Skip to content

Commit e028e84

Browse files
authored
Merge pull request #213 from CITCOM-project/no-engine
Removed the big `if` statement in causal test engine
2 parents 7856456 + e6fa220 commit e028e84

File tree

7 files changed

+36
-81
lines changed

7 files changed

+36
-81
lines changed

causal_testing/testing/causal_test_engine.py

Lines changed: 11 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResu
6363
raise ValueError("No data has been loaded. Please call load_data prior to executing a causal test case.")
6464
test_suite_results = {}
6565
for edge in test_suite:
66-
print("edge: ")
67-
print(edge)
6866
logger.info("treatment: %s", edge.treatment_variable)
6967
logger.info("outcome: %s", edge.outcome_variable)
7068
minimal_adjustment_set = self.causal_dag.identification(edge)
@@ -148,63 +146,17 @@ def _return_causal_test_results(self, estimator, causal_test_case):
148146
:param causal_test_case: The concrete test case to be executed
149147
:return: a CausalTestResult object containing the confidence intervals
150148
"""
151-
if causal_test_case.estimate_type == "cate":
152-
logger.debug("calculating cate")
153-
if not hasattr(estimator, "estimate_cates"):
154-
raise NotImplementedError(f"{estimator.__class__} has no CATE method.")
155-
156-
cates_df, confidence_intervals = estimator.estimate_cates()
157-
causal_test_result = CausalTestResult(
158-
estimator=estimator,
159-
test_value=TestValue("ate", cates_df),
160-
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
161-
confidence_intervals=confidence_intervals,
162-
)
163-
elif causal_test_case.estimate_type == "risk_ratio":
164-
logger.debug("calculating risk_ratio")
165-
risk_ratio, confidence_intervals = estimator.estimate_risk_ratio(**causal_test_case.estimate_params)
166-
167-
causal_test_result = CausalTestResult(
168-
estimator=estimator,
169-
test_value=TestValue("risk_ratio", risk_ratio),
170-
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
171-
confidence_intervals=confidence_intervals,
172-
)
173-
elif causal_test_case.estimate_type == "coefficient":
174-
logger.debug("calculating coefficient")
175-
coefficient, confidence_intervals = estimator.estimate_unit_ate()
176-
causal_test_result = CausalTestResult(
177-
estimator=estimator,
178-
test_value=TestValue("coefficient", coefficient),
179-
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
180-
confidence_intervals=confidence_intervals,
181-
)
182-
elif causal_test_case.estimate_type == "ate":
183-
logger.debug("calculating ate")
184-
ate, confidence_intervals = estimator.estimate_ate()
185-
causal_test_result = CausalTestResult(
186-
estimator=estimator,
187-
test_value=TestValue("ate", ate),
188-
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
189-
confidence_intervals=confidence_intervals,
190-
)
191-
# causal_test_result = CausalTestResult(minimal_adjustment_set, ate, confidence_intervals)
192-
# causal_test_result.apply_test_oracle_procedure(self.causal_test_case.expected_causal_effect)
193-
elif causal_test_case.estimate_type == "ate_calculated":
194-
logger.debug("calculating ate")
195-
ate, confidence_intervals = estimator.estimate_ate_calculated()
196-
causal_test_result = CausalTestResult(
197-
estimator=estimator,
198-
test_value=TestValue("ate", ate),
199-
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
200-
confidence_intervals=confidence_intervals,
201-
)
202-
# causal_test_result = CausalTestResult(minimal_adjustment_set, ate, confidence_intervals)
203-
# causal_test_result.apply_test_oracle_procedure(self.causal_test_case.expected_causal_effect)
204-
else:
205-
raise ValueError(
206-
f"Invalid estimate type {causal_test_case.estimate_type}, expected 'ate', 'cate', or 'risk_ratio'"
207-
)
149+
if not hasattr(estimator, f"estimate_{causal_test_case.estimate_type}"):
150+
raise AttributeError(f"{estimator.__class__} has no {causal_test_case.estimate_type} method.")
151+
estimate_effect = getattr(estimator, f"estimate_{causal_test_case.estimate_type}")
152+
effect, confidence_intervals = estimate_effect(**causal_test_case.estimate_params)
153+
causal_test_result = CausalTestResult(
154+
estimator=estimator,
155+
test_value=TestValue(causal_test_case.estimate_type, effect),
156+
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
157+
confidence_intervals=confidence_intervals,
158+
)
159+
208160
return causal_test_result
209161

210162
def _check_positivity_violation(self, variables_list):

causal_testing/testing/estimators.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def add_modelling_assumptions(self):
321321
"do not need to be linear."
322322
)
323323

324-
def estimate_unit_ate(self) -> float:
324+
def estimate_coefficient(self) -> float:
325325
"""Estimate the unit average treatment effect of the treatment on the outcome. That is, the change in outcome
326326
caused by a unit change in treatment.
327327
@@ -475,7 +475,7 @@ def add_modelling_assumptions(self):
475475
(iii) Instrument and outcome do not share causes
476476
"""
477477

478-
def estimate_coefficient(self, df):
478+
def estimate_iv_coefficient(self, df):
479479
"""
480480
Estimate the linear regression coefficient of the treatment on the
481481
outcome.
@@ -489,19 +489,19 @@ def estimate_coefficient(self, df):
489489
# Estimate the coefficient of I on X by cancelling
490490
return ab / a
491491

