Skip to content

Commit 410da1c

Browse files
Merge branch 'main' into adjustment_set_formula_check
2 parents 478fcd7 + 18b6ccd commit 410da1c

File tree

2 files changed

+21
-18
lines changed

2 files changed

+21
-18
lines changed

causal_testing/testing/estimators.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,16 @@ def estimate(self, data: pd.DataFrame, adjustment_config: dict = None) -> Regres
161161
# x = x[model.params.index]
162162
return model.predict(x)
163163

164-
def estimate_control_treatment(self, bootstrap_size, adjustment_config) -> tuple[pd.Series, pd.Series]:
164+
def estimate_control_treatment(
165+
self, adjustment_config: dict = None, bootstrap_size: int = 100
166+
) -> tuple[pd.Series, pd.Series]:
165167
"""Estimate the outcomes under control and treatment.
166168
167169
:return: The estimated control and treatment values and their confidence
168170
intervals in the form ((ci_low, control, ci_high), (ci_low, treatment, ci_high)).
169171
"""
170-
172+
if adjustment_config is None:
173+
adjustment_config = {}
171174
y = self.estimate(self.df, adjustment_config=adjustment_config)
172175

173176
try:
@@ -197,18 +200,16 @@ def estimate_control_treatment(self, bootstrap_size, adjustment_config) -> tuple
197200

198201
return (y.iloc[1], np.array(control)), (y.iloc[0], np.array(treatment))
199202

200-
def estimate_ate(self, estimator_params: dict = None) -> float:
203+
def estimate_ate(self, adjustment_config: dict = None, bootstrap_size: int = 100) -> float:
201204
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
202205
by changing the treatment variable from the control value to the treatment value. Here, we actually
203206
calculate the expected outcomes under control and treatment and take one away from the other. This
204207
allows for custom terms to be put in such as squares, inverses, products, etc.
205208
206209
:return: The estimated average treatment effect and 95% confidence intervals
207210
"""
208-
if estimator_params is None:
209-
estimator_params = {}
210-
bootstrap_size = estimator_params.get("bootstrap_size", 100)
211-
adjustment_config = estimator_params.get("adjustment_config", None)
211+
if adjustment_config is None:
212+
adjustment_config = {}
212213
(control_outcome, control_bootstraps), (
213214
treatment_outcome,
214215
treatment_bootstraps,
@@ -231,18 +232,16 @@ def estimate_ate(self, estimator_params: dict = None) -> float:
231232

232233
return estimate, (ci_low, ci_high)
233234

234-
def estimate_risk_ratio(self, estimator_params: dict = None) -> float:
235+
def estimate_risk_ratio(self, adjustment_config: dict = None, bootstrap_size: int = 100) -> float:
235236
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
236237
by changing the treatment variable from the control value to the treatment value. Here, we actually
237238
calculate the expected outcomes under control and treatment and divide one by the other. This
238239
allows for custom terms to be put in such as squares, inverses, products, etc.
239240
240241
:return: The estimated risk ratio and 95% confidence intervals.
241242
"""
242-
if estimator_params is None:
243-
estimator_params = {}
244-
bootstrap_size = estimator_params.get("bootstrap_size", 100)
245-
adjustment_config = estimator_params.get("adjustment_config", None)
243+
if adjustment_config is None:
244+
adjustment_config = {}
246245
(control_outcome, control_bootstraps), (
247246
treatment_outcome,
248247
treatment_bootstraps,
@@ -371,7 +370,6 @@ def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd
371370
"""
372371
if adjustment_config is None:
373372
adjustment_config = {}
374-
375373
model = self._run_linear_regression()
376374

377375
x = pd.DataFrame(columns=self.df.columns)
@@ -390,13 +388,15 @@ def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd
390388

391389
return y.iloc[1], y.iloc[0]
392390

393-
def estimate_risk_ratio(self) -> tuple[float, list[float, float]]:
391+
def estimate_risk_ratio(self, adjustment_config: dict = None) -> tuple[float, list[float, float]]:
394392
"""Estimate the risk_ratio effect of the treatment on the outcome. That is, the change in outcome caused
395393
by changing the treatment variable from the control value to the treatment value.
396394
397395
:return: The average treatment effect and the 95% Wald confidence intervals.
398396
"""
399-
control_outcome, treatment_outcome = self.estimate_control_treatment()
397+
if adjustment_config is None:
398+
adjustment_config = {}
399+
control_outcome, treatment_outcome = self.estimate_control_treatment(adjustment_config=adjustment_config)
400400
ci_low = treatment_outcome["mean_ci_lower"] / control_outcome["mean_ci_upper"]
401401
ci_high = treatment_outcome["mean_ci_upper"] / control_outcome["mean_ci_lower"]
402402

@@ -410,6 +410,8 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[float
410410
411411
:return: The average treatment effect and the 95% Wald confidence intervals.
412412
"""
413+
if adjustment_config is None:
414+
adjustment_config = {}
413415
control_outcome, treatment_outcome = self.estimate_control_treatment(adjustment_config=adjustment_config)
414416
ci_low = treatment_outcome["mean_ci_lower"] - control_outcome["mean_ci_upper"]
415417
ci_high = treatment_outcome["mean_ci_upper"] - control_outcome["mean_ci_lower"]

tests/testing_tests/test_estimators.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,15 @@ def test_ate_adjustment(self):
124124
logistic_regression_estimator = LogisticRegressionEstimator(
125125
"length_in", 65, 55, {"large_gauge"}, "completed", df
126126
)
127-
ate, _ = logistic_regression_estimator.estimate_ate(estimator_params={"adjustment_config": {"large_gauge": 0}})
127+
ate, _ = logistic_regression_estimator.estimate_ate(adjustment_config = {"large_gauge": 0})
128128
self.assertEqual(round(ate, 4), -0.3388)
129129

130130
def test_ate_invalid_adjustment(self):
131131
df = self.scarf_df.copy()
132132
logistic_regression_estimator = LogisticRegressionEstimator("length_in", 65, 55, {}, "completed", df)
133133
with self.assertRaises(ValueError):
134134
ate, _ = logistic_regression_estimator.estimate_ate(
135-
estimator_params={"adjustment_config": {"large_gauge": 0}}
135+
adjustment_config = {"large_gauge": 0}
136136
)
137137

138138
def test_ate_effect_modifiers(self):
@@ -392,8 +392,9 @@ def test_program_15_no_interaction_ate_calculated(self):
392392
)
393393
# terms_to_square = ["age", "wt71", "smokeintensity", "smokeyrs"]
394394
# for term_to_square in terms_to_square:
395+
395396
ate, [ci_low, ci_high] = linear_regression_estimator.estimate_ate_calculated(
396-
{k: self.nhefs_df.mean()[k] for k in covariates}
397+
adjustment_config = {k: self.nhefs_df.mean()[k] for k in covariates}
397398
)
398399
self.assertEqual(round(ate, 1), 3.5)
399400
self.assertEqual([round(ci_low, 1), round(ci_high, 1)], [1.9, 5])

0 commit comments

Comments
 (0)