@@ -177,6 +177,24 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[pd.Se
177
177
ci_high = pd .Series (treatment_outcome ["mean_ci_upper" ] - control_outcome ["mean_ci_lower" ])
178
178
return pd .Series (treatment_outcome ["mean" ] - control_outcome ["mean" ]), [ci_low , ci_high ]
179
179
180
+ def estimate_prediction (self , adjustment_config : dict = None ) -> tuple [pd .Series , list [pd .Series , pd .Series ]]:
181
+ """Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
182
+ by changing the treatment variable from the control value to the treatment value. Here, we actually
183
+ calculate the expected outcomes under control and treatment and divide one by the other. This
184
+ allows for custom terms to be put in such as squares, inverses, products, etc.
185
+
186
+ :param: adjustment_config: The configuration of the adjustment set as a dict mapping variable names to
187
+ their values. N.B. Every variable in the adjustment set MUST have a value in
188
+ order to estimate the outcome under control and treatment.
189
+
190
+ :return: The average treatment effect and the 95% Wald confidence intervals.
191
+ """
192
+ prediction = self ._predict (adjustment_config = adjustment_config )
193
+ outcome = prediction .iloc [1 ]
194
+ ci_low = pd .Series (outcome ["mean_ci_upper" ])
195
+ ci_high = pd .Series (outcome ["mean_ci_lower" ])
196
+ return pd .Series (outcome ["mean" ]), [ci_low , ci_high ]
197
+
180
198
def _get_confidence_intervals (self , model , treatment ):
181
199
confidence_intervals = model .conf_int (alpha = self .alpha , cols = None )
182
200
ci_low , ci_high = (
0 commit comments