Skip to content

Commit c5e1c2d

Browse files
committed
Reduced size of _return_causal_test_results
1 parent c583759 commit c5e1c2d

File tree

5 files changed

+34
-77
lines changed

5 files changed

+34
-77
lines changed

causal_testing/testing/causal_test_engine.py

Lines changed: 11 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -148,63 +148,17 @@ def _return_causal_test_results(self, estimator, causal_test_case):
148148
:param causal_test_case: The concrete test case to be executed
149149
:return: a CausalTestResult object containing the confidence intervals
150150
"""
151-
if causal_test_case.estimate_type == "cate":
152-
logger.debug("calculating cate")
153-
if not hasattr(estimator, "estimate_cates"):
154-
raise NotImplementedError(f"{estimator.__class__} has no CATE method.")
155-
156-
cates_df, confidence_intervals = estimator.estimate_cates()
157-
causal_test_result = CausalTestResult(
158-
estimator=estimator,
159-
test_value=TestValue("ate", cates_df),
160-
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
161-
confidence_intervals=confidence_intervals,
162-
)
163-
elif causal_test_case.estimate_type == "risk_ratio":
164-
logger.debug("calculating risk_ratio")
165-
risk_ratio, confidence_intervals = estimator.estimate_risk_ratio(**causal_test_case.estimate_params)
166-
167-
causal_test_result = CausalTestResult(
168-
estimator=estimator,
169-
test_value=TestValue("risk_ratio", risk_ratio),
170-
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
171-
confidence_intervals=confidence_intervals,
172-
)
173-
elif causal_test_case.estimate_type == "coefficient":
174-
logger.debug("calculating coefficient")
175-
coefficient, confidence_intervals = estimator.estimate_unit_ate()
176-
causal_test_result = CausalTestResult(
177-
estimator=estimator,
178-
test_value=TestValue("coefficient", coefficient),
179-
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
180-
confidence_intervals=confidence_intervals,
181-
)
182-
elif causal_test_case.estimate_type == "ate":
183-
logger.debug("calculating ate")
184-
ate, confidence_intervals = estimator.estimate_ate()
185-
causal_test_result = CausalTestResult(
186-
estimator=estimator,
187-
test_value=TestValue("ate", ate),
188-
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
189-
confidence_intervals=confidence_intervals,
190-
)
191-
# causal_test_result = CausalTestResult(minimal_adjustment_set, ate, confidence_intervals)
192-
# causal_test_result.apply_test_oracle_procedure(self.causal_test_case.expected_causal_effect)
193-
elif causal_test_case.estimate_type == "ate_calculated":
194-
logger.debug("calculating ate")
195-
ate, confidence_intervals = estimator.estimate_ate_calculated()
196-
causal_test_result = CausalTestResult(
197-
estimator=estimator,
198-
test_value=TestValue("ate", ate),
199-
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
200-
confidence_intervals=confidence_intervals,
201-
)
202-
# causal_test_result = CausalTestResult(minimal_adjustment_set, ate, confidence_intervals)
203-
# causal_test_result.apply_test_oracle_procedure(self.causal_test_case.expected_causal_effect)
204-
else:
205-
raise ValueError(
206-
f"Invalid estimate type {causal_test_case.estimate_type}, expected 'ate', 'cate', or 'risk_ratio'"
207-
)
151+
if not hasattr(estimator, f"estimate_{causal_test_case.estimate_type}"):
152+
raise NotImplementedError(f"{estimator.__class__} has no {causal_test_case.estimate_type} method.")
153+
estimate_effect = getattr(estimator, f"estimate_{causal_test_case.estimate_type}")
154+
effect, confidence_intervals = estimate_effect(**causal_test_case.estimate_params)
155+
causal_test_result = CausalTestResult(
156+
estimator=estimator,
157+
test_value=TestValue(causal_test_case.estimate_type, effect),
158+
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
159+
confidence_intervals=confidence_intervals,
160+
)
161+
208162
return causal_test_result
209163

210164
def _check_positivity_violation(self, variables_list):

causal_testing/testing/estimators.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def add_modelling_assumptions(self):
341341
"do not need to be linear."
342342
)
343343

344-
def estimate_unit_ate(self) -> float:
344+
def estimate_coefficient(self) -> float:
345345
"""Estimate the unit average treatment effect of the treatment on the outcome. That is, the change in outcome
346346
caused by a unit change in treatment.
347347
@@ -495,7 +495,7 @@ def add_modelling_assumptions(self):
495495
(iii) Instrument and outcome do not share causes
496496
"""
497497

498-
def estimate_coefficient(self, df):
498+
def estimate_coefficient_aux(self, df):
499499
"""
500500
Estimate the linear regression coefficient of the treatment on the
501501
outcome.
@@ -509,19 +509,19 @@ def estimate_coefficient(self, df):
509509
# Estimate the coefficient of I on X by cancelling
510510
return ab / a
511511

