Skip to content

Commit 18b6ccd

Browse files
Merge pull request #225 from CITCOM-project/estimator_params_for_linear
add estimator params for linear regression estimate methods
2 parents ee8be4d + d8439a7 commit 18b6ccd

File tree

2 files changed

+21
-18
lines changed

2 files changed

+21
-18
lines changed

causal_testing/testing/estimators.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,16 @@ def estimate(self, data: pd.DataFrame, adjustment_config: dict = None) -> Regres
161161
# x = x[model.params.index]
162162
return model.predict(x)
163163

164-
def estimate_control_treatment(self, bootstrap_size, adjustment_config) -> tuple[pd.Series, pd.Series]:
164+
def estimate_control_treatment(
165+
self, adjustment_config: dict = None, bootstrap_size: int = 100
166+
) -> tuple[pd.Series, pd.Series]:
165167
"""Estimate the outcomes under control and treatment.
166168
167169
:return: The estimated control and treatment values and their confidence
168170
intervals in the form ((ci_low, control, ci_high), (ci_low, treatment, ci_high)).
169171
"""
170-
172+
if adjustment_config is None:
173+
adjustment_config = {}
171174
y = self.estimate(self.df, adjustment_config=adjustment_config)
172175

173176
try:
@@ -197,18 +200,16 @@ def estimate_control_treatment(self, bootstrap_size, adjustment_config) -> tuple
197200

198201
return (y.iloc[1], np.array(control)), (y.iloc[0], np.array(treatment))
199202

200-
def estimate_ate(self, estimator_params: dict = None) -> float:
203+
def estimate_ate(self, adjustment_config: dict = None, bootstrap_size: int = 100) -> float:
201204
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
202205
by changing the treatment variable from the control value to the treatment value. Here, we actually
203206
calculate the expected outcomes under control and treatment and take one away from the other. This
204207
allows for custom terms to be put in such as squares, inverses, products, etc.
205208
206209
:return: The estimated average treatment effect and 95% confidence intervals
207210
"""
208-
if estimator_params is None:
209-
estimator_params = {}
210-
bootstrap_size = estimator_params.get("bootstrap_size", 100)
211-
adjustment_config = estimator_params.get("adjustment_config", None)
211+
if adjustment_config is None:
212+
adjustment_config = {}
212213
(control_outcome, control_bootstraps), (
213214
treatment_outcome,
214215
treatment_bootstraps,
@@ -231,18 +232,16 @@ def estimate_ate(self, estimator_params: dict = None) -> float:
231232

232233
return estimate, (ci_low, ci_high)
233234

234-
def estimate_risk_ratio(self, estimator_params: dict = None) -> float:
235+
def estimate_risk_ratio(self, adjustment_config: dict = None, bootstrap_size: int = 100) -> float:
235236
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
236237
by changing the treatment variable from the control value to the treatment value. Here, we actually
237238
calculate the expected outcomes under control and treatment and divide one by the other. This
238239
allows for custom terms to be put in such as squares, inverses, products, etc.
239240
240241
:return: The estimated risk ratio and 95% confidence intervals.
241242
"""
242-
if estimator_params is None:
243-
estimator_params = {}
244-
bootstrap_size = estimator_params.get("bootstrap_size", 100)
245-
adjustment_config = estimator_params.get("adjustment_config", None)
243+
if adjustment_config is None:
244+
adjustment_config = {}
246245
(control_outcome, control_bootstraps), (
247246
treatment_outcome,
248247
treatment_bootstraps,
@@ -374,7 +373,6 @@ def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd
374373
"""
375374
if adjustment_config is None:
376375
adjustment_config = {}
377-
378376
model = self._run_linear_regression()
379377

380378
x = pd.DataFrame(columns=self.df.columns)
@@ -393,13 +391,15 @@ def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd
393391

394392
return y.iloc[1], y.iloc[0]
395393

396-
def estimate_risk_ratio(self) -> tuple[float, list[float, float]]:
394+
def estimate_risk_ratio(self, adjustment_config: dict = None) -> tuple[float, list[float, float]]:
397395
"""Estimate the risk_ratio effect of the treatment on the outcome. That is, the change in outcome caused
398396
by changing the treatment variable from the control value to the treatment value.
399397
400398
:return: The average treatment effect and the 95% Wald confidence intervals.
401399
"""
402-
control_outcome, treatment_outcome = self.estimate_control_treatment()
400+
if adjustment_config is None:
401+
adjustment_config = {}
402+
control_outcome, treatment_outcome = self.estimate_control_treatment(adjustment_config=adjustment_config)
403403
ci_low = treatment_outcome["mean_ci_lower"] / control_outcome["mean_ci_upper"]
404404
ci_high = treatment_outcome["mean_ci_upper"] / control_outcome["mean_ci_lower"]
405405

@@ -413,6 +413,8 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[float
413413
414414
:return: The average treatment effect and the 95% Wald confidence intervals.
415415
"""
416+
if adjustment_config is None:
417+
adjustment_config = {}
416418
control_outcome, treatment_outcome = self.estimate_control_treatment(adjustment_config=adjustment_config)
417419
ci_low = treatment_outcome["mean_ci_lower"] - control_outcome["mean_ci_upper"]
418420
ci_high = treatment_outcome["mean_ci_upper"] - control_outcome["mean_ci_lower"]

tests/testing_tests/test_estimators.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,15 @@ def test_ate_adjustment(self):
124124
logistic_regression_estimator = LogisticRegressionEstimator(
125125
"length_in", 65, 55, {"large_gauge"}, "completed", df
126126
)
127-
ate, _ = logistic_regression_estimator.estimate_ate(estimator_params={"adjustment_config": {"large_gauge": 0}})
127+
ate, _ = logistic_regression_estimator.estimate_ate(adjustment_config = {"large_gauge": 0})
128128
self.assertEqual(round(ate, 4), -0.3388)
129129

130130
def test_ate_invalid_adjustment(self):
131131
df = self.scarf_df.copy()
132132
logistic_regression_estimator = LogisticRegressionEstimator("length_in", 65, 55, {}, "completed", df)
133133
with self.assertRaises(ValueError):
134134
ate, _ = logistic_regression_estimator.estimate_ate(
135-
estimator_params={"adjustment_config": {"large_gauge": 0}}
135+
adjustment_config = {"large_gauge": 0}
136136
)
137137

138138
def test_ate_effect_modifiers(self):
@@ -392,8 +392,9 @@ def test_program_15_no_interaction_ate_calculated(self):
392392
)
393393
# terms_to_square = ["age", "wt71", "smokeintensity", "smokeyrs"]
394394
# for term_to_square in terms_to_square:
395+
395396
ate, [ci_low, ci_high] = linear_regression_estimator.estimate_ate_calculated(
396-
{k: self.nhefs_df.mean()[k] for k in covariates}
397+
adjustment_config = {k: self.nhefs_df.mean()[k] for k in covariates}
397398
)
398399
self.assertEqual(round(ate, 1), 3.5)
399400
self.assertEqual([round(ci_low, 1), round(ci_high, 1)], [1.9, 5])

0 commit comments

Comments
 (0)