492-
def estimate_unit_ate(self, bootstrap_size=100):
492+
def estimate_coefficient(self, bootstrap_size=100):
493493
"""
494494
Estimate the unit ate (i.e. coefficient) of the treatment on the
495495
outcome.
496496
"""
497497
bootstraps = sorted(
498-
[self.estimate_coefficient(self.df.sample(len(self.df), replace=True)) for _ in range(bootstrap_size)]
498+
[self.estimate_iv_coefficient(self.df.sample(len(self.df), replace=True)) for _ in range(bootstrap_size)]
499499
)
500500
bound = ceil((bootstrap_size * self.alpha) / 2)
501501
ci_low = bootstraps[bound]
502502
ci_high = bootstraps[bootstrap_size - bound]
503503

504-
return self.estimate_coefficient(self.df), (ci_low, ci_high)
504+
return self.estimate_iv_coefficient(self.df), (ci_low, ci_high)
505505

506506

507507
class CausalForestEstimator(Estimator):

examples/covasim_/vaccinating_elderly/example_vaccine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def run_sim_with_pars(pars_dict: dict, desired_outputs: [str], n_runs: int = 1,
212212
# Append outputs to results
213213
for output in desired_outputs:
214214
if output not in results:
215-
raise IndexError(f"{output} is not in the Covasim outputs.")
215+
raise IndexError(f"{output} is not in the Covasim outputs. Are you using v3.0.7?")
216216
results_dict[output].append(
217217
results[output][-1]
218218
) # Append the final recorded value for each variable

examples/poisson/example_run_causal_tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def test_run_causal_tests():
157157
dag_path = f"{ROOT}/dag.dot"
158158
data_path = f"{ROOT}/data.csv"
159159

160-
json_utility = JsonUtility(log_path) # Create an instance of the extended JsonUtility class
160+
json_utility = JsonUtility(log_path, output_overwrite=True) # Create an instance of the extended JsonUtility class
161161
json_utility.set_paths(
162162
json_path, dag_path, [data_path]
163163
) # Set the path to the data.csv, dag.dot and causal_tests.json file

tests/json_front_tests/test_json_class.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_f_flag(self):
9898
effects = {"NoEffect": NoEffect()}
9999
mutates = {
100100
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
101-
> self.json_class.scenario.variables[x].z3
101+
> self.json_class.scenario.variables[x].z3
102102
}
103103
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
104104
with self.assertRaises(StatisticsError):
@@ -147,7 +147,7 @@ def test_run_json_tests_from_json(self):
147147
effects = {"NoEffect": NoEffect()}
148148
mutates = {
149149
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
150-
> self.json_class.scenario.variables[x].z3
150+
> self.json_class.scenario.variables[x].z3
151151
}
152152
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
153153

@@ -177,7 +177,7 @@ def test_generate_tests_from_json_no_dist(self):
177177
effects = {"NoEffect": NoEffect()}
178178
mutates = {
179179
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
180-
> self.json_class.scenario.variables[x].z3
180+
> self.json_class.scenario.variables[x].z3
181181
}
182182
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
183183

@@ -207,7 +207,7 @@ def test_formula_in_json_test(self):
207207
effects = {"Positive": Positive()}
208208
mutates = {
209209
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
210-
> self.json_class.scenario.variables[x].z3
210+
> self.json_class.scenario.variables[x].z3
211211
}
212212
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
213213

@@ -239,6 +239,7 @@ def test_run_concrete_json_testcase(self):
239239
with open("temp_out.txt", "r") as reader:
240240
temp_out = reader.readlines()
241241
self.assertIn("FAILED", temp_out[-1])
242+
242243
def test_concrete_generate_params(self):
243244
example_test = {
244245
"tests": [
@@ -259,7 +260,7 @@ def test_concrete_generate_params(self):
259260
effects = {"NoEffect": NoEffect()}
260261
mutates = {
261262
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
262-
> self.json_class.scenario.variables[x].z3
263+
> self.json_class.scenario.variables[x].z3
263264
}
264265
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
265266

tests/testing_tests/test_causal_test_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def test_invalid_estimate_type(self):
220220
self.causal_test_engine.scenario_execution_data_df,
221221
)
222222
self.causal_test_case.estimate_type = "invalid"
223-
with self.assertRaises(ValueError):
223+
with self.assertRaises(AttributeError):
224224
self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)
225225

226226
def test_execute_test_observational_linear_regression_estimator_squared_term(self):
@@ -257,7 +257,7 @@ def test_execute_observational_causal_forest_estimator_cates(self):
257257
self.causal_test_engine.scenario_execution_data_df,
258258
effect_modifiers={"M": None},
259259
)
260-
self.causal_test_case.estimate_type = "cate"
260+
self.causal_test_case.estimate_type = "cates"
261261
causal_test_result = self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)
262262
causal_test_result = causal_test_result.test_value.value
263263
# Check that each effect modifier's strata has a greater ATE than the last (ascending order)

