@@ -374,7 +374,6 @@ def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd
374
374
"""
375
375
if adjustment_config is None :
376
376
adjustment_config = {}
377
-
378
377
model = self ._run_linear_regression ()
379
378
380
379
x = pd .DataFrame (columns = self .df .columns )
@@ -393,26 +392,33 @@ def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd
393
392
394
393
return y .iloc [1 ], y .iloc [0 ]
395
394
396
- def estimate_risk_ratio (self ) -> tuple [float , list [float , float ]]:
395
+ def estimate_risk_ratio (self , estimator_params : dict = None ) -> tuple [float , list [float , float ]]:
397
396
"""Estimate the risk_ratio effect of the treatment on the outcome. That is, the change in outcome caused
398
397
by changing the treatment variable from the control value to the treatment value.
399
398
400
399
:return: The average treatment effect and the 95% Wald confidence intervals.
401
400
"""
402
- control_outcome , treatment_outcome = self .estimate_control_treatment ()
401
+ if estimator_params is None :
402
+ estimator_params = {}
403
+ adjustment_config = estimator_params .get ("adjustment_config" , None )
404
+
405
+ control_outcome , treatment_outcome = self .estimate_control_treatment (adjustment_config = adjustment_config )
403
406
ci_low = treatment_outcome ["mean_ci_lower" ] / control_outcome ["mean_ci_upper" ]
404
407
ci_high = treatment_outcome ["mean_ci_upper" ] / control_outcome ["mean_ci_lower" ]
405
408
406
409
return (treatment_outcome ["mean" ] / control_outcome ["mean" ]), [ci_low , ci_high ]
407
410
408
- def estimate_ate_calculated (self , adjustment_config : dict = None ) -> tuple [float , list [float , float ]]:
411
+ def estimate_ate_calculated (self , estimator_params : dict = None ) -> tuple [float , list [float , float ]]:
409
412
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
410
413
by changing the treatment variable from the control value to the treatment value. Here, we actually
411
414
calculate the expected outcomes under control and treatment and divide one by the other. This
412
415
allows for custom terms to be put in such as squares, inverses, products, etc.
413
416
414
417
:return: The average treatment effect and the 95% Wald confidence intervals.
415
418
"""
419
+ if estimator_params is None :
420
+ estimator_params = {}
421
+ adjustment_config = estimator_params .get ("adjustment_config" , None )
416
422
control_outcome , treatment_outcome = self .estimate_control_treatment (adjustment_config = adjustment_config )
417
423
ci_low = treatment_outcome ["mean_ci_lower" ] - control_outcome ["mean_ci_upper" ]
418
424
ci_high = treatment_outcome ["mean_ci_upper" ] - control_outcome ["mean_ci_lower" ]
0 commit comments