512-
def estimate_unit_ate(self, bootstrap_size=100):
512+
def estimate_coefficient(self, bootstrap_size=100):
513513
"""
514514
Estimate the unit ate (i.e. coefficient) of the treatment on the
515515
outcome.
516516
"""
517517
bootstraps = sorted(
518-
[self.estimate_coefficient(self.df.sample(len(self.df), replace=True)) for _ in range(bootstrap_size)]
518+
[self.estimate_coefficient_aux(self.df.sample(len(self.df), replace=True)) for _ in range(bootstrap_size)]
519519
)
520520
bound = ceil((bootstrap_size * self.alpha) / 2)
521521
ci_low = bootstraps[bound]
522522
ci_high = bootstraps[bootstrap_size - bound]
523523

524-
return self.estimate_coefficient(self.df), (ci_low, ci_high)
524+
return self.estimate_coefficient_aux(self.df), (ci_low, ci_high)
525525

526526

527527
class CausalForestEstimator(Estimator):

tests/json_front_tests/test_json_class.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_f_flag(self):
9898
effects = {"NoEffect": NoEffect()}
9999
mutates = {
100100
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
101-
> self.json_class.scenario.variables[x].z3
101+
> self.json_class.scenario.variables[x].z3
102102
}
103103
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
104104
with self.assertRaises(StatisticsError):
@@ -147,7 +147,7 @@ def test_run_json_tests_from_json(self):
147147
effects = {"NoEffect": NoEffect()}
148148
mutates = {
149149
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
150-
> self.json_class.scenario.variables[x].z3
150+
> self.json_class.scenario.variables[x].z3
151151
}
152152
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
153153

@@ -177,7 +177,7 @@ def test_generate_tests_from_json_no_dist(self):
177177
effects = {"NoEffect": NoEffect()}
178178
mutates = {
179179
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
180-
> self.json_class.scenario.variables[x].z3
180+
> self.json_class.scenario.variables[x].z3
181181
}
182182
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
183183

@@ -207,7 +207,7 @@ def test_formula_in_json_test(self):
207207
effects = {"Positive": Positive()}
208208
mutates = {
209209
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
210-
> self.json_class.scenario.variables[x].z3
210+
> self.json_class.scenario.variables[x].z3
211211
}
212212
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
213213

@@ -239,6 +239,7 @@ def test_run_concrete_json_testcase(self):
239239
with open("temp_out.txt", "r") as reader:
240240
temp_out = reader.readlines()
241241
self.assertIn("FAILED", temp_out[-1])
242+
242243
def test_concrete_generate_params(self):
243244
example_test = {
244245
"tests": [
@@ -259,7 +260,7 @@ def test_concrete_generate_params(self):
259260
effects = {"NoEffect": NoEffect()}
260261
mutates = {
261262
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
262-
> self.json_class.scenario.variables[x].z3
263+
> self.json_class.scenario.variables[x].z3
263264
}
264265
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
265266

tests/testing_tests/test_causal_test_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def test_invalid_estimate_type(self):
220220
self.causal_test_engine.scenario_execution_data_df,
221221
)
222222
self.causal_test_case.estimate_type = "invalid"
223-
with self.assertRaises(ValueError):
223+
with self.assertRaises(NotImplementedError):
224224
self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)
225225