tests/testing_tests/test_estimators.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def setUpClass(cls) -> None:
9898
def test_linear_regression_categorical_ate(self):
9999
df = self.scarf_df.copy()
100100
logistic_regression_estimator = LinearRegressionEstimator("color", None, None, set(), "completed", df)
101-
ate, confidence = logistic_regression_estimator.estimate_unit_ate()
101+
ate, confidence = logistic_regression_estimator.estimate_coefficient()
102102
self.assertTrue(all([ci_low < 0 < ci_high for ci_low, ci_high in zip(confidence[0], confidence[1])]))
103103

104104
def test_ate(self):
@@ -131,7 +131,9 @@ def test_ate_invalid_adjustment(self):
131131
df = self.scarf_df.copy()
132132
logistic_regression_estimator = LogisticRegressionEstimator("length_in", 65, 55, {}, "completed", df)
133133
with self.assertRaises(ValueError):
134-
ate, _ = logistic_regression_estimator.estimate_ate(estimator_params={"adjustment_config": {"large_gauge": 0}})
134+
ate, _ = logistic_regression_estimator.estimate_ate(
135+
estimator_params={"adjustment_config": {"large_gauge": 0}}
136+
)
135137

136138
def test_ate_effect_modifiers(self):
137139
df = self.scarf_df.copy()
@@ -184,7 +186,7 @@ def test_estimate_coefficient(self):
184186
)
185187
self.assertEqual(iv_estimator.estimate_coefficient(self.df), 2)
186188

187-
def test_estimate_unit_ate(self):
189+
def test_estimate_coefficient(self):
188190
"""
189191
Test we get the correct coefficient.
190192
"""
@@ -197,8 +199,8 @@ def test_estimate_unit_ate(self):
197199
outcome="Y",
198200
instrument="Z",
199201
)
200-
unit_ate, [low, high] = iv_estimator.estimate_unit_ate()
201-
self.assertEqual(unit_ate, 2)
202+
coefficient, [low, high] = iv_estimator.estimate_coefficient()
203+
self.assertEqual(coefficient, 2)
202204

203205

204206
class TestLinearRegressionEstimator(unittest.TestCase):
@@ -218,7 +220,7 @@ def test_program_11_2(self):
218220
df = self.chapter_11_df
219221
linear_regression_estimator = LinearRegressionEstimator("treatments", None, None, set(), "outcomes", df)
220222
model = linear_regression_estimator._run_linear_regression()
221-
ate, _ = linear_regression_estimator.estimate_unit_ate()
223+
ate, _ = linear_regression_estimator.estimate_coefficient()
222224

223225
self.assertEqual(round(model.params["Intercept"] + 90 * model.params["treatments"], 1), 216.9)
224226

@@ -232,7 +234,7 @@ def test_program_11_3(self):
232234
"treatments", None, None, set(), "outcomes", df, formula="outcomes ~ treatments + np.power(treatments, 2)"
233235
)
234236
model = linear_regression_estimator._run_linear_regression()
235-
ate, _ = linear_regression_estimator.estimate_unit_ate()
237+
ate, _ = linear_regression_estimator.estimate_coefficient()
236238
self.assertEqual(
237239
round(
238240
model.params["Intercept"]
@@ -320,7 +322,7 @@ def test_program_15_no_interaction(self):
320322
)
321323
# terms_to_square = ["age", "wt71", "smokeintensity", "smokeyrs"]
322324
# for term_to_square in terms_to_square:
323-
ate, [ci_low, ci_high] = linear_regression_estimator.estimate_unit_ate()
325+
ate, [ci_low, ci_high] = linear_regression_estimator.estimate_coefficient()
324326
self.assertEqual(round(ate, 1), 3.5)
325327
self.assertEqual([round(ci_low, 1), round(ci_high, 1)], [2.6, 4.3])
326328

0 commit comments

Comments
 (0)