Skip to content

Commit adecf3d

Browse files
authored
Merge pull request #119 from CITCOM-project/somers/test_value
causal_test_result + new dataclass for test values
2 parents 6e9a369 + d86b8bc commit adecf3d

File tree

11 files changed

+158
-132
lines changed

11 files changed

+158
-132
lines changed

causal_testing/json_front/json_class.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,10 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
177177
result_string = str()
178178
if causal_test_result.ci_low() and causal_test_result.ci_high():
179179
result_string = (
180-
f"{causal_test_result.ci_low()} < {causal_test_result.ate} < {causal_test_result.ci_high()}"
180+
f"{causal_test_result.ci_low()} < {causal_test_result.test_value.value} < {causal_test_result.ci_high()}"
181181
)
182182
else:
183-
result_string = causal_test_result.ate
183+
result_string = causal_test_result.test_value.value
184184
if f_flag:
185185
assert test_passes, (
186186
f"{causal_test_case}\n FAILED - expected {causal_test_case.expected_causal_effect}, "
@@ -191,7 +191,7 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
191191
logger.warning(
192192
" FAILED- expected %s, got %s",
193193
causal_test_case.expected_causal_effect,
194-
causal_test_result.ate,
194+
causal_test_result.test_value.value,
195195
)
196196
return failed
197197

causal_testing/testing/causal_test_engine.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from causal_testing.data_collection.data_collector import DataCollector
66
from causal_testing.specification.causal_specification import CausalSpecification
77
from causal_testing.testing.causal_test_case import CausalTestCase
8-
from causal_testing.testing.causal_test_outcome import CausalTestResult
8+
from causal_testing.testing.causal_test_result import CausalTestResult, TestValue
99
from causal_testing.testing.estimators import Estimator
1010

1111
logger = logging.getLogger(__name__)
@@ -138,7 +138,7 @@ def execute_test(
138138
treatment_value=estimator.treatment_values,
139139
control_value=estimator.control_values,
140140
adjustment_set=estimator.adjustment_set,
141-
ate=cates_df,
141+
test_value=TestValue("ate", cates_df),
142142
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
143143
confidence_intervals=confidence_intervals,
144144
)
@@ -151,7 +151,7 @@ def execute_test(
151151
treatment_value=estimator.treatment_values,
152152
control_value=estimator.control_values,
153153
adjustment_set=estimator.adjustment_set,
154-
ate=risk_ratio,
154+
test_value=TestValue("risk_ratio", risk_ratio),
155155
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
156156
confidence_intervals=confidence_intervals,
157157
)
@@ -164,7 +164,7 @@ def execute_test(
164164
treatment_value=estimator.treatment_values,
165165
control_value=estimator.control_values,
166166
adjustment_set=estimator.adjustment_set,
167-
ate=ate,
167+
test_value=TestValue("ate", ate),
168168
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
169169
confidence_intervals=confidence_intervals,
170170
)
@@ -179,7 +179,7 @@ def execute_test(
179179
treatment_value=estimator.treatment_values,
180180
control_value=estimator.control_values,
181181
adjustment_set=estimator.adjustment_set,
182-
ate=ate,
182+
test_value=TestValue("ate", ate),
183183
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
184184
confidence_intervals=confidence_intervals,
185185
)
Lines changed: 18 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,92 +1,7 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, Union
3-
2+
from causal_testing.testing.causal_test_result import CausalTestResult
43
import numpy as np
54

6-
from causal_testing.specification.variable import Variable
7-
8-
9-
class CausalTestResult:
10-
"""A container to hold the results of a causal test case. Every causal test case provides a point estimate of
11-
the ATE, given a particular treatment, outcome, and adjustment set. Some but not all estimators can provide
12-
confidence intervals."""
13-
14-
def __init__(
15-
self,
16-
treatment: tuple,
17-
outcome: tuple,
18-
treatment_value: Union[int, float, str],
19-
control_value: Union[int, float, str],
20-
adjustment_set: set,
21-
ate: float,
22-
confidence_intervals: [float, float] = None,
23-
effect_modifier_configuration: {Variable: Any} = None,
24-
):
25-
self.treatment = treatment
26-
self.outcome = outcome
27-
self.treatment_value = treatment_value
28-
self.control_value = control_value
29-
if adjustment_set:
30-
self.adjustment_set = adjustment_set
31-
else:
32-
self.adjustment_set = set()
33-
self.ate = ate
34-
self.confidence_intervals = confidence_intervals
35-
36-
if effect_modifier_configuration is not None:
37-
self.effect_modifier_configuration = effect_modifier_configuration
38-
else:
39-
self.effect_modifier_configuration = {}
40-
41-
def __str__(self):
42-
base_str = (
43-
f"Causal Test Result\n==============\n"
44-
f"Treatment: {self.treatment[0]}\n"
45-
f"Control value: {self.control_value}\n"
46-
f"Treatment value: {self.treatment_value}\n"
47-
f"Outcome: {self.outcome[0]}\n"
48-
f"Adjustment set: {self.adjustment_set}\n"
49-
f"ATE: {self.ate}\n"
50-
)
51-
confidence_str = ""
52-
if self.confidence_intervals:
53-
confidence_str += f"Confidence intervals: {self.confidence_intervals}\n"
54-
return base_str + confidence_str
55-
56-
def to_dict(self):
57-
base_dict = {
58-
"treatment": self.treatment[0],
59-
"control_value": self.control_value,
60-
"treatment_value": self.treatment_value,
61-
"outcome": self.outcome[0],
62-
"adjustment_set": self.adjustment_set,
63-
"ate": self.ate,
64-
}
65-
if self.confidence_intervals:
66-
base_dict["ci_low"] = min(self.confidence_intervals)
67-
base_dict["ci_high"] = max(self.confidence_intervals)
68-
return base_dict
69-
70-
def ci_low(self):
71-
"""Return the lower bracket of the confidence intervals."""
72-
if not self.confidence_intervals:
73-
return None
74-
return min(self.confidence_intervals)
75-
76-
def ci_high(self):
77-
"""Return the higher bracket of the confidence intervals."""
78-
if not self.confidence_intervals:
79-
return None
80-
return max(self.confidence_intervals)
81-
82-
def summary(self):
83-
"""Summarise the causal test result as an intuitive sentence."""
84-
print(
85-
f"The causal effect of changing {self.treatment[0]} = {self.control_value} to "
86-
f"{self.treatment[0]}' = {self.treatment_value} is {self.ate} (95% confidence intervals: "
87-
f"{self.confidence_intervals})."
88-
)
89-
905

916
class CausalTestOutcome(ABC):
927
"""An abstract class representing an expected causal effect."""
@@ -110,7 +25,7 @@ def __init__(self, value: float, tolerance: float = None):
11025
self.tolerance = tolerance
11126

11227
def apply(self, res: CausalTestResult) -> bool:
113-
return np.isclose(res.ate, self.value, atol=self.tolerance)
28+
return np.isclose(res.test_value.value, self.value, atol=self.tolerance)
11429

11530
def __str__(self):
11631
return f"ExactValue: {self.value}±{self.tolerance}"
@@ -121,22 +36,31 @@ class Positive(CausalTestOutcome):
12136

12237
def apply(self, res: CausalTestResult) -> bool:
12338
# TODO: confidence intervals?
124-
return res.ate > 0
39+
if res.test_value.type == "ate":
40+
return res.test_value.value > 0
41+
elif res.test_value.type == "risk_ratio":
42+
return res.test_value.value > 1
12543

12644

12745
class Negative(CausalTestOutcome):
12846
"""An extension of TestOutcome representing that the expected causal effect should be negative."""
12947

13048
def apply(self, res: CausalTestResult) -> bool:
13149
# TODO: confidence intervals?
132-
return res.ate < 0
50+
if res.test_value.type == "ate":
51+
return res.test_value.value < 0
52+
elif res.test_value.type == "risk_ratio":
53+
return res.test_value.value < 1
13354

13455

13556
class SomeEffect(CausalTestOutcome):
13657
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
13758

13859
def apply(self, res: CausalTestResult) -> bool:
139-
return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
60+
if res.test_value.type == "ate":
61+
return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
62+
elif res.test_value.type == "risk_ratio":
63+
return (1 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 1)
14064

14165
def __str__(self):
14266
return "Changed"
@@ -146,7 +70,10 @@ class NoEffect(CausalTestOutcome):
14670
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
14771

14872
def apply(self, res: CausalTestResult) -> bool:
149-
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.ate) < 1e-10)
73+
if res.test_value.type == "ate":
74+
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < 1e-10)
75+
elif res.test_value.type == "risk_ratio":
76+
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=1e-10)
15077

