Skip to content

Commit aa90884

Browse files
authored
Merge pull request #202 from CITCOM-project/somers/estimates
Estimates now taken from causal test case when executing test
2 parents 9310ebe + 45cb719 commit aa90884

File tree

8 files changed

+28
-36
lines changed

8 files changed

+28
-36
lines changed

causal_testing/json_front/json_class.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,7 @@ def _execute_test_case(
229229
causal_test_engine, estimation_model = self._setup_test(
230230
causal_test_case, test, test["conditions"] if "conditions" in test else None
231231
)
232-
causal_test_result = causal_test_engine.execute_test(
233-
estimation_model, causal_test_case, estimate_type=causal_test_case.estimate_type
234-
)
232+
causal_test_result = causal_test_engine.execute_test(estimation_model, causal_test_case)
235233

236234
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
237235

causal_testing/testing/causal_test_engine.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResu
8181

8282
estimators = test_suite[edge]["estimators"]
8383
tests = test_suite[edge]["tests"]
84-
estimate_type = test_suite[edge]["estimate_type"]
8584
results = {}
8685
for estimator_class in estimators:
8786
causal_test_results = []
@@ -96,16 +95,14 @@ def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResu
9695
)
9796
if estimator.df is None:
9897
estimator.df = self.scenario_execution_data_df
99-
causal_test_result = self._return_causal_test_results(estimate_type, estimator, test)
98+
causal_test_result = self._return_causal_test_results(estimator, test)
10099
causal_test_results.append(causal_test_result)
101100

102101
results[estimator_class.__name__] = causal_test_results
103102
test_suite_results[edge] = results
104103
return test_suite_results
105104

106-
def execute_test(
107-
self, estimator: type(Estimator), causal_test_case: CausalTestCase, estimate_type: str = "ate"
108-
) -> CausalTestResult:
105+
def execute_test(self, estimator: type(Estimator), causal_test_case: CausalTestCase) -> CausalTestResult:
109106
"""Execute a causal test case and return the causal test result.
110107
111108
Test case execution proceeds with the following steps:
@@ -120,7 +117,6 @@ def execute_test(
120117
121118
:param estimator: A reference to an Estimator class.
122119
:param causal_test_case: The CausalTestCase object to be tested
123-
:param estimate_type: A string which denotes the type of estimate to return, ATE or CATE.
124120
:return causal_test_result: A CausalTestResult for the executed causal test case.
125121
"""
126122
if self.scenario_execution_data_df.empty:
@@ -142,18 +138,17 @@ def execute_test(
142138
if self._check_positivity_violation(variables_for_positivity):
143139
raise ValueError("POSITIVITY VIOLATION -- Cannot proceed.")
144140

145-
causal_test_result = self._return_causal_test_results(estimate_type, estimator, causal_test_case)
141+
causal_test_result = self._return_causal_test_results(estimator, causal_test_case)
146142
return causal_test_result
147143

148-
def _return_causal_test_results(self, estimate_type, estimator, causal_test_case):
144+
def _return_causal_test_results(self, estimator, causal_test_case):
149145
"""Depending on the estimator used, calculate the 95% confidence intervals and return in a causal_test_result
150146
151-
:param estimate_type: A string which denotes the type of estimate to return
152147
:param estimator: An Estimator class object
153148
:param causal_test_case: The concrete test case to be executed
154149
:return: a CausalTestResult object containing the confidence intervals
155150
"""
156-
if estimate_type == "cate":
151+
if causal_test_case.estimate_type == "cate":
157152
logger.debug("calculating cate")
158153
if not hasattr(estimator, "estimate_cates"):
159154
raise NotImplementedError(f"{estimator.__class__} has no CATE method.")
@@ -165,7 +160,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
165160
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
166161
confidence_intervals=confidence_intervals,
167162
)
168-
elif estimate_type == "risk_ratio":
163+
elif causal_test_case.estimate_type == "risk_ratio":
169164
logger.debug("calculating risk_ratio")
170165
risk_ratio, confidence_intervals = estimator.estimate_risk_ratio()
171166
causal_test_result = CausalTestResult(
@@ -174,7 +169,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
174169
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
175170
confidence_intervals=confidence_intervals,
176171
)
177-
elif estimate_type == "coefficient":
172+
elif causal_test_case.estimate_type == "coefficient":
178173
logger.debug("calculating coefficient")
179174
coefficient, confidence_intervals = estimator.estimate_unit_ate()
180175
causal_test_result = CausalTestResult(
@@ -183,7 +178,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
183178
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
184179
confidence_intervals=confidence_intervals,
185180
)
186-
elif estimate_type == "ate":
181+
elif causal_test_case.estimate_type == "ate":
187182
logger.debug("calculating ate")
188183
ate, confidence_intervals = estimator.estimate_ate()
189184
causal_test_result = CausalTestResult(
@@ -194,7 +189,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
194189
)
195190
# causal_test_result = CausalTestResult(minimal_adjustment_set, ate, confidence_intervals)
196191
# causal_test_result.apply_test_oracle_procedure(self.causal_test_case.expected_causal_effect)
197-
elif estimate_type == "ate_calculated":
192+
elif causal_test_case.estimate_type == "ate_calculated":
198193
logger.debug("calculating ate")
199194
ate, confidence_intervals = estimator.estimate_ate_calculated()
200195
causal_test_result = CausalTestResult(
@@ -206,7 +201,9 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
206201
# causal_test_result = CausalTestResult(minimal_adjustment_set, ate, confidence_intervals)
207202
# causal_test_result.apply_test_oracle_procedure(self.causal_test_case.expected_causal_effect)
208203
else:
209-
raise ValueError(f"Invalid estimate type {estimate_type}, expected 'ate', 'cate', or 'risk_ratio'")
204+
raise ValueError(
205+
f"Invalid estimate type {causal_test_case.estimate_type}, expected 'ate', 'cate', or 'risk_ratio'"
206+
)
210207
return causal_test_result
211208

212209
def _check_positivity_violation(self, variables_list):

docs/source/usage.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,7 @@ various information. Here, we simply assert that the observed result is (on aver
107107
108108
causal_test_result = causal_test_engine.execute_test(
109109
estimator = estimation_model,
110-
causal_test_case = causal_test_case,
111-
estimate_type = "ate")
110+
causal_test_case = causal_test_case)
112111
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
113112
assert test_passes, "Expected to see a positive change in y."
114113

examples/covasim_/doubling_beta/example_beta.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def doubling_beta_CATE_on_csv(
6565
)
6666

6767
# Add squared terms for beta, since it has a quadratic relationship with cumulative infections
68-
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case, "ate")
68+
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case)
6969

