Skip to content

Commit 2b7042d

Browse files
authored
Merge pull request #286 from CITCOM-project/temporal-shenanigans
Enabling causal effect estimation in the presence of time-varying confounding using IPCW.
2 parents b6ce637 + e6284c5 commit 2b7042d

File tree

15 files changed

+593
-27
lines changed

15 files changed

+593
-27
lines changed

causal_testing/json_front/json_class.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def _execute_test_case(
276276
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
277277

278278
if "coverage" in test and test["coverage"]:
279-
adequacy_metric = DataAdequacy(causal_test_case, estimation_model, self.data_collector)
279+
adequacy_metric = DataAdequacy(causal_test_case, estimation_model)
280280
adequacy_metric.measure_adequacy()
281281
causal_test_result.adequacy = adequacy_metric
282282

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""
2+
This module contains the Capability and TreatmentSequence classes to implement
3+
treatment sequences that operate over time.
4+
"""
5+
from typing import Any
6+
from causal_testing.specification.variable import Variable
7+
8+
9+
class Capability:
10+
"""
11+
Data class to encapsulate temporal interventions.
12+
"""
13+
14+
def __init__(self, variable: Variable, value: Any, start_time: int, end_time: int):
15+
self.variable = variable
16+
self.value = value
17+
self.start_time = start_time
18+
self.end_time = end_time
19+
20+
def __eq__(self, other):
21+
return (
22+
isinstance(other, type(self))
23+
and self.variable == other.variable
24+
and self.value == other.value
25+
and self.start_time == other.start_time
26+
and self.end_time == other.end_time
27+
)
28+
29+
def __repr__(self):
30+
return f"({self.variable}, {self.value}, {self.start_time}-{self.end_time})"
31+
32+
33+
class TreatmentSequence:
34+
"""
35+
Class to represent a list of capabilities, i.e. a treatment regime.
36+
"""
37+
38+
def __init__(self, timesteps_per_intervention, capabilities):
39+
self.timesteps_per_intervention = timesteps_per_intervention
40+
self.capabilities = [
41+
Capability(var, val, t, t + timesteps_per_intervention)
42+
for (var, val), t in zip(
43+
capabilities,
44+
range(
45+
timesteps_per_intervention,
46+
(len(capabilities) * timesteps_per_intervention) + 1,
47+
timesteps_per_intervention,
48+
),
49+
)
50+
]
51+
# This is a bodge so that causal test adequacy works
52+
self.name = tuple(c.variable for c in self.capabilities)
53+
54+
def set_value(self, index: int, value: float):
55+
"""
56+
Set the value of capability at the given index.
57+
:param index - the index of the element to update.
58+
:param value - the desired value of the capability.
59+
"""
60+
self.capabilities[index].value = value
61+
62+
def copy(self):
63+
"""
64+
Return a deep copy of the capability list.
65+
"""
66+
strategy = TreatmentSequence(
67+
self.timesteps_per_intervention,
68+
[(c.variable, c.value) for c in self.capabilities],
69+
)
70+
return strategy
71+
72+
def total_time(self):
73+
"""
74+
Calculate the total duration of the treatment strategy.
75+
"""
76+
return (len(self.capabilities) + 1) * self.timesteps_per_intervention

causal_testing/testing/causal_test_adequacy.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,20 @@
22
This module contains code to measure various aspects of causal test adequacy.
33
"""
44

5+
import logging
56
from itertools import combinations
67
from copy import deepcopy
78
import pandas as pd
9+
from numpy.linalg import LinAlgError
10+
from lifelines.exceptions import ConvergenceError
811

912
from causal_testing.testing.causal_test_suite import CausalTestSuite
10-
from causal_testing.data_collection.data_collector import DataCollector
1113
from causal_testing.specification.causal_dag import CausalDAG
1214
from causal_testing.testing.estimators import Estimator
1315
from causal_testing.testing.causal_test_case import CausalTestCase
1416

17+
logger = logging.getLogger(__name__)
18+
1519

1620
class DAGAdequacy:
1721
"""
@@ -70,15 +74,21 @@ class DataAdequacy:
7074
- Zero kurtosis is optimal.
7175
"""
7276

77+
# pylint: disable=too-many-instance-attributes
7378
def __init__(
74-
self, test_case: CausalTestCase, estimator: Estimator, data_collector: DataCollector, bootstrap_size: int = 100
79+
self,
80+
test_case: CausalTestCase,
81+
estimator: Estimator,
82+
bootstrap_size: int = 100,
83+
group_by=None,
7584
):
7685
self.test_case = test_case
7786
self.estimator = estimator
78-
self.data_collector = data_collector
7987
self.kurtosis = None
8088
self.outcomes = None
89+
self.successful = None
8190
self.bootstrap_size = bootstrap_size
91+
self.group_by = group_by
8292

8393
def measure_adequacy(self):
8494
"""
@@ -87,11 +97,24 @@ def measure_adequacy(self):
8797
results = []
8898
for i in range(self.bootstrap_size):
8999
estimator = deepcopy(self.estimator)
90-
estimator.df = estimator.df.sample(len(estimator.df), replace=True, random_state=i)
91-
# try:
92-
results.append(self.test_case.execute_test(estimator, self.data_collector))
93-
# except np.LinAlgError:
94-
# continue
100+
101+
if self.group_by is not None:
102+
ids = pd.Series(estimator.df[self.group_by].unique())
103+
ids = ids.sample(len(ids), replace=True, random_state=i)
104+
estimator.df = estimator.df[estimator.df[self.group_by].isin(ids)]
105+
else:
106+
estimator.df = estimator.df.sample(len(estimator.df), replace=True, random_state=i)
107+
try:
108+
results.append(self.test_case.execute_test(estimator, None))
109+
except LinAlgError:
110+
logger.warning("Adequacy LinAlgError")
111+
continue
112+
except ConvergenceError:
113+
logger.warning("Adequacy ConvergenceError")
114+
continue
115+
except ValueError as e:
116+
logger.warning(f"Adequacy ValueError: {e}")
117+
continue
95118
outcomes = [self.test_case.expected_causal_effect.apply(c) for c in results]
96119
results = pd.DataFrame(c.to_dict() for c in results)[["effect_estimate", "ci_low", "ci_high"]]
97120

@@ -111,8 +134,14 @@ def convert_to_df(field):
111134

112135
effect_estimate = pd.concat(results["effect_estimate"].tolist(), axis=1).transpose().reset_index(drop=True)
113136
self.kurtosis = effect_estimate.kurtosis()
114-
self.outcomes = sum(outcomes)
137+
self.outcomes = sum(filter(lambda x: x is not None, outcomes))
138+
self.successful = sum(x is not None for x in outcomes)
115139

116140
def to_dict(self):
117141
"Returns the adequacy object as a dictionary."
118-
return {"kurtosis": self.kurtosis.to_dict(), "bootstrap_size": self.bootstrap_size, "passing": self.outcomes}
142+
return {
143+
"kurtosis": self.kurtosis.to_dict(),
144+
"bootstrap_size": self.bootstrap_size,
145+
"passing": self.outcomes,
146+
"successful": self.successful,
147+
}

causal_testing/testing/causal_test_case.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def _return_causal_test_results(self, estimator) -> CausalTestResult:
9292
except np.linalg.LinAlgError:
9393
return CausalTestResult(
9494
estimator=estimator,
95-
test_value=TestValue(self.estimate_type, "LinAlgError"),
95+
test_value=TestValue(self.estimate_type, None),
9696
effect_modifier_configuration=self.effect_modifier_configuration,
9797
confidence_intervals=None,
9898
)

causal_testing/testing/causal_test_outcome.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ class SomeEffect(CausalTestOutcome):
2727
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
2828

2929
def apply(self, res: CausalTestResult) -> bool:
30-
if res.test_value.type == "risk_ratio":
30+
if res.ci_low() is None or res.ci_high() is None:
31+
return None
32+
if res.test_value.type in ("risk_ratio", "hazard_ratio"):
3133
return any(
3234
1 < ci_low < ci_high or ci_low < ci_high < 1 for ci_low, ci_high in zip(res.ci_low(), res.ci_high())
3335
)
@@ -52,7 +54,7 @@ def __init__(self, atol: float = 1e-10, ctol: float = 0.05):
5254
self.ctol = ctol
5355

5456
def apply(self, res: CausalTestResult) -> bool:
55-
if res.test_value.type == "risk_ratio":
57+
if res.test_value.type in ("risk_ratio", "hazard_ratio"):
5658
return any(
5759
ci_low < 1 < ci_high or np.isclose(value, 1.0, atol=self.atol)
5860
for ci_low, ci_high, value in zip(res.ci_low(), res.ci_high(), res.test_value.value)

0 commit comments

Comments
 (0)