Skip to content

Commit 4951029

Browse files
committed
Temporal estimation in
1 parent 4a86d9c commit 4951029

File tree

6 files changed

+341
-12
lines changed

6 files changed

+341
-12
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from causal_testing.specification.variable import Variable
2+
3+
4+
class Capability:
5+
"""
6+
Data class to encapsulate temporal interventions.
7+
"""
8+
9+
def __init__(self, variable: Variable, value: any, start_time: int, end_time: float):
10+
self.variable = variable
11+
self.value = value
12+
self.start_time = start_time
13+
self.end_time = end_time
14+
15+
def __eq__(self, other):
16+
return (
17+
type(other) == type(self)
18+
and self.variable == other.variable
19+
and self.value == other.value
20+
and self.start_time == other.start_time
21+
and self.end_time == other.end_time
22+
)
23+
24+
def __repr__(self):
25+
return f"({self.variable}, {self.value}, {self.start_time}-{self.end_time})"
26+
27+
28+
class TreatmentSequence:
29+
"""
30+
Class to represent a list of capabilities, i.e. a treatment regime.
31+
"""
32+
33+
def __init__(self, timesteps_per_intervention, capabilities):
34+
self.timesteps_per_intervention = timesteps_per_intervention
35+
self.capabilities = [
36+
Capability(var, val, t, t + timesteps_per_intervention)
37+
for (var, val), t in zip(
38+
capabilities,
39+
range(
40+
timesteps_per_intervention,
41+
(len(capabilities) * timesteps_per_intervention) + 1,
42+
timesteps_per_intervention,
43+
),
44+
)
45+
]
46+
# This is a bodge so that causal test adequacy works
47+
self.name = tuple([c.variable for c in self.capabilities])
48+
49+
def set_value(self, index: int, value: float):
50+
"""
51+
Set the value of capability at the given index.
52+
:param index - the index of the element to update.
53+
:param value - the desired value of the capability.
54+
"""
55+
self.capabilities[index].value = value
56+
57+
def copy(self):
58+
"""
59+
Return a deep copy of the capability list.
60+
"""
61+
strategy = TreatmentSequence(
62+
self.timesteps_per_intervention,
63+
[(c.variable, c.value) for c in self.capabilities],
64+
)
65+
return strategy
66+
67+
def total_time(self):
68+
"""
69+
Calculate the total duration of the treatment strategy.
70+
"""
71+
return (len(self.capabilities) + 1) * self.timesteps_per_intervention

causal_testing/testing/causal_test_adequacy.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,21 @@ class DataAdequacy:
7171
"""
7272

7373
def __init__(
74-
self, test_case: CausalTestCase, estimator: Estimator, data_collector: DataCollector, bootstrap_size: int = 100
74+
self,
75+
test_case: CausalTestCase,
76+
estimator: Estimator,
77+
data_collector: DataCollector,
78+
bootstrap_size: int = 100,
79+
group_by=None,
7580
):
7681
self.test_case = test_case
7782
self.estimator = estimator
7883
self.data_collector = data_collector
7984
self.kurtosis = None
8085
self.outcomes = None
86+
self.successful = None
8187
self.bootstrap_size = bootstrap_size
88+
self.group_by = group_by
8289

8390
def measure_adequacy(self):
8491
"""
@@ -87,11 +94,14 @@ def measure_adequacy(self):
8794
results = []
8895
for i in range(self.bootstrap_size):
8996
estimator = deepcopy(self.estimator)
90-
estimator.df = estimator.df.sample(len(estimator.df), replace=True, random_state=i)
91-
# try:
97+
98+
if self.group_by is not None:
99+
ids = pd.Series(estimator.df[self.group_by].unique())
100+
ids = ids.sample(len(ids), replace=True, random_state=i)
101+
estimator.df = estimator.df[estimator.df[self.group_by].isin(ids)]
102+
else:
103+
estimator.df = estimator.df.sample(len(estimator.df), replace=True, random_state=i)
92104
results.append(self.test_case.execute_test(estimator, self.data_collector))
93-
# except np.LinAlgError:
94-
# continue
95105
outcomes = [self.test_case.expected_causal_effect.apply(c) for c in results]
96106
results = pd.DataFrame(c.to_dict() for c in results)[["effect_estimate", "ci_low", "ci_high"]]
97107

@@ -111,8 +121,14 @@ def convert_to_df(field):
111121

112122
effect_estimate = pd.concat(results["effect_estimate"].tolist(), axis=1).transpose().reset_index(drop=True)
113123
self.kurtosis = effect_estimate.kurtosis()
114-
self.outcomes = sum(outcomes)
124+
self.outcomes = sum(filter(lambda x: x is not None, outcomes))
125+
self.successful = sum([x is not None for x in outcomes])
115126

116127
def to_dict(self):
117128
"Returns the adequacy object as a dictionary."
118-
return {"kurtosis": self.kurtosis.to_dict(), "bootstrap_size": self.bootstrap_size, "passing": self.outcomes}
129+
return {
130+
"kurtosis": self.kurtosis.to_dict(),
131+
"bootstrap_size": self.bootstrap_size,
132+
"passing": self.outcomes,
133+
"successful": self.successful,
134+
}

causal_testing/testing/causal_test_case.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def _return_causal_test_results(self, estimator) -> CausalTestResult:
9292
except np.linalg.LinAlgError:
9393
return CausalTestResult(
9494
estimator=estimator,
95-
test_value=TestValue(self.estimate_type, "LinAlgError"),
95+
test_value=TestValue(self.estimate_type, None),
9696
effect_modifier_configuration=self.effect_modifier_configuration,
9797
confidence_intervals=None,
9898
)

causal_testing/testing/causal_test_outcome.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ class SomeEffect(CausalTestOutcome):
2727
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
2828

2929
def apply(self, res: CausalTestResult) -> bool:
30-
if res.test_value.type == "risk_ratio":
30+
if res.ci_low() is None or res.ci_high() is None:
31+
return None
32+
if res.test_value.type in ("risk_ratio", "hazard_ratio"):
3133
return any(
3234
1 < ci_low < ci_high or ci_low < ci_high < 1 for ci_low, ci_high in zip(res.ci_low(), res.ci_high())
3335
)
@@ -52,7 +54,7 @@ def __init__(self, atol: float = 1e-10, ctol: float = 0.05):
5254
self.ctol = ctol
5355

5456
def apply(self, res: CausalTestResult) -> bool:
55-
if res.test_value.type == "risk_ratio":
57+
if res.test_value.type in ("risk_ratio", "hazard_ratio"):
5658
return any(
5759
ci_low < 1 < ci_high or np.isclose(value, 1.0, atol=self.atol)
5860
for ci_low, ci_high, value in zip(res.ci_low(), res.ci_high(), res.test_value.value)

0 commit comments

Comments
 (0)