15178
def __str__(self):
15279
return "Unchanged"
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from typing import Any, Union
2+
from dataclasses import dataclass
3+
4+
from causal_testing.specification.variable import Variable
5+
6+
7+
@dataclass
8+
class TestValue:
9+
"""A dataclass to hold both the type and value of a causal test result"""
10+
11+
type: str
12+
value: float
13+
14+
15+
class CausalTestResult:
16+
"""A container to hold the results of a causal test case. Every causal test case provides a point estimate of
17+
the ATE, given a particular treatment, outcome, and adjustment set. Some but not all estimators can provide
18+
confidence intervals."""
19+
20+
def __init__(
21+
self,
22+
treatment: tuple,
23+
outcome: tuple,
24+
treatment_value: Union[int, float, str],
25+
control_value: Union[int, float, str],
26+
adjustment_set: set,
27+
test_value: TestValue,
28+
confidence_intervals: [float, float] = None,
29+
effect_modifier_configuration: {Variable: Any} = None,
30+
):
31+
self.treatment = treatment
32+
self.outcome = outcome
33+
self.treatment_value = treatment_value
34+
self.control_value = control_value
35+
if adjustment_set:
36+
self.adjustment_set = adjustment_set
37+
else:
38+
self.adjustment_set = set()
39+
self.test_value = test_value
40+
self.confidence_intervals = confidence_intervals
41+
42+
if effect_modifier_configuration is not None:
43+
self.effect_modifier_configuration = effect_modifier_configuration
44+
else:
45+
self.effect_modifier_configuration = {}
46+
47+
def __str__(self):
48+
base_str = (
49+
f"Causal Test Result\n==============\n"
50+
f"Treatment: {self.treatment[0]}\n"
51+
f"Control value: {self.control_value}\n"
52+
f"Treatment value: {self.treatment_value}\n"
53+
f"Outcome: {self.outcome[0]}\n"
54+
f"Adjustment set: {self.adjustment_set}\n"
55+
f"{self.test_value.type}: {self.test_value.value}\n"
56+
)
57+
confidence_str = ""
58+
if self.confidence_intervals:
59+
confidence_str += f"Confidence intervals: {self.confidence_intervals}\n"
60+
return base_str + confidence_str
61+
62+
def to_dict(self):
63+
base_dict = {
64+
"treatment": self.treatment[0],
65+
"control_value": self.control_value,
66+
"treatment_value": self.treatment_value,
67+
"outcome": self.outcome[0],
68+
"adjustment_set": self.adjustment_set,
69+
"test_value": self.test_value,
70+
}
71+
if self.confidence_intervals:
72+
base_dict["ci_low"] = min(self.confidence_intervals)
73+
base_dict["ci_high"] = max(self.confidence_intervals)
74+
return base_dict
75+
76+
def ci_low(self):
77+
"""Return the lower bracket of the confidence intervals."""
78+
if not self.confidence_intervals:
79+
return None
80+
return min(self.confidence_intervals)
81+
82+
def ci_high(self):
83+
"""Return the higher bracket of the confidence intervals."""
84+
if not self.confidence_intervals:
85+
return None
86+
return max(self.confidence_intervals)
87+
88+
def summary(self):
89+
"""Summarise the causal test result as an intuitive sentence."""
90+
print(
91+
f"The causal effect of changing {self.treatment[0]} = {self.control_value} to "
92+
f"{self.treatment[0]}' = {self.treatment_value} is {self.test_value.value} (95% confidence intervals: "
93+
f"{self.confidence_intervals})."
94+
)

examples/covasim_/doubling_beta/causal_test_beta.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ def doubling_beta_CATE_on_csv(observational_data_path: str, simulate_counterfact
6464
association_test_result = causal_test_engine.execute_test(no_adjustment_linear_regression_estimator, causal_test_case, 'ate')
6565

6666
# Store results for plotting
67-
results_dict['association'] = {'ate': association_test_result.ate,
67+
results_dict['association'] = {'ate': association_test_result.test_value.value,
6868
'cis': association_test_result.confidence_intervals,
6969
'df': past_execution_df}
70-
results_dict['causation'] = {'ate': causal_test_result.ate,
70+
results_dict['causation'] = {'ate': causal_test_result.test_value.value,
7171
'cis': causal_test_result.confidence_intervals,
7272
'df': past_execution_df}
7373

@@ -84,7 +84,7 @@ def doubling_beta_CATE_on_csv(observational_data_path: str, simulate_counterfact
8484
df=counterfactual_past_execution_df)
8585
counterfactual_linear_regression_estimator.add_squared_term_to_df('beta')
8686
counterfactual_causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case, 'ate')
87-
results_dict['counterfactual'] = {'ate': counterfactual_causal_test_result.ate,
87+
results_dict['counterfactual'] = {'ate': counterfactual_causal_test_result.test_value.value,
8888
'cis': counterfactual_causal_test_result.confidence_intervals,
8989
'df': counterfactual_past_execution_df}
9090
if verbose:

examples/covasim_/vaccinating_elderly/causal_test_vaccine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def experimental_causal_test_vaccinate_elderly(runs_per_test_per_config: int = 3
9191
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case, 'ate')
9292
if verbose:
9393
print(f"Causation:\n{causal_test_result}")
94-
results_dict[outcome_variable.name]['ate'] = causal_test_result.ate
94+
results_dict[outcome_variable.name]['ate'] = causal_test_result.test_value.value
9595
results_dict[outcome_variable.name]['cis'] = causal_test_result.confidence_intervals
9696
results_dict[outcome_variable.name]['test_passes'] = causal_test_case.expected_causal_effect.apply(
9797
causal_test_result)

examples/lr91/causal_test_max_conductances.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def effects_on_APD90(observational_data_path, treatment_var, control_val, treatm
132132
# 10. Run the causal test and print results
133133
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case, 'ate')
134134
print(causal_test_result)
135-
return causal_test_result.ate, causal_test_result.confidence_intervals
135+
return causal_test_result.test_value.value, causal_test_result.confidence_intervals
136136

137137

138138
def plot_ates_with_cis(results_dict: dict, xs: list, save: bool = True):

examples/poisson-line-process/causal_test_poisson.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ def test_intensity_num_shapes(
182182
"height": wh,
183183
"control": control_value,
184184
"treatment": treatment_value,
185-
"smt_risk_ratio": smt_causal_test_result.ate,
186-
"obs_risk_ratio": obs_causal_test_result.ate,
185+
"smt_risk_ratio": smt_causal_test_result.test_value.value,
186+
"obs_risk_ratio": obs_causal_test_result.test_value.value,
187187
}
188188
intensity_num_shapes_results.append(results)
189189

@@ -218,7 +218,7 @@ def test_intensity_num_shapes(
218218
"control": control_value,
219219
"treatment": treatment_value,
220220
"intensity": i,
221-
"ate": causal_test_result.ate,
221+
"ate": causal_test_result.test_value.value,
222222
"ci_low": min(causal_test_result.confidence_intervals),
223223
"ci_high": max(causal_test_result.confidence_intervals),
224224
}

examples/poisson/run_causal_tests.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import scipy
44

55
from causal_testing.testing.estimators import LinearRegressionEstimator, CausalForestEstimator
6-
from causal_testing.testing.causal_test_outcome import ExactValue, Positive, Negative, NoEffect, CausalTestOutcome, \
7-
CausalTestResult
6+
from causal_testing.testing.causal_test_outcome import ExactValue, Positive, Negative, NoEffect, CausalTestOutcome
7+
from causal_testing.testing.causal_test_result import CausalTestResult
88
from causal_testing.json_front.json_class import JsonUtility
99
from causal_testing.testing.estimators import Estimator
1010
from causal_testing.specification.scenario import Scenario
@@ -70,8 +70,8 @@ def apply(self, res: CausalTestResult) -> bool:
7070
i = effect_modifier_configuration['intensity']
7171
self.i2c = i * 2 * c
7272
print("2ic:", f"2*{i}*{c}={self.i2c}")
73-
print("ate:", res.ate)
74-
return np.isclose(res.ate, self.i2c, atol=self.tolerance)
73+
print("ate:", res.test_value.value)
74+
return np.isclose(res.test_value.value, self.i2c, atol=self.tolerance)
7575

7676
def __str__(self):
7777
if self.i2c is None:

0 commit comments

Comments
 (0)