Skip to content

Commit 4b841ce

Browse files
committed
Further coverage
1 parent 4cd496f commit 4b841ce

File tree

2 files changed

+41
-12
lines changed

2 files changed

+41
-12
lines changed

causal_testing/testing/estimators.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,6 @@ def __init__(
5858
self.df = df
5959
if effect_modifiers is None:
6060
self.effect_modifiers = {}
61-
elif isinstance(effect_modifiers, (list, set)):
62-
self.effect_modifiers = {k for k in effect_modifiers}
6361
elif isinstance(effect_modifiers, dict):
6462
self.effect_modifiers = {k: v for k, v in effect_modifiers.items()}
6563
else:
@@ -119,9 +117,6 @@ def __init__(
119117
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(self.effect_modifiers))
120118
self.formula = f"{outcome} ~ {'+'.join(((terms)))}"
121119

122-
for term in self.effect_modifiers:
123-
self.adjustment_set.add(term)
124-
125120
def add_modelling_assumptions(self):
126121
"""
127122
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
@@ -170,6 +165,10 @@ def estimate(self, data: pd.DataFrame, adjustment_config=None) -> RegressionResu
170165
"""
171166
if adjustment_config is None:
172167
adjustment_config = {}
168+
if set(self.adjustment_set) != set(adjustment_config):
169+
raise ValueError(
170+
f"Invalid adjustment configuration {adjustment_config}. Must specify values for {self.adjustment_set}"
171+
)
173172

174173
model = self._run_logistic_regression(data)
175174
self.model = model
@@ -188,18 +187,19 @@ def estimate(self, data: pd.DataFrame, adjustment_config=None) -> RegressionResu
188187
# x = x[model.params.index]
189188
return model.predict(x)
190189

191-
def estimate_control_treatment(self, bootstrap_size=100) -> tuple[pd.Series, pd.Series]:
190+
def estimate_control_treatment(self, bootstrap_size=100, adjustment_config=None) -> tuple[pd.Series, pd.Series]:
192191
"""Estimate the outcomes under control and treatment.
193192
194193
:return: The estimated control and treatment values and their confidence
195194
intervals in the form ((ci_low, control, ci_high), (ci_low, treatment, ci_high)).
196195
"""
197196

198-
y = self.estimate(self.df)
197+
y = self.estimate(self.df, adjustment_config=adjustment_config)
199198

200199
try:
201200
bootstrap_samples = [
202-
self.estimate(self.df.sample(len(self.df), replace=True)) for _ in range(bootstrap_size)
201+
self.estimate(self.df.sample(len(self.df), replace=True), adjustment_config=adjustment_config)
202+
for _ in range(bootstrap_size)
203203
]
204204
control, treatment = zip(*[(x.iloc[1], x.iloc[0]) for x in bootstrap_samples])
205205
except PerfectSeparationError:
@@ -223,7 +223,7 @@ def estimate_control_treatment(self, bootstrap_size=100) -> tuple[pd.Series, pd.
223223

224224
return (y.iloc[1], np.array(control)), (y.iloc[0], np.array(treatment))
225225

226-
def estimate_ate(self, bootstrap_size=100) -> float:
226+
def estimate_ate(self, bootstrap_size=100, adjustment_config=None) -> float:
227227
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
228228
by changing the treatment variable from the control value to the treatment value. Here, we actually
229229
calculate the expected outcomes under control and treatment and take one away from the other. This
@@ -234,7 +234,7 @@ def estimate_ate(self, bootstrap_size=100) -> float:
234234
(control_outcome, control_bootstraps), (
235235
treatment_outcome,
236236
treatment_bootstraps,
237-
) = self.estimate_control_treatment(bootstrap_size=bootstrap_size)
237+
) = self.estimate_control_treatment(bootstrap_size=bootstrap_size, adjustment_config=adjustment_config)
238238
estimate = treatment_outcome - control_outcome
239239

240240
if control_bootstraps is None or treatment_bootstraps is None:
@@ -253,7 +253,7 @@ def estimate_ate(self, bootstrap_size=100) -> float:
253253

254254
return estimate, (ci_low, ci_high)
255255

256-
def estimate_risk_ratio(self, bootstrap_size=100) -> float:
256+
def estimate_risk_ratio(self, bootstrap_size=100, adjustment_config=None) -> float:
257257
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
258258
by changing the treatment variable from the control value to the treatment value. Here, we actually
259259
calculate the expected outcomes under control and treatment and divide one by the other. This
@@ -264,7 +264,7 @@ def estimate_risk_ratio(self, bootstrap_size=100) -> float:
264264
(control_outcome, control_bootstraps), (
265265
treatment_outcome,
266266
treatment_bootstraps,
267-
) = self.estimate_control_treatment(bootstrap_size=bootstrap_size)
267+
) = self.estimate_control_treatment(bootstrap_size=bootstrap_size, adjustment_config=adjustment_config)
268268
estimate = treatment_outcome / control_outcome
269269

270270
if control_bootstraps is None or treatment_bootstraps is None:

tests/testing_tests/test_estimators.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,20 @@ def test_odds_ratio(self):
110110
odds = logistic_regression_estimator.estimate_unit_odds_ratio()
111111
self.assertEqual(round(odds, 4), 0.8948)
112112

113+
def test_ate_adjustment(self):
114+
df = self.scarf_df.copy()
115+
logistic_regression_estimator = LogisticRegressionEstimator(
116+
"length_in", 65, 55, {"large_gauge"}, "completed", df
117+
)
118+
ate, _ = logistic_regression_estimator.estimate_ate(adjustment_config={"large_gauge": 0})
119+
self.assertEqual(round(ate, 4), -0.3388)
120+
121+
def test_ate_invalid_adjustment(self):
122+
df = self.scarf_df.copy()
123+
logistic_regression_estimator = LogisticRegressionEstimator("length_in", 65, 55, {}, "completed", df)
124+
with self.assertRaises(ValueError):
125+
ate, _ = logistic_regression_estimator.estimate_ate(adjustment_config={"large_gauge": 0})
126+
113127
def test_ate_effect_modifiers(self):
114128
df = self.scarf_df.copy()
115129
logistic_regression_estimator = LogisticRegressionEstimator(
@@ -118,6 +132,21 @@ def test_ate_effect_modifiers(self):
118132
ate, _ = logistic_regression_estimator.estimate_ate()
119133
self.assertEqual(round(ate, 4), -0.3388)
120134

135+
def test_ate_effect_modifiers_formula(self):
136+
df = self.scarf_df.copy()
137+
logistic_regression_estimator = LogisticRegressionEstimator(
138+
"length_in",
139+
65,
140+
55,
141+
set(),
142+
"completed",
143+
df,
144+
effect_modifiers={"large_gauge": 0},
145+
formula="completed ~ length_in + large_gauge",
146+
)
147+
ate, _ = logistic_regression_estimator.estimate_ate()
148+
self.assertEqual(round(ate, 4), -0.3388)
149+
121150

122151
class TestInstrumentalVariableEstimator(unittest.TestCase):
123152
"""

0 commit comments

Comments
 (0)