Skip to content

Commit 6ed62f8

Browse files
Revert to keywords in estimate_.. functions
1 parent b372d22 commit 6ed62f8

File tree

2 files changed

+18
-24
lines changed

2 files changed

+18
-24
lines changed

causal_testing/testing/estimators.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,14 @@ def estimate(self, data: pd.DataFrame, adjustment_config: dict = None) -> Regres
161161
# x = x[model.params.index]
162162
return model.predict(x)
163163

164-
def estimate_control_treatment(self, bootstrap_size, adjustment_config) -> tuple[pd.Series, pd.Series]:
164+
def estimate_control_treatment(self, adjustment_config: dict = None, bootstrap_size: int = 100) -> tuple[pd.Series, pd.Series]:
165165
"""Estimate the outcomes under control and treatment.
166166
167167
:return: The estimated control and treatment values and their confidence
168168
intervals in the form ((ci_low, control, ci_high), (ci_low, treatment, ci_high)).
169169
"""
170-
170+
if adjustment_config is None:
171+
adjustment_config = {}
171172
y = self.estimate(self.df, adjustment_config=adjustment_config)
172173

173174
try:
@@ -197,18 +198,16 @@ def estimate_control_treatment(self, bootstrap_size, adjustment_config) -> tuple
197198

198199
return (y.iloc[1], np.array(control)), (y.iloc[0], np.array(treatment))
199200

200-
def estimate_ate(self, estimator_params: dict = None) -> float:
201+
def estimate_ate(self, adjustment_config: dict = None, bootstrap_size: int = 100) -> float:
201202
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
202203
by changing the treatment variable from the control value to the treatment value. Here, we actually
203204
calculate the expected outcomes under control and treatment and take one away from the other. This
204205
allows for custom terms to be put in such as squares, inverses, products, etc.
205206
206207
:return: The estimated average treatment effect and 95% confidence intervals
207208
"""
208-
if estimator_params is None:
209-
estimator_params = {}
210-
bootstrap_size = estimator_params.get("bootstrap_size", 100)
211-
adjustment_config = estimator_params.get("adjustment_config", None)
209+
if adjustment_config is None:
210+
adjustment_config = {}
212211
(control_outcome, control_bootstraps), (
213212
treatment_outcome,
214213
treatment_bootstraps,
@@ -231,18 +230,16 @@ def estimate_ate(self, estimator_params: dict = None) -> float:
231230

232231
return estimate, (ci_low, ci_high)
233232

234-
def estimate_risk_ratio(self, estimator_params: dict = None) -> float:
233+
def estimate_risk_ratio(self, adjustment_config: dict = None, bootstrap_size: int = 100) -> float:
235234
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
236235
by changing the treatment variable from the control value to the treatment value. Here, we actually
237236
calculate the expected outcomes under control and treatment and divide one by the other. This
238237
allows for custom terms to be put in such as squares, inverses, products, etc.
239238
240239
:return: The estimated risk ratio and 95% confidence intervals.
241240
"""
242-
if estimator_params is None:
243-
estimator_params = {}
244-
bootstrap_size = estimator_params.get("bootstrap_size", 100)
245-
adjustment_config = estimator_params.get("adjustment_config", None)
241+
if adjustment_config is None:
242+
adjustment_config = {}
246243
(control_outcome, control_bootstraps), (
247244
treatment_outcome,
248245
treatment_bootstraps,
@@ -392,33 +389,30 @@ def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd
392389

393390
return y.iloc[1], y.iloc[0]
394391

395-
def estimate_risk_ratio(self, estimator_params: dict = None) -> tuple[float, list[float, float]]:
392+
def estimate_risk_ratio(self, adjustment_config: dict = None) -> tuple[float, list[float, float]]:
396393
"""Estimate the risk_ratio effect of the treatment on the outcome. That is, the change in outcome caused
397394
by changing the treatment variable from the control value to the treatment value.
398395
399396
:return: The average treatment effect and the 95% Wald confidence intervals.
400397
"""
401-
if estimator_params is None:
402-
estimator_params = {}
403-
adjustment_config = estimator_params.get("adjustment_config", None)
404-
398+
if adjustment_config is None:
399+
adjustment_config = {}
405400
control_outcome, treatment_outcome = self.estimate_control_treatment(adjustment_config=adjustment_config)
406401
ci_low = treatment_outcome["mean_ci_lower"] / control_outcome["mean_ci_upper"]
407402
ci_high = treatment_outcome["mean_ci_upper"] / control_outcome["mean_ci_lower"]
408403

409404
return (treatment_outcome["mean"] / control_outcome["mean"]), [ci_low, ci_high]
410405

411-
def estimate_ate_calculated(self, estimator_params: dict = None) -> tuple[float, list[float, float]]:
406+
def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[float, list[float, float]]:
412407
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
413408
by changing the treatment variable from the control value to the treatment value. Here, we actually
414409
calculate the expected outcomes under control and treatment and divide one by the other. This
415410
allows for custom terms to be put in such as squares, inverses, products, etc.
416411
417412
:return: The average treatment effect and the 95% Wald confidence intervals.
418413
"""
419-
if estimator_params is None:
420-
estimator_params = {}
421-
adjustment_config = estimator_params.get("adjustment_config", None)
414+
if adjustment_config is None:
415+
adjustment_config = {}
422416
control_outcome, treatment_outcome = self.estimate_control_treatment(adjustment_config=adjustment_config)
423417
ci_low = treatment_outcome["mean_ci_lower"] - control_outcome["mean_ci_upper"]
424418
ci_high = treatment_outcome["mean_ci_upper"] - control_outcome["mean_ci_lower"]

tests/testing_tests/test_estimators.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,15 @@ def test_ate_adjustment(self):
124124
logistic_regression_estimator = LogisticRegressionEstimator(
125125
"length_in", 65, 55, {"large_gauge"}, "completed", df
126126
)
127-
ate, _ = logistic_regression_estimator.estimate_ate(estimator_params={"adjustment_config": {"large_gauge": 0}})
127+
ate, _ = logistic_regression_estimator.estimate_ate(adjustment_config = {"large_gauge": 0})
128128
self.assertEqual(round(ate, 4), -0.3388)
129129

130130
def test_ate_invalid_adjustment(self):
131131
df = self.scarf_df.copy()
132132
logistic_regression_estimator = LogisticRegressionEstimator("length_in", 65, 55, {}, "completed", df)
133133
with self.assertRaises(ValueError):
134134
ate, _ = logistic_regression_estimator.estimate_ate(
135-
estimator_params={"adjustment_config": {"large_gauge": 0}}
135+
adjustment_config = {"large_gauge": 0}
136136
)
137137

138138
def test_ate_effect_modifiers(self):
@@ -394,7 +394,7 @@ def test_program_15_no_interaction_ate_calculated(self):
394394
# for term_to_square in terms_to_square:
395395

396396
ate, [ci_low, ci_high] = linear_regression_estimator.estimate_ate_calculated(
397-
estimator_params={"adjustment_config": {k: self.nhefs_df.mean()[k] for k in covariates}}
397+
adjustment_config = {k: self.nhefs_df.mean()[k] for k in covariates}
398398
)
399399
self.assertEqual(round(ate, 1), 3.5)
400400
self.assertEqual([round(ci_low, 1), round(ci_high, 1)], [1.9, 5])

0 commit comments

Comments
 (0)