Skip to content

Commit 47974a8

Browse files
committed
create method to more robustly grab the treatment coefficient name
1 parent e885cfe commit 47974a8

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

causalpy/pymc_experiments.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -646,11 +646,8 @@ def __init__(
646646
self.prediction_model.fit(X=self.X, y=self.y, coords=COORDS)
647647

648648
# Evaluate causal impact as equal to the trestment effect
649-
# TODO: Grabbing the appropriate name for the group parameter (to describe
650-
# the treatment effect) is brittle and will likely break.
651-
treatment_effect_parameter = f"C({self.group_variable_name})[T.1]"
652649
self.causal_impact = self.prediction_model.idata.posterior["beta"].sel(
653-
{"coeffs": treatment_effect_parameter}
650+
{"coeffs": self._get_treatment_effect_coeff()}
654651
)
655652

656653
# ================================================================
@@ -693,3 +690,13 @@ def summary(self):
693690
# TODO: extra experiment specific outputs here
694691
print(self._causal_impact_summary_stat())
695692
self.print_coefficients()
693+
694+
def _get_treatment_effect_coeff(self) -> str:
695+
"""Find the beta regression coefficient corresponding to the
696+
group (i.e. treatment) effect.
697+
For example if self.group_variable_name is 'group' and
698+
the labels are `['Intercept', 'C(group)[T.1]', 'pre']`
699+
then we want `C(group)[T.1]`.
700+
"""
701+
mask = [self.group_variable_name in label for label in self.labels]
702+
return self.labels[np.argmax(mask)]

0 commit comments

Comments
 (0)