Skip to content

Commit ef4df17

Browse files
committed
LR Estimate prediction
1 parent d0c9ee2 commit ef4df17

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

causal_testing/estimation/linear_regression_estimator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,24 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[pd.Se
177177
ci_high = pd.Series(treatment_outcome["mean_ci_upper"] - control_outcome["mean_ci_lower"])
178178
return pd.Series(treatment_outcome["mean"] - control_outcome["mean"]), [ci_low, ci_high]
179179

180+
def estimate_prediction(self, adjustment_config: dict = None) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
181+
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
182+
by changing the treatment variable from the control value to the treatment value. Here, we actually
183+
calculate the expected outcomes under control and treatment and divide one by the other. This
184+
allows for custom terms to be put in such as squares, inverses, products, etc.
185+
186+
:param: adjustment_config: The configuration of the adjustment set as a dict mapping variable names to
187+
their values. N.B. Every variable in the adjustment set MUST have a value in
188+
order to estimate the outcome under control and treatment.
189+
190+
:return: The average treatment effect and the 95% Wald confidence intervals.
191+
"""
192+
prediction = self._predict(adjustment_config=adjustment_config)
193+
outcome = prediction.iloc[1]
194+
ci_low = pd.Series(outcome["mean_ci_upper"])
195+
ci_high = pd.Series(outcome["mean_ci_lower"])
196+
return pd.Series(outcome["mean"]), [ci_low, ci_high]
197+
180198
def _get_confidence_intervals(self, model, treatment):
181199
confidence_intervals = model.conf_int(alpha=self.alpha, cols=None)
182200
ci_low, ci_high = (

0 commit comments

Comments
 (0)