@@ -161,13 +161,14 @@ def estimate(self, data: pd.DataFrame, adjustment_config: dict = None) -> Regres
161
161
# x = x[model.params.index]
162
162
return model .predict (x )
163
163
164
- def estimate_control_treatment (self , bootstrap_size , adjustment_config ) -> tuple [pd .Series , pd .Series ]:
164
+ def estimate_control_treatment (self , adjustment_config : dict = None , bootstrap_size : int = 100 ) -> tuple [pd .Series , pd .Series ]:
165
165
"""Estimate the outcomes under control and treatment.
166
166
167
167
:return: The estimated control and treatment values and their confidence
168
168
intervals in the form ((ci_low, control, ci_high), (ci_low, treatment, ci_high)).
169
169
"""
170
-
170
+ if adjustment_config is None :
171
+ adjustment_config = {}
171
172
y = self .estimate (self .df , adjustment_config = adjustment_config )
172
173
173
174
try :
@@ -197,18 +198,16 @@ def estimate_control_treatment(self, bootstrap_size, adjustment_config) -> tuple
197
198
198
199
return (y .iloc [1 ], np .array (control )), (y .iloc [0 ], np .array (treatment ))
199
200
200
- def estimate_ate (self , estimator_params : dict = None ) -> float :
201
+ def estimate_ate (self , adjustment_config : dict = None , bootstrap_size : int = 100 ) -> float :
201
202
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
202
203
by changing the treatment variable from the control value to the treatment value. Here, we actually
203
204
calculate the expected outcomes under control and treatment and take one away from the other. This
204
205
allows for custom terms to be put in such as squares, inverses, products, etc.
205
206
206
207
:return: The estimated average treatment effect and 95% confidence intervals
207
208
"""
208
- if estimator_params is None :
209
- estimator_params = {}
210
- bootstrap_size = estimator_params .get ("bootstrap_size" , 100 )
211
- adjustment_config = estimator_params .get ("adjustment_config" , None )
209
+ if adjustment_config is None :
210
+ adjustment_config = {}
212
211
(control_outcome , control_bootstraps ), (
213
212
treatment_outcome ,
214
213
treatment_bootstraps ,
@@ -231,18 +230,16 @@ def estimate_ate(self, estimator_params: dict = None) -> float:
231
230
232
231
return estimate , (ci_low , ci_high )
233
232
234
- def estimate_risk_ratio (self , estimator_params : dict = None ) -> float :
233
+ def estimate_risk_ratio (self , adjustment_config : dict = None , bootstrap_size : int = 100 ) -> float :
235
234
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
236
235
by changing the treatment variable from the control value to the treatment value. Here, we actually
237
236
calculate the expected outcomes under control and treatment and divide one by the other. This
238
237
allows for custom terms to be put in such as squares, inverses, products, etc.
239
238
240
239
:return: The estimated risk ratio and 95% confidence intervals.
241
240
"""
242
- if estimator_params is None :
243
- estimator_params = {}
244
- bootstrap_size = estimator_params .get ("bootstrap_size" , 100 )
245
- adjustment_config = estimator_params .get ("adjustment_config" , None )
241
+ if adjustment_config is None :
242
+ adjustment_config = {}
246
243
(control_outcome , control_bootstraps ), (
247
244
treatment_outcome ,
248
245
treatment_bootstraps ,
@@ -392,33 +389,30 @@ def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd
392
389
393
390
return y .iloc [1 ], y .iloc [0 ]
394
391
395
- def estimate_risk_ratio (self , estimator_params : dict = None ) -> tuple [float , list [float , float ]]:
392
+ def estimate_risk_ratio (self , adjustment_config : dict = None ) -> tuple [float , list [float , float ]]:
396
393
"""Estimate the risk_ratio effect of the treatment on the outcome. That is, the change in outcome caused
397
394
by changing the treatment variable from the control value to the treatment value.
398
395
399
396
:return: The average treatment effect and the 95% Wald confidence intervals.
400
397
"""
401
- if estimator_params is None :
402
- estimator_params = {}
403
- adjustment_config = estimator_params .get ("adjustment_config" , None )
404
-
398
+ if adjustment_config is None :
399
+ adjustment_config = {}
405
400
control_outcome , treatment_outcome = self .estimate_control_treatment (adjustment_config = adjustment_config )
406
401
ci_low = treatment_outcome ["mean_ci_lower" ] / control_outcome ["mean_ci_upper" ]
407
402
ci_high = treatment_outcome ["mean_ci_upper" ] / control_outcome ["mean_ci_lower" ]
408
403
409
404
return (treatment_outcome ["mean" ] / control_outcome ["mean" ]), [ci_low , ci_high ]
410
405
411
- def estimate_ate_calculated (self , estimator_params : dict = None ) -> tuple [float , list [float , float ]]:
406
+ def estimate_ate_calculated (self , adjustment_config : dict = None ) -> tuple [float , list [float , float ]]:
412
407
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
413
408
by changing the treatment variable from the control value to the treatment value. Here, we actually
414
409
calculate the expected outcomes under control and treatment and divide one by the other. This
415
410
allows for custom terms to be put in such as squares, inverses, products, etc.
416
411
417
412
:return: The average treatment effect and the 95% Wald confidence intervals.
418
413
"""
419
- if estimator_params is None :
420
- estimator_params = {}
421
- adjustment_config = estimator_params .get ("adjustment_config" , None )
414
+ if adjustment_config is None :
415
+ adjustment_config = {}
422
416
control_outcome , treatment_outcome = self .estimate_control_treatment (adjustment_config = adjustment_config )
423
417
ci_low = treatment_outcome ["mean_ci_lower" ] - control_outcome ["mean_ci_upper" ]
424
418
ci_high = treatment_outcome ["mean_ci_upper" ] - control_outcome ["mean_ci_lower" ]
0 commit comments