@@ -179,7 +179,7 @@ def estimate(self, data: pd.DataFrame, adjustment_config: dict = None) -> Regres
179
179
# x = x[model.params.index]
180
180
return model .predict (x )
181
181
182
- def estimate_control_treatment (self , bootstrap_size = 100 , adjustment_config = None ) -> tuple [pd .Series , pd .Series ]:
182
+ def estimate_control_treatment (self , bootstrap_size , adjustment_config ) -> tuple [pd .Series , pd .Series ]:
183
183
"""Estimate the outcomes under control and treatment.
184
184
185
185
:return: The estimated control and treatment values and their confidence
@@ -215,14 +215,18 @@ def estimate_control_treatment(self, bootstrap_size=100, adjustment_config=None)
215
215
216
216
return (y .iloc [1 ], np .array (control )), (y .iloc [0 ], np .array (treatment ))
217
217
218
- def estimate_ate (self , bootstrap_size = 100 , adjustment_config = None ) -> float :
218
+ def estimate_ate (self , estimator_params : dict = None ) -> float :
219
219
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
220
220
by changing the treatment variable from the control value to the treatment value. Here, we actually
221
221
calculate the expected outcomes under control and treatment and take one away from the other. This
222
222
allows for custom terms to be put in such as squares, inverses, products, etc.
223
223
224
224
:return: The estimated average treatment effect and 95% confidence intervals
225
225
"""
226
+ if estimator_params is None :
227
+ estimator_params = {}
228
+ bootstrap_size = estimator_params .get ("bootstrap_size" , 100 )
229
+ adjustment_config = estimator_params .get ("adjustment_config" , None )
226
230
(control_outcome , control_bootstraps ), (
227
231
treatment_outcome ,
228
232
treatment_bootstraps ,
@@ -245,14 +249,18 @@ def estimate_ate(self, bootstrap_size=100, adjustment_config=None) -> float:
245
249
246
250
return estimate , (ci_low , ci_high )
247
251
248
- def estimate_risk_ratio (self , bootstrap_size = 100 , adjustment_config = None ) -> float :
252
+ def estimate_risk_ratio (self , estimator_params : dict = None ) -> float :
249
253
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
250
254
by changing the treatment variable from the control value to the treatment value. Here, we actually
251
255
calculate the expected outcomes under control and treatment and divide one by the other. This
252
256
allows for custom terms to be put in such as squares, inverses, products, etc.
253
257
254
258
:return: The estimated risk ratio and 95% confidence intervals.
255
259
"""
260
+ if estimator_params is None :
261
+ estimator_params = {}
262
+ bootstrap_size = estimator_params .get ("bootstrap_size" , 100 )
263
+ adjustment_config = estimator_params .get ("adjustment_config" , None )
256
264
(control_outcome , control_bootstraps ), (
257
265
treatment_outcome ,
258
266
treatment_bootstraps ,
0 commit comments