@@ -646,11 +646,8 @@ def __init__(
646
646
self .prediction_model .fit (X = self .X , y = self .y , coords = COORDS )
647
647
648
648
# Evaluate causal impact as equal to the trestment effect
649
- # TODO: Grabbing the appropriate name for the group parameter (to describe
650
- # the treatment effect) is brittle and will likely break.
651
- treatment_effect_parameter = f"C({ self .group_variable_name } )[T.1]"
652
649
self .causal_impact = self .prediction_model .idata .posterior ["beta" ].sel (
653
- {"coeffs" : treatment_effect_parameter }
650
+ {"coeffs" : self . _get_treatment_effect_coeff () }
654
651
)
655
652
656
653
# ================================================================
@@ -693,3 +690,13 @@ def summary(self):
693
690
# TODO: extra experiment specific outputs here
694
691
print (self ._causal_impact_summary_stat ())
695
692
self .print_coefficients ()
693
+
694
+ def _get_treatment_effect_coeff (self ) -> str :
695
+ """Find the beta regression coefficient corresponding to the
696
+ group (i.e. treatment) effect.
697
+ For example if self.group_variable_name is 'group' and
698
+ the labels are `['Intercept', 'C(group)[T.1]', 'pre']`
699
+ then we want `C(group)[T.1]`.
700
+ """
701
+ mask = [self .group_variable_name in label for label in self .labels ]
702
+ return self .labels [np .argmax (mask )]
0 commit comments