226226
def test_execute_test_observational_linear_regression_estimator_squared_term(self):
@@ -257,7 +257,7 @@ def test_execute_observational_causal_forest_estimator_cates(self):
257257
self.causal_test_engine.scenario_execution_data_df,
258258
effect_modifiers={"M": None},
259259
)
260-
self.causal_test_case.estimate_type = "cate"
260+
self.causal_test_case.estimate_type = "cates"
261261
causal_test_result = self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)
262262
causal_test_result = causal_test_result.test_value.value
263263
# Check that each effect modifier's strata has a greater ATE than the last (ascending order)

tests/testing_tests/test_estimators.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def setUpClass(cls) -> None:
9898
def test_linear_regression_categorical_ate(self):
9999
df = self.scarf_df.copy()
100100
logistic_regression_estimator = LinearRegressionEstimator("color", None, None, set(), "completed", df)
101-
ate, confidence = logistic_regression_estimator.estimate_unit_ate()
101+
ate, confidence = logistic_regression_estimator.estimate_coefficient()
102102
self.assertTrue(all([ci_low < 0 < ci_high for ci_low, ci_high in zip(confidence[0], confidence[1])]))
103103

104104
def test_ate(self):
@@ -131,7 +131,9 @@ def test_ate_invalid_adjustment(self):
131131
df = self.scarf_df.copy()
132132
logistic_regression_estimator = LogisticRegressionEstimator("length_in", 65, 55, {}, "completed", df)
133133
with self.assertRaises(ValueError):
134-
ate, _ = logistic_regression_estimator.estimate_ate(estimator_params={"adjustment_config": {"large_gauge": 0}})
134+
ate, _ = logistic_regression_estimator.estimate_ate(
135+
estimator_params={"adjustment_config": {"large_gauge": 0}}
136+
)
135137

136138
def test_ate_effect_modifiers(self):
137139
df = self.scarf_df.copy()
@@ -184,7 +186,7 @@ def test_estimate_coefficient(self):
184186
)
185187
self.assertEqual(iv_estimator.estimate_coefficient(self.df), 2)
186188

187-
def test_estimate_unit_ate(self):
189+
def test_estimate_coefficient(self):
188190
"""
189191
Test we get the correct coefficient.
190192
"""
@@ -197,8 +199,8 @@ def test_estimate_unit_ate(self):
197199
outcome="Y",
198200
instrument="Z",
199201
)
200-
unit_ate, [low, high] = iv_estimator.estimate_unit_ate()
201-
self.assertEqual(unit_ate, 2)
202+
coefficient, [low, high] = iv_estimator.estimate_coefficient()
203+
self.assertEqual(coefficient, 2)
202204

203205

204206
class TestLinearRegressionEstimator(unittest.TestCase):
@@ -218,7 +220,7 @@ def test_program_11_2(self):
218220
df = self.chapter_11_df
219221
linear_regression_estimator = LinearRegressionEstimator("treatments", None, None, set(), "outcomes", df)
220222
model = linear_regression_estimator._run_linear_regression()
221-
ate, _ = linear_regression_estimator.estimate_unit_ate()
223+
ate, _ = linear_regression_estimator.estimate_coefficient()
222224

223225
self.assertEqual(round(model.params["Intercept"] + 90 * model.params["treatments"], 1), 216.9)
224226

@@ -232,7 +234,7 @@ def test_program_11_3(self):
232234
"treatments", None, None, set(), "outcomes", df, formula="outcomes ~ treatments + np.power(treatments, 2)"
233235
)
234236
model = linear_regression_estimator._run_linear_regression()
235-
ate, _ = linear_regression_estimator.estimate_unit_ate()
237+
ate, _ = linear_regression_estimator.estimate_coefficient()
236238
self.assertEqual(
237239
round(
238240
model.params["Intercept"]
@@ -320,7 +322,7 @@ def test_program_15_no_interaction(self):
320322
)
321323
# terms_to_square = ["age", "wt71", "smokeintensity", "smokeyrs"]
322324
# for term_to_square in terms_to_square:
323-
ate, [ci_low, ci_high] = linear_regression_estimator.estimate_unit_ate()
325+
ate, [ci_low, ci_high] = linear_regression_estimator.estimate_coefficient()
324326
self.assertEqual(round(ate, 1), 3.5)
325327
self.assertEqual([round(ci_low, 1), round(ci_high, 1)], [2.6, 4.3])
326328

0 commit comments

Comments
 (0)