@@ -600,3 +600,84 @@ def summary(self):
600
600
f"Discontinuity at threshold = { self .discontinuity_at_threshold .mean ():.2f} "
601
601
)
602
602
self .print_coefficients ()
603
+
604
+
605
+ # =============================================================================
606
+ # =============================================================================
607
+
608
+
609
+ class PrePostNEGD (ExperimentalDesign ):
610
+ """A class to analyse data from pretest/posttest designs"""
611
+
612
+ def __init__ (
613
+ self ,
614
+ data : pd .DataFrame ,
615
+ formula : str ,
616
+ group_variable_name : str ,
617
+ prediction_model = None ,
618
+ ** kwargs ,
619
+ ):
620
+ super ().__init__ (prediction_model = prediction_model , ** kwargs )
621
+ self .data = data
622
+ self .expt_type = "Difference in Differences"
623
+ self .formula = formula
624
+ self .group_variable_name = group_variable_name
625
+
626
+ y , X = dmatrices (formula , self .data )
627
+ self ._y_design_info = y .design_info
628
+ self ._x_design_info = X .design_info
629
+ self .labels = X .design_info .column_names
630
+ self .y , self .X = np .asarray (y ), np .asarray (X )
631
+ self .outcome_variable_name = y .design_info .column_names [0 ]
632
+
633
+ # Input validation ----------------------------------------------------
634
+ # Check that `group_variable_name` has TWO levels, representing the
635
+ # treated/untreated. But it does not matter what the actual names of
636
+ # the levels are.
637
+ assert (
638
+ len (pd .Categorical (self .data [self .group_variable_name ]).categories ) == 2
639
+ ), f"""
640
+ There must be 2 levels of the grouping variable { self .group_variable_name }
641
+ .I.e. the treated and untreated.
642
+ """
643
+
644
+ # fit the model to the observed (pre-intervention) data
645
+ COORDS = {"coeffs" : self .labels , "obs_indx" : np .arange (self .X .shape [0 ])}
646
+ self .prediction_model .fit (X = self .X , y = self .y , coords = COORDS )
647
+ # ================================================================
648
+
649
+ def plot (self ):
650
+ """Plot the results"""
651
+ fig , ax = plt .subplots (
652
+ 2 , 1 , figsize = (7 , 9 ), gridspec_kw = {"height_ratios" : [3 , 1 ]}
653
+ )
654
+
655
+ # Plot raw data
656
+ sns .scatterplot (
657
+ x = "pre" ,
658
+ y = "post" ,
659
+ hue = "group" ,
660
+ alpha = 0.5 ,
661
+ palette = "muted" ,
662
+ data = self .data ,
663
+ ax = ax [0 ],
664
+ )
665
+ ax [0 ].set (xlabel = "Pretest" , ylabel = "Posttest" )
666
+ ax [0 ].legend (fontsize = LEGEND_FONT_SIZE )
667
+
668
+ # Post estimated treatment effect
669
+ # TODO: Grabbing the appropriate name for the group parameter (to describe
670
+ # the treatment effect) is brittle and will likely break.
671
+ treatment_effect_parameter = f"C({ self .group_variable_name } )[T.1]"
672
+ az .plot_posterior (
673
+ self .prediction_model .idata .posterior ["beta" ].sel (
674
+ {"coeffs" : treatment_effect_parameter }
675
+ ),
676
+ ref_val = 0 ,
677
+ ax = ax [1 ],
678
+ )
679
+ ax [1 ].set (title = "Estimated treatment effect" )
680
+ return fig , ax
681
+
682
+ def summary (self ):
683
+ raise NotImplementedError
0 commit comments