@@ -181,7 +181,7 @@ def estimate(self, data: pd.DataFrame, adjustment_config: dict = None) -> Regres
181
181
# x = x[model.params.index]
182
182
return model .predict (x )
183
183
184
- def estimate_control_treatment (self , bootstrap_size = 100 , adjustment_config = None ) -> tuple [pd .Series , pd .Series ]:
184
+ def estimate_control_treatment (self , bootstrap_size , adjustment_config ) -> tuple [pd .Series , pd .Series ]:
185
185
"""Estimate the outcomes under control and treatment.
186
186
187
187
:return: The estimated control and treatment values and their confidence
@@ -217,14 +217,18 @@ def estimate_control_treatment(self, bootstrap_size=100, adjustment_config=None)
217
217
218
218
return (y .iloc [1 ], np .array (control )), (y .iloc [0 ], np .array (treatment ))
219
219
220
- def estimate_ate (self , bootstrap_size = 100 , adjustment_config = None ) -> float :
220
+ def estimate_ate (self , estimator_params : dict = None ) -> float :
221
221
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
222
222
by changing the treatment variable from the control value to the treatment value. Here, we actually
223
223
calculate the expected outcomes under control and treatment and take one away from the other. This
224
224
allows for custom terms to be put in such as squares, inverses, products, etc.
225
225
226
226
:return: The estimated average treatment effect and 95% confidence intervals
227
227
"""
228
+ if estimator_params is None :
229
+ estimator_params = {}
230
+ bootstrap_size = estimator_params .get ("bootstrap_size" , 100 )
231
+ adjustment_config = estimator_params .get ("adjustment_config" , None )
228
232
(control_outcome , control_bootstraps ), (
229
233
treatment_outcome ,
230
234
treatment_bootstraps ,
@@ -247,14 +251,18 @@ def estimate_ate(self, bootstrap_size=100, adjustment_config=None) -> float:
247
251
248
252
return estimate , (ci_low , ci_high )
249
253
250
- def estimate_risk_ratio (self , bootstrap_size = 100 , adjustment_config = None ) -> float :
254
+ def estimate_risk_ratio (self , estimator_params : dict = None ) -> float :
251
255
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
252
256
by changing the treatment variable from the control value to the treatment value. Here, we actually
253
257
calculate the expected outcomes under control and treatment and divide one by the other. This
254
258
allows for custom terms to be put in such as squares, inverses, products, etc.
255
259
256
260
:return: The estimated risk ratio and 95% confidence intervals.
257
261
"""
262
+ if estimator_params is None :
263
+ estimator_params = {}
264
+ bootstrap_size = estimator_params .get ("bootstrap_size" , 100 )
265
+ adjustment_config = estimator_params .get ("adjustment_config" , None )
258
266
(control_outcome , control_bootstraps ), (
259
267
treatment_outcome ,
260
268
treatment_bootstraps ,
0 commit comments