Skip to content

Commit 53c6cbe

Browse files
Remove causal specification from execute_test
1 parent 01589f6 commit 53c6cbe

File tree

6 files changed

+18
-23
lines changed

6 files changed

+18
-23
lines changed

causal_testing/json_front/json_class.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,7 @@ def _execute_test_case(
265265

266266
estimation_model = self._setup_test(causal_test_case=causal_test_case, test=test)
267267
causal_test_result = causal_test_case.execute_test(estimator=estimation_model,
268-
data_collector=self.data_collector,
269-
causal_specification=self.causal_specification)
268+
data_collector=self.data_collector)
270269

271270
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
272271

@@ -310,7 +309,7 @@ def _setup_test(
310309
"control_value": causal_test_case.control_value,
311310
"adjustment_set": minimal_adjustment_set,
312311
"outcome": causal_test_case.outcome_variable.name,
313-
"df": self.data_collector.data,
312+
"df": self.data_collector.collect_data(),
314313
"effect_modifiers": causal_test_case.effect_modifier_configuration,
315314
"alpha": test["alpha"] if "alpha" in test else 0.05,
316315
}

causal_testing/testing/causal_test_suite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def execute_test_suite(self, data_collector: ObservationalDataCollector,
7777
test.outcome_variable.name,
7878
)
7979
if estimator.df is None:
80-
estimator.df = data_collector.data
80+
estimator.df = data_collector.collect_data()
8181
causal_test_result = test._return_causal_test_results(estimator)
8282
causal_test_results.append(causal_test_result)
8383

examples/covasim_/doubling_beta/example_beta.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,7 @@ def doubling_beta_CATE_on_csv(
6565

6666
# Add squared terms for beta, since it has a quadratic relationship with cumulative infections
6767
causal_test_result = causal_test_case.execute_test(estimator=linear_regression_estimator,
68-
data_collector=data_collector,
69-
causal_specification=causal_specification)
68+
data_collector=data_collector)
7069

7170
# Repeat for association estimate (no adjustment)
7271
no_adjustment_linear_regression_estimator = LinearRegressionEstimator(
@@ -79,8 +78,7 @@ def doubling_beta_CATE_on_csv(
7978
formula="cum_infections ~ beta + np.power(beta, 2)",
8079
)
8180
association_test_result = causal_test_case.execute_test(estimator=no_adjustment_linear_regression_estimator,
82-
data_collector=data_collector,
83-
causal_specification=causal_specification)
81+
data_collector=data_collector)
8482

8583
# Store results for plotting
8684
results_dict["association"] = {
@@ -111,9 +109,8 @@ def doubling_beta_CATE_on_csv(
111109
formula="cum_infections ~ beta + np.power(beta, 2) + avg_age + contacts",
112110
)
113111
counterfactual_causal_test_result = causal_test_case.execute_test(
114-
estimator=linear_regression_estimator, data_collector=data_collector,
115-
causal_specification=causal_specification
116-
)
112+
estimator=linear_regression_estimator, data_collector=data_collector)
113+
117114
results_dict["counterfactual"] = {
118115
"ate": counterfactual_causal_test_result.test_value.value,
119116
"cis": counterfactual_causal_test_result.confidence_intervals,

examples/covasim_/vaccinating_elderly/example_vaccine.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,7 @@ def test_experimental_vaccinate_elderly(runs_per_test_per_config: int = 30, verb
9595
)
9696

9797
# 9. Execute test and save results in dict
98-
causal_test_result = causal_test_case.execute_test(linear_regression_estimator, data_collector,
99-
causal_specification)
98+
causal_test_result = causal_test_case.execute_test(linear_regression_estimator, data_collector)
10099
if verbose:
101100
logging.info("Causation:\n%s", causal_test_result)
102101
results_dict[outcome_variable.name]["ate"] = causal_test_result.test_value.value

examples/lr91/example_max_conductances.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def effects_on_APD90(observational_data_path, treatment_var, control_val, treatm
144144
)
145145

146146
# 9. Run the causal test and print results
147-
causal_test_result = causal_test_case.execute_test(linear_regression_estimator, data_collector, causal_specification)
147+
causal_test_result = causal_test_case.execute_test(linear_regression_estimator, data_collector)
148148
logger.info("%s", causal_test_result)
149149
return causal_test_result.test_value.value, causal_test_result.confidence_intervals
150150

tests/testing_tests/test_causal_test_case.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def setUp(self) -> None:
108108
# Obsolete?
109109
self.data_collector = ObservationalDataCollector(self.scenario, df)
110110
self.data_collector.collect_data()
111-
self.df = self.data_collector.data
111+
self.df = self.data_collector.collect_data()
112112
self.minimal_adjustment_set = self.causal_dag.identification(self.base_test_case)
113113
# 6. Easier to access treatment and outcome values
114114
self.treatment_value = 1
@@ -130,7 +130,7 @@ def test_execute_test_observational_causal_forest_estimator(self):
130130
"C",
131131
self.df,
132132
)
133-
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.data_collector, self.causal_specification)
133+
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.data_collector)
134134
self.assertAlmostEqual(causal_test_result.test_value.value, 4, delta=1)
135135

