Skip to content

Commit b3bb0b4

Browse files
add estimator params for linear regression estimate methods
1 parent bf0bbe3 commit b3bb0b4

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

causal_testing/testing/estimators.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,6 @@ def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd
374374
"""
375375
if adjustment_config is None:
376376
adjustment_config = {}
377-
378377
model = self._run_linear_regression()
379378

380379
x = pd.DataFrame(columns=self.df.columns)
@@ -393,26 +392,33 @@ def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd
393392

394393
return y.iloc[1], y.iloc[0]
395394

396-
def estimate_risk_ratio(self) -> tuple[float, list[float, float]]:
395+
def estimate_risk_ratio(self, estimator_params: dict = None) -> tuple[float, list[float, float]]:
397396
"""Estimate the risk_ratio effect of the treatment on the outcome. That is, the change in outcome caused
398397
by changing the treatment variable from the control value to the treatment value.
399398
400399
:return: The average treatment effect and the 95% Wald confidence intervals.
401400
"""
402-
control_outcome, treatment_outcome = self.estimate_control_treatment()
401+
if estimator_params is None:
402+
estimator_params = {}
403+
adjustment_config = estimator_params.get("adjustment_config", None)
404+
405+
control_outcome, treatment_outcome = self.estimate_control_treatment(adjustment_config=adjustment_config)
403406
ci_low = treatment_outcome["mean_ci_lower"] / control_outcome["mean_ci_upper"]
404407
ci_high = treatment_outcome["mean_ci_upper"] / control_outcome["mean_ci_lower"]
405408

406409
return (treatment_outcome["mean"] / control_outcome["mean"]), [ci_low, ci_high]
407410

408-
def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[float, list[float, float]]:
411+
def estimate_ate_calculated(self, estimator_params: dict = None) -> tuple[float, list[float, float]]:
409412
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
410413
by changing the treatment variable from the control value to the treatment value. Here, we actually
411414
calculate the expected outcomes under control and treatment and divide one by the other. This
412415
allows for custom terms to be put in such as squares, inverses, products, etc.
413416
414417
:return: The average treatment effect and the 95% Wald confidence intervals.
415418
"""
419+
if estimator_params is None:
420+
estimator_params = {}
421+
adjustment_config = estimator_params.get("adjustment_config", None)
416422
control_outcome, treatment_outcome = self.estimate_control_treatment(adjustment_config=adjustment_config)
417423
ci_low = treatment_outcome["mean_ci_lower"] - control_outcome["mean_ci_upper"]
418424
ci_high = treatment_outcome["mean_ci_upper"] - control_outcome["mean_ci_lower"]

0 commit comments

Comments
 (0)