Skip to content

Commit c777503

Browse files
Update return typings
1 parent ea8c273 commit c777503

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

causal_testing/testing/estimators.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def add_modelling_assumptions(self):
343343
"do not need to be linear."
344344
)
345345

346-
def estimate_coefficient(self) -> float:
346+
def estimate_coefficient(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
347347
"""Estimate the unit average treatment effect of the treatment on the outcome. That is, the change in outcome
348348
caused by a unit change in treatment.
349349
@@ -364,7 +364,7 @@ def estimate_coefficient(self) -> float:
364364
[ci_low, ci_high] = self._get_confidence_intervals(model, treatment)
365365
return unit_effect, [ci_low, ci_high]
366366

367-
def estimate_ate(self) -> tuple[float, list[float, float], float]:
367+
def estimate_ate(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
368368
"""Estimate the average treatment effect of the treatment on the outcome. That is, the change in outcome caused
369369
by changing the treatment variable from the control value to the treatment value.
370370
@@ -413,7 +413,7 @@ def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd
413413

414414
return y.iloc[1], y.iloc[0]
415415

416-
def estimate_risk_ratio(self, adjustment_config: dict = None) -> tuple[float, list[float, float]]:
416+
def estimate_risk_ratio(self, adjustment_config: dict = None) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
417417
"""Estimate the risk_ratio effect of the treatment on the outcome. That is, the change in outcome caused
418418
by changing the treatment variable from the control value to the treatment value.
419419
@@ -426,9 +426,7 @@ def estimate_risk_ratio(self, adjustment_config: dict = None) -> tuple[float, li
426426
ci_high = pd.Series(treatment_outcome["mean_ci_upper"] / control_outcome["mean_ci_lower"])
427427
return pd.Series(treatment_outcome["mean"] / control_outcome["mean"]), [ci_low, ci_high]
428428

429-
return (treatment_outcome["mean"] / control_outcome["mean"]), [ci_low, ci_high]
430-
431-
def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[float, list[float, float]]:
429+
def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
432430
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
433431
by changing the treatment variable from the control value to the treatment value. Here, we actually
434432
calculate the expected outcomes under control and treatment and divide one by the other. This

0 commit comments

Comments
 (0)