136136
def test_invalid_causal_effect(self):
@@ -152,7 +152,7 @@ def test_execute_test_observational_linear_regression_estimator(self):
152152
"C",
153153
self.df,
154154
)
155-
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.data_collector, self.causal_specification)
155+
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.data_collector)
156156
self.assertAlmostEqual(causal_test_result.test_value.value, 4, delta=1e-10)
157157

158158
def test_execute_test_observational_linear_regression_estimator_direct_effect(self):
@@ -179,7 +179,7 @@ def test_execute_test_observational_linear_regression_estimator_direct_effect(se
179179
"C",
180180
self.df,
181181
)
182-
causal_test_result = causal_test_case.execute_test(estimation_model, self.data_collector, self.causal_specification)
182+
causal_test_result = causal_test_case.execute_test(estimation_model, self.data_collector)
183183
self.assertAlmostEqual(causal_test_result.test_value.value, 4, delta=1e-10)
184184

185185
def test_execute_test_observational_linear_regression_estimator_coefficient(self):
@@ -194,7 +194,7 @@ def test_execute_test_observational_linear_regression_estimator_coefficient(self
194194
self.df,
195195
)
196196
self.causal_test_case.estimate_type = "coefficient"
197-
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.data_collector, self.causal_specification)
197+
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.data_collector)
198198
self.assertEqual(int(causal_test_result.test_value.value), 0)
199199

200200
def test_execute_test_observational_linear_regression_estimator_risk_ratio(self):
@@ -209,7 +209,7 @@ def test_execute_test_observational_linear_regression_estimator_risk_ratio(self)
209209
self.df,
210210
)
211211
self.causal_test_case.estimate_type = "risk_ratio"
212-
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.data_collector, self.causal_specification)
212+
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.data_collector)
213213
self.assertEqual(int(causal_test_result.test_value.value), 0)
214214

215215
def test_invalid_estimate_type(self):
@@ -225,7 +225,7 @@ def test_invalid_estimate_type(self):
225225
)
226226
self.causal_test_case.estimate_type = "invalid"
227227
with self.assertRaises(AttributeError):
228-
self.causal_test_case.execute_test(estimation_model, self.data_collector, self.causal_specification)
228+
self.causal_test_case.execute_test(estimation_model, self.data_collector)
229229

230230
def test_execute_test_observational_linear_regression_estimator_squared_term(self):
231231
"""Check that executing the causal test case returns the correct results for dummy data with a squared term
@@ -239,7 +239,7 @@ def test_execute_test_observational_linear_regression_estimator_squared_term(sel
239239
self.df,
240240
formula=f"C ~ A + {'+'.join(self.minimal_adjustment_set)} + (D ** 2)",
241241
)
242-
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.data_collector, self.causal_specification)
242+
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.data_collector)
243243
self.assertAlmostEqual(round(causal_test_result.test_value.value, 1), 4, delta=1)
244244

245245
def test_execute_observational_causal_forest_estimator_cates(self):
@@ -262,7 +262,7 @@ def test_execute_observational_causal_forest_estimator_cates(self):
262262
effect_modifiers={"M": None},
263263
)
264264
self.causal_test_case.estimate_type = "cates"
265-
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.data_collector, self.causal_specification)
265+
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.data_collector)
266266
causal_test_result = causal_test_result.test_value.value
267267
# Check that each effect modifier's strata has a greater ATE than the last (ascending order)
268268
causal_test_result_m1 = causal_test_result.loc[causal_test_result["M"] == 1]

0 commit comments

Comments
 (0)