@@ -619,7 +619,7 @@ def __init__(
619
619
):
620
620
super ().__init__ (prediction_model = prediction_model , ** kwargs )
621
621
self .data = data
622
- self .expt_type = "Difference in Differences "
622
+ self .expt_type = "Pretest/posttest Nonequivalent Group Design "
623
623
self .formula = formula
624
624
self .group_variable_name = group_variable_name
625
625
@@ -644,6 +644,15 @@ def __init__(
644
644
# fit the model to the observed (pre-intervention) data
645
645
COORDS = {"coeffs" : self .labels , "obs_indx" : np .arange (self .X .shape [0 ])}
646
646
self .prediction_model .fit (X = self .X , y = self .y , coords = COORDS )
647
+
648
+ # Evaluate causal impact as equal to the trestment effect
649
+ # TODO: Grabbing the appropriate name for the group parameter (to describe
650
+ # the treatment effect) is brittle and will likely break.
651
+ treatment_effect_parameter = f"C({ self .group_variable_name } )[T.1]"
652
+ self .causal_impact = self .prediction_model .idata .posterior ["beta" ].sel (
653
+ {"coeffs" : treatment_effect_parameter }
654
+ )
655
+
647
656
# ================================================================
648
657
649
658
def plot (self ):
@@ -658,26 +667,29 @@ def plot(self):
658
667
y = "post" ,
659
668
hue = "group" ,
660
669
alpha = 0.5 ,
661
- palette = "muted" ,
662
670
data = self .data ,
663
671
ax = ax [0 ],
664
672
)
665
673
ax [0 ].set (xlabel = "Pretest" , ylabel = "Posttest" )
666
674
ax [0 ].legend (fontsize = LEGEND_FONT_SIZE )
667
675
668
- # Post estimated treatment effect
669
- # TODO: Grabbing the appropriate name for the group parameter (to describe
670
- # the treatment effect) is brittle and will likely break.
671
- treatment_effect_parameter = f"C({ self .group_variable_name } )[T.1]"
672
- az .plot_posterior (
673
- self .prediction_model .idata .posterior ["beta" ].sel (
674
- {"coeffs" : treatment_effect_parameter }
675
- ),
676
- ref_val = 0 ,
677
- ax = ax [1 ],
678
- )
676
+ # Plot estimated caual impact / treatment effect
677
+ az .plot_posterior (self .causal_impact , ref_val = 0 , ax = ax [1 ])
679
678
ax [1 ].set (title = "Estimated treatment effect" )
680
679
return fig , ax
681
680
681
+ def _causal_impact_summary_stat (self ):
682
+ percentiles = self .causal_impact .quantile ([0.03 , 1 - 0.03 ]).values
683
+ ci = r"$CI_{94\%}$" + f"[{ percentiles [0 ]:.2f} , { percentiles [1 ]:.2f} ]"
684
+ causal_impact = f"{ self .causal_impact .mean ():.2f} , "
685
+ return f"Causal impact = { causal_impact + ci } "
686
+
682
687
def summary (self ):
683
- raise NotImplementedError
688
+ """Print text output summarising the results"""
689
+
690
+ print (f"{ self .expt_type :=^80} " )
691
+ print (f"Formula: { self .formula } " )
692
+ print ("\n Results:" )
693
+ # TODO: extra experiment specific outputs here
694
+ print (self ._causal_impact_summary_stat ())
695
+ self .print_coefficients ()
0 commit comments