7070
# Repeat for association estimate (no adjustment)
7171
no_adjustment_linear_regression_estimator = LinearRegressionEstimator(
@@ -78,7 +78,7 @@ def doubling_beta_CATE_on_csv(
7878
formula="cum_infections ~ beta + np.power(beta, 2)",
7979
)
8080
association_test_result = causal_test_engine.execute_test(
81-
no_adjustment_linear_regression_estimator, causal_test_case, "ate"
81+
no_adjustment_linear_regression_estimator, causal_test_case
8282
)
8383

8484
# Store results for plotting
@@ -110,7 +110,7 @@ def doubling_beta_CATE_on_csv(
110110
formula="cum_infections ~ beta + np.power(beta, 2) + avg_age + contacts",
111111
)
112112
counterfactual_causal_test_result = causal_test_engine.execute_test(
113-
linear_regression_estimator, causal_test_case, "ate"
113+
linear_regression_estimator, causal_test_case
114114
)
115115
results_dict["counterfactual"] = {
116116
"ate": counterfactual_causal_test_result.test_value.value,

examples/covasim_/vaccinating_elderly/example_vaccine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def test_experimental_vaccinate_elderly(runs_per_test_per_config: int = 30, verb
9999
)
100100

101101
# 10. Execute test and save results in dict
102-
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case, "ate")
102+
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case)
103103
if verbose:
104104
logging.info("Causation:\n%s", causal_test_result)
105105
results_dict[outcome_variable.name]["ate"] = causal_test_result.test_value.value

examples/lr91/example_max_conductances.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def effects_on_APD90(observational_data_path, treatment_var, control_val, treatm
148148
)
149149

150150
# 10. Run the causal test and print results
151-
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case, "ate")
151+
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case)
152152
logger.info("%s", causal_test_result)
153153
return causal_test_result.test_value.value, causal_test_result.confidence_intervals
154154

examples/poisson-line-process/example_poisson_process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def causal_test_intensity_num_shapes(
123123
)
124124

125125
# 10. Execute the test
126-
causal_test_result = causal_test_engine.execute_test(estimator, causal_test_case, causal_test_case.estimate_type)
126+
causal_test_result = causal_test_engine.execute_test(estimator, causal_test_case)
127127

128128
return causal_test_result
129129

tests/testing_tests/test_causal_test_engine.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,8 @@ def test_execute_test_observational_linear_regression_estimator_coefficient(self
189189
"A",
190190
self.causal_test_engine.scenario_execution_data_df,
191191
)
192-
causal_test_result = self.causal_test_engine.execute_test(
193-
estimation_model, self.causal_test_case, estimate_type="coefficient"
194-
)
192+
self.causal_test_case.estimate_type = "coefficient"
193+
causal_test_result = self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)
195194
self.assertEqual(int(causal_test_result.test_value.value), 0)
196195

197196
def test_execute_test_observational_linear_regression_estimator_risk_ratio(self):
@@ -205,9 +204,8 @@ def test_execute_test_observational_linear_regression_estimator_risk_ratio(self)
205204
"A",
206205
self.causal_test_engine.scenario_execution_data_df,
207206
)
208-
causal_test_result = self.causal_test_engine.execute_test(
209-
estimation_model, self.causal_test_case, estimate_type="risk_ratio"
210-
)
207+
self.causal_test_case.estimate_type = "risk_ratio"
208+
causal_test_result = self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)
211209
self.assertEqual(int(causal_test_result.test_value.value), 0)
212210

213211
def test_invalid_estimate_type(self):
@@ -221,8 +219,9 @@ def test_invalid_estimate_type(self):
221219
"A",
222220
self.causal_test_engine.scenario_execution_data_df,
223221
)
222+
self.causal_test_case.estimate_type = "invalid"
224223
with self.assertRaises(ValueError):
225-
self.causal_test_engine.execute_test(estimation_model, self.causal_test_case, estimate_type="invalid")
224+
self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)
226225

227226
def test_execute_test_observational_linear_regression_estimator_squared_term(self):
228227
"""Check that executing the causal test case returns the correct results for dummy data with a squared term
@@ -258,9 +257,8 @@ def test_execute_observational_causal_forest_estimator_cates(self):
258257
self.causal_test_engine.scenario_execution_data_df,
259258
effect_modifiers={"M": None},
260259
)
261-
causal_test_result = self.causal_test_engine.execute_test(
262-
estimation_model, self.causal_test_case, estimate_type="cate"
263-
)
260+
self.causal_test_case.estimate_type = "cate"
261+
causal_test_result = self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)
264262
causal_test_result = causal_test_result.test_value.value
265263
# Check that each effect modifier's strata has a greater ATE than the last (ascending order)
266264
causal_test_result_m1 = causal_test_result.loc[causal_test_result["M"] == 1]

0 commit comments

Comments
 (0)