@@ -614,6 +614,7 @@ def __init__(
614
614
data : pd .DataFrame ,
615
615
formula : str ,
616
616
group_variable_name : str ,
617
+ pretreatment_variable_name : str ,
617
618
prediction_model = None ,
618
619
** kwargs ,
619
620
):
@@ -622,6 +623,7 @@ def __init__(
622
623
self .expt_type = "Pretest/posttest Nonequivalent Group Design"
623
624
self .formula = formula
624
625
self .group_variable_name = group_variable_name
626
+ self .pretreatment_variable_name = pretreatment_variable_name
625
627
626
628
y , X = dmatrices (formula , self .data )
627
629
self ._y_design_info = y .design_info
@@ -645,6 +647,33 @@ def __init__(
645
647
COORDS = {"coeffs" : self .labels , "obs_indx" : np .arange (self .X .shape [0 ])}
646
648
self .prediction_model .fit (X = self .X , y = self .y , coords = COORDS )
647
649
650
+ # Calculate the posterior predictive for the treatment and control for an
651
+ # interpolated set of pretest values
652
+ # get the model predictions of the observed data
653
+ self .pred_xi = np .linspace (
654
+ np .min (self .data [self .pretreatment_variable_name ]),
655
+ np .max (self .data [self .pretreatment_variable_name ]),
656
+ 200 ,
657
+ )
658
+ # untreated
659
+ x_pred_untreated = pd .DataFrame (
660
+ {
661
+ self .pretreatment_variable_name : self .pred_xi ,
662
+ self .group_variable_name : np .zeros (self .pred_xi .shape ),
663
+ }
664
+ )
665
+ (new_x ,) = build_design_matrices ([self ._x_design_info ], x_pred_untreated )
666
+ self .pred_untreated = self .prediction_model .predict (X = np .asarray (new_x ))
667
+ # treated
668
+ x_pred_untreated = pd .DataFrame (
669
+ {
670
+ self .pretreatment_variable_name : self .pred_xi ,
671
+ self .group_variable_name : np .ones (self .pred_xi .shape ),
672
+ }
673
+ )
674
+ (new_x ,) = build_design_matrices ([self ._x_design_info ], x_pred_untreated )
675
+ self .pred_treated = self .prediction_model .predict (X = np .asarray (new_x ))
676
+
648
677
# Evaluate causal impact as equal to the trestment effect
649
678
self .causal_impact = self .prediction_model .idata .posterior ["beta" ].sel (
650
679
{"coeffs" : self ._get_treatment_effect_coeff ()}
@@ -668,6 +697,23 @@ def plot(self):
668
697
ax = ax [0 ],
669
698
)
670
699
ax [0 ].set (xlabel = "Pretest" , ylabel = "Posttest" )
700
+
701
+ # plot posterior predictive of untreated
702
+ plot_xY (
703
+ self .pred_xi ,
704
+ self .pred_untreated ["posterior_predictive" ].y_hat ,
705
+ ax = ax [0 ],
706
+ plot_hdi_kwargs = {"color" : "C0" },
707
+ )
708
+
709
+ # plot posterior predictive of treated
710
+ plot_xY (
711
+ self .pred_xi ,
712
+ self .pred_treated ["posterior_predictive" ].y_hat ,
713
+ ax = ax [0 ],
714
+ plot_hdi_kwargs = {"color" : "C1" },
715
+ )
716
+
671
717
ax [0 ].legend (fontsize = LEGEND_FONT_SIZE )
672
718
673
719
# Plot estimated caual impact / treatment effect
@@ -698,5 +744,7 @@ def _get_treatment_effect_coeff(self) -> str:
698
744
the labels are `['Intercept', 'C(group)[T.1]', 'pre']`
699
745
then we want `C(group)[T.1]`.
700
746
"""
701
- mask = [self .group_variable_name in label for label in self .labels ]
702
- return self .labels [np .argmax (mask )]
747
+ for label in self .labels :
748
+ if ("group" in label ) & (":" not in label ):
749
+ return label
750
+ # TODO: raise an exception if we fail to find the coefficient we want
0 commit comments