Skip to content

Commit 7b9e990

Browse files
committed
Added confidence intervals to the logistic regression estimator for risk ratio and ATE
1 parent 6e9a369 commit 7b9e990

File tree

1 file changed

+32
-10
lines changed

1 file changed

+32
-10
lines changed

causal_testing/testing/estimators.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ def _run_logistic_regression(self) -> RegressionResultsWrapper:
149149
def estimate_control_treatment(self) -> tuple[pd.Series, pd.Series]:
150150
"""Estimate the outcomes under control and treatment.
151151
152-
:return: The average treatment effect and the 95% Wald confidence intervals.
152+
:return: The estimated control and treatment values and their confidence
153+
intervals in the form ((ci_low, control, ci_high), (ci_low, treatment, ci_high)).
153154
"""
154155
model = self._run_logistic_regression()
155156
self.model = model
@@ -168,31 +169,51 @@ def estimate_control_treatment(self) -> tuple[pd.Series, pd.Series]:
168169
x = x[model.params.index]
169170

170171
y = model.predict(x)
171-
return y.iloc[1], y.iloc[0]
172+
173+
# Delta method confidence intervals from
174+
# https://stackoverflow.com/questions/47414842/confidence-interval-of-probability-prediction-from-logistic-regression-statsmode
175+
cov = model.cov_params()
176+
gradient = (y * (1 - y) * x.T).T # matrix of gradients for each observation
177+
std_errors = np.array([np.sqrt(np.dot(np.dot(g, cov), g)) for g in gradient.to_numpy()])
178+
c = 1.96 # multiplier for confidence interval
179+
upper = np.maximum(0, np.minimum(1, y + std_errors * c))
180+
lower = np.maximum(0, np.minimum(1, y - std_errors * c))
181+
182+
return (lower.iloc[1], y.iloc[1], upper.iloc[1]), (lower.iloc[0], y.iloc[0], upper.iloc[0])
172183

173184
def estimate_ate(self) -> float:
174185
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
175186
by changing the treatment variable from the control value to the treatment value. Here, we actually
176187
calculate the expected outcomes under control and treatment and take one away from the other. This
177188
allows for custom terms to be put in such as squares, inverses, products, etc.
178189
179-
:return: The average treatment effect. Confidence intervals are not yet supported.
190+
:return: The estimated average treatment effect and 95% confidence intervals
180191
"""
181-
control_outcome, treatment_outcome = self.estimate_control_treatment()
192+
(cci_low, control_outcome, cci_high), (tci_low, treatment_outcome, tci_high) = self.estimate_control_treatment()
193+
194+
ci_low = tci_low - cci_high
195+
ci_high = tci_high - cci_low
196+
estimate = treatment_outcome - control_outcome
182197

183-
return treatment_outcome - control_outcome
198+
logger.info(f"Changing {self.treatment} from {self.control_values} to {self.treatment_values} gives an estimated ATE of {ci_low} < {estimate} < {ci_high}")
199+
assert ci_low < estimate < ci_high, f"Expecting {ci_low} < {estimate} < {ci_high}"
200+
201+
return estimate, (ci_low, ci_high)
184202

185203
def estimate_risk_ratio(self) -> float:
186204
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
187205
by changing the treatment variable from the control value to the treatment value. Here, we actually
188206
calculate the expected outcomes under control and treatment and divide one by the other. This
189207
allows for custom terms to be put in such as squares, inverses, products, etc.
190208
191-
:return: The average treatment effect. Confidence intervals are not yet supported.
209+
:return: The estimated risk ratio and 95% confidence intervals.
192210
"""
193-
control_outcome, treatment_outcome = self.estimate_control_treatment()
211+
(cci_low, control_outcome, cci_high), (tci_low, treatment_outcome, tci_high) = self.estimate_control_treatment()
212+
213+
ci_low = tci_low / cci_high
214+
ci_high = tci_high / cci_low
194215

195-
return treatment_outcome / control_outcome
216+
return treatment_outcome / control_outcome, (ci_low, ci_high)
196217

197218
def estimate_unit_odds_ratio(self) -> float:
198219
"""Estimate the odds ratio of increasing the treatment by one. In logistic regression, this corresponds to the
@@ -214,7 +235,7 @@ def __init__(
214235
treatment: tuple,
215236
treatment_values: float,
216237
control_values: float,
217-
adjustment_set: set,
238+
adjustment_set: list[float],
218239
outcome: tuple,
219240
df: pd.DataFrame = None,
220241
effect_modifiers: dict[Variable:Any] = None,
@@ -332,7 +353,8 @@ def estimate_ate(self) -> tuple[float, list[float, float], float]:
332353
def estimate_control_treatment(self) -> tuple[pd.Series, pd.Series]:
333354
"""Estimate the outcomes under control and treatment.
334355
335-
:return: The average treatment effect and the 95% Wald confidence intervals.
356+
:return: The estimated outcome under control and treatment in the form
357+
(control_outcome, treatment_outcome).
336358
"""
337359
model = self._run_linear_regression()
338360
self.model = model

0 commit comments

Comments
 (0)