Skip to content

Commit fd7f79d

Browse files
Refactor other estimator classes to return pd.Series
1 parent c777503 commit fd7f79d

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

causal_testing/testing/estimators.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def __init__(
492492
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
493493
self.formula = f"{outcome} ~ cr({'+'.join(terms)}, df={basis})"
494494

495-
def estimate_ate_calculated(self, adjustment_config: dict = None) -> float:
495+
def estimate_ate_calculated(self, adjustment_config: dict = None) -> pd.Series:
496496
model = self._run_linear_regression()
497497

498498
x = {"Intercept": 1, self.treatment: self.treatment_value}
@@ -508,7 +508,7 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> float:
508508
x[self.treatment] = self.control_value
509509
control = model.predict(x).iloc[0]
510510

511-
return treatment - control
511+
return pd.Series(treatment - control)
512512

513513

514514
class InstrumentalVariableEstimator(Estimator):
@@ -564,7 +564,7 @@ def add_modelling_assumptions(self):
564564
"""
565565
)
566566

567-
def estimate_iv_coefficient(self, df):
567+
def estimate_iv_coefficient(self, df) -> float:
568568
"""
569569
Estimate the linear regression coefficient of the treatment on the
570570
outcome.
@@ -578,7 +578,7 @@ def estimate_iv_coefficient(self, df):
578578
# Estimate the coefficient of I on X by cancelling
579579
return ab / a
580580

581-
def estimate_coefficient(self, bootstrap_size=100):
581+
def estimate_coefficient(self, bootstrap_size=100) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
582582
"""
583583
Estimate the unit ate (i.e. coefficient) of the treatment on the
584584
outcome.
@@ -587,10 +587,10 @@ def estimate_coefficient(self, bootstrap_size=100):
587587
[self.estimate_iv_coefficient(self.df.sample(len(self.df), replace=True)) for _ in range(bootstrap_size)]
588588
)
589589
bound = ceil((bootstrap_size * self.alpha) / 2)
590-
ci_low = bootstraps[bound]
591-
ci_high = bootstraps[bootstrap_size - bound]
590+
ci_low = pd.Series(bootstraps[bound])
591+
ci_high = pd.Series(bootstraps[bootstrap_size - bound])
592592

593-
return self.estimate_iv_coefficient(self.df), (ci_low, ci_high)
593+
return pd.Series(self.estimate_iv_coefficient(self.df)), [ci_low, ci_high]
594594

595595

596596
class CausalForestEstimator(Estimator):
@@ -607,7 +607,7 @@ def add_modelling_assumptions(self):
607607
"""
608608
self.modelling_assumptions.append("Non-parametric estimator: no restrictions imposed on the data.")
609609

610-
def estimate_ate(self) -> float:
610+
def estimate_ate(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
611611
"""Estimate the average treatment effect.
612612
613613
:return ate, confidence_intervals: The average treatment effect and 95% confidence intervals.
@@ -635,9 +635,9 @@ def estimate_ate(self) -> float:
635635
model.fit(outcome_df, treatment_df, X=effect_modifier_df, W=confounders_df)
636636

637637
# Obtain the ATE and 95% confidence intervals
638-
ate = model.ate(effect_modifier_df, T0=self.control_value, T1=self.treatment_value)
638+
ate = pd.Series(model.ate(effect_modifier_df, T0=self.control_value, T1=self.treatment_value))
639639
ate_interval = model.ate_interval(effect_modifier_df, T0=self.control_value, T1=self.treatment_value)
640-
ci_low, ci_high = ate_interval[0], ate_interval[1]
640+
ci_low, ci_high = pd.Series(ate_interval[0]), pd.Series(ate_interval[1])
641641
return ate, [ci_low, ci_high]
642642

643643
def estimate_cates(self) -> pd.DataFrame:

0 commit comments

Comments
 (0)