@@ -99,13 +99,13 @@ def __init__(
99
99
# causal impact pre (ie the residuals of the model fit to observed)
100
100
pre_data = xr .DataArray (self .pre_y [:, 0 ], dims = ["obs_ind" ])
101
101
self .pre_impact = (
102
- pre_data - self .pre_pred ["posterior_predictive" ].y_hat
102
+ pre_data - self .pre_pred ["posterior_predictive" ].mu
103
103
).transpose (..., "obs_ind" )
104
104
105
105
# causal impact post (ie the residuals of the model fit to observed)
106
106
post_data = xr .DataArray (self .post_y [:, 0 ], dims = ["obs_ind" ])
107
107
self .post_impact = (
108
- post_data - self .post_pred ["posterior_predictive" ].y_hat
108
+ post_data - self .post_pred ["posterior_predictive" ].mu
109
109
).transpose (..., "obs_ind" )
110
110
111
111
# cumulative impact post
@@ -118,31 +118,47 @@ def plot(self):
118
118
119
119
# TOP PLOT --------------------------------------------------
120
120
# pre-intervention period
121
- plot_xY (
121
+ h_line , h_patch = plot_xY (
122
122
self .datapre .index ,
123
- self .pre_pred ["posterior_predictive" ].y_hat ,
123
+ self .pre_pred ["posterior_predictive" ].mu ,
124
124
ax = ax [0 ],
125
+ include_label = False ,
126
+ plot_hdi_kwargs = {"color" : "C0" },
125
127
)
126
- ax [0 ].plot (self .datapre .index , self .pre_y , "k." , label = "Observations" )
128
+ handles = [(h_line , h_patch )]
129
+ labels = ["Pre-intervention period" ]
130
+
131
+ (h ,) = ax [0 ].plot (self .datapre .index , self .pre_y , "k." , label = "Observations" )
132
+ handles .append (h )
133
+ labels .append ("Observations" )
134
+
127
135
# post intervention period
128
- plot_xY (
136
+ h_line , h_patch = plot_xY (
129
137
self .datapost .index ,
130
- self .post_pred ["posterior_predictive" ].y_hat ,
138
+ self .post_pred ["posterior_predictive" ].mu ,
131
139
ax = ax [0 ],
132
140
include_label = False ,
141
+ # label="Synthetic control",
142
+ plot_hdi_kwargs = {"color" : "C1" },
133
143
)
144
+ handles .append ((h_line , h_patch ))
145
+ labels .append ("Synthetic control" )
146
+
134
147
ax [0 ].plot (self .datapost .index , self .post_y , "k." )
135
148
# Shaded causal effect
136
- ax [0 ].fill_between (
149
+ h = ax [0 ].fill_between (
137
150
self .datapost .index ,
138
151
y1 = az .extract (
139
- self .post_pred , group = "posterior_predictive" , var_names = "y_hat "
152
+ self .post_pred , group = "posterior_predictive" , var_names = "mu "
140
153
).mean ("sample" ),
141
154
y2 = np .squeeze (self .post_y ),
142
- color = "C0 " ,
155
+ color = "C2 " ,
143
156
alpha = 0.25 ,
144
- label = "Causal impact" ,
157
+ # label="Causal impact",
145
158
)
159
+ handles .append (h )
160
+ labels .append ("Causal impact" )
161
+
146
162
ax [0 ].set (
147
163
title = f"""
148
164
Pre-intervention Bayesian $R^2$: { self .score .r2 :.3f}
@@ -155,30 +171,34 @@ def plot(self):
155
171
self .datapre .index ,
156
172
self .pre_impact ,
157
173
ax = ax [1 ],
174
+ include_label = False ,
175
+ plot_hdi_kwargs = {"color" : "C0" },
158
176
)
159
177
plot_xY (
160
178
self .datapost .index ,
161
179
self .post_impact ,
162
180
ax = ax [1 ],
163
181
include_label = False ,
182
+ plot_hdi_kwargs = {"color" : "C1" },
164
183
)
165
184
ax [1 ].axhline (y = 0 , c = "k" )
166
185
ax [1 ].fill_between (
167
186
self .datapost .index ,
168
187
y1 = self .post_impact .mean (["chain" , "draw" ]),
169
- color = "C0 " ,
188
+ color = "C2 " ,
170
189
alpha = 0.25 ,
171
190
label = "Causal impact" ,
172
191
)
173
192
ax [1 ].set (title = "Causal Impact" )
174
193
175
194
# BOTTOM PLOT -----------------------------------------------
176
-
177
195
ax [2 ].set (title = "Cumulative Causal Impact" )
178
196
plot_xY (
179
197
self .datapost .index ,
180
198
self .post_impact_cumulative ,
181
199
ax = ax [2 ],
200
+ include_label = False ,
201
+ plot_hdi_kwargs = {"color" : "C1" },
182
202
)
183
203
ax [2 ].axhline (y = 0 , c = "k" )
184
204
@@ -189,10 +209,14 @@ def plot(self):
189
209
ls = "-" ,
190
210
lw = 3 ,
191
211
color = "r" ,
192
- label = "Treatment time" ,
212
+ # label="Treatment time",
193
213
)
194
214
195
- ax [0 ].legend (fontsize = LEGEND_FONT_SIZE )
215
+ ax [0 ].legend (
216
+ handles = (h_tuple for h_tuple in handles ),
217
+ labels = labels ,
218
+ fontsize = LEGEND_FONT_SIZE ,
219
+ )
196
220
197
221
return (fig , ax )
198
222
@@ -353,39 +377,46 @@ def __init__(
353
377
)
354
378
355
379
def plot (self ):
356
- """Plot the results"""
380
+ """Plot the results.
381
+ Creating the combined mean + HDI legend entries is a bit involved.
382
+ """
357
383
fig , ax = plt .subplots ()
358
384
359
385
# Plot raw data
360
- # NOTE: This will not work when there is just ONE unit in each group
361
- sns .lineplot (
386
+ sns .scatterplot (
362
387
self .data ,
363
388
x = self .time_variable_name ,
364
389
y = self .outcome_variable_name ,
365
390
hue = self .group_variable_name ,
366
- units = "unit" , # NOTE: assumes we have a `unit` predictor variable
367
- estimator = None ,
368
- alpha = 0.5 ,
391
+ alpha = 1 ,
392
+ legend = False ,
393
+ markers = True ,
369
394
ax = ax ,
370
395
)
371
396
372
397
# Plot model fit to control group
373
398
time_points = self .x_pred_control [self .time_variable_name ].values
374
- plot_xY (
399
+ h_line , h_patch = plot_xY (
375
400
time_points ,
376
- self .y_pred_control .posterior_predictive .y_hat ,
401
+ self .y_pred_control .posterior_predictive .mu ,
377
402
ax = ax ,
378
403
plot_hdi_kwargs = {"color" : "C0" },
404
+ label = "Control group" ,
379
405
)
406
+ handles = [(h_line , h_patch )]
407
+ labels = ["Control group" ]
380
408
381
409
# Plot model fit to treatment group
382
410
time_points = self .x_pred_control [self .time_variable_name ].values
383
- plot_xY (
411
+ h_line , h_patch = plot_xY (
384
412
time_points ,
385
- self .y_pred_treatment .posterior_predictive .y_hat ,
413
+ self .y_pred_treatment .posterior_predictive .mu ,
386
414
ax = ax ,
387
415
plot_hdi_kwargs = {"color" : "C1" },
416
+ label = "Treatment group" ,
388
417
)
418
+ handles .append ((h_line , h_patch ))
419
+ labels .append ("Treatment group" )
389
420
390
421
# Plot counterfactual - post-test for treatment group IF no treatment
391
422
# had occurred.
@@ -407,22 +438,30 @@ def plot(self):
407
438
pc .set_edgecolor ("None" )
408
439
pc .set_alpha (0.5 )
409
440
else :
410
- plot_xY (
441
+ h_line , h_patch = plot_xY (
411
442
time_points ,
412
- self .y_pred_counterfactual .posterior_predictive .y_hat ,
443
+ self .y_pred_counterfactual .posterior_predictive .mu ,
413
444
ax = ax ,
414
445
plot_hdi_kwargs = {"color" : "C2" },
446
+ label = "Counterfactual" ,
415
447
)
448
+ handles .append ((h_line , h_patch ))
449
+ labels .append ("Counterfactual" )
416
450
417
451
# arrow to label the causal impact
418
452
self ._plot_causal_impact_arrow (ax )
453
+
419
454
# formatting
420
455
ax .set (
421
456
xticks = self .x_pred_treatment [self .time_variable_name ].values ,
422
457
title = self ._causal_impact_summary_stat (),
423
458
)
424
- ax .legend (fontsize = LEGEND_FONT_SIZE )
425
- return (fig , ax )
459
+ ax .legend (
460
+ handles = (h_tuple for h_tuple in handles ),
461
+ labels = labels ,
462
+ fontsize = LEGEND_FONT_SIZE ,
463
+ )
464
+ return fig , ax
426
465
427
466
def _plot_causal_impact_arrow (self , ax ):
428
467
"""
@@ -582,12 +621,17 @@ def plot(self):
582
621
c = "k" , # hue="treated",
583
622
ax = ax ,
584
623
)
624
+
585
625
# Plot model fit to data
586
- plot_xY (
626
+ h_line , h_patch = plot_xY (
587
627
self .x_pred [self .running_variable_name ],
588
628
self .pred ["posterior_predictive" ].mu ,
589
629
ax = ax ,
630
+ plot_hdi_kwargs = {"color" : "C1" },
590
631
)
632
+ handles = [(h_line , h_patch )]
633
+ labels = ["Posterior mean" ]
634
+
591
635
# create strings to compose title
592
636
title_info = f"{ self .score .r2 :.3f} (std = { self .score .r2_std :.3f} )"
593
637
r2 = f"Bayesian $R^2$ on all data = { title_info } "
@@ -605,7 +649,11 @@ def plot(self):
605
649
color = "r" ,
606
650
label = "treatment threshold" ,
607
651
)
608
- ax .legend (fontsize = LEGEND_FONT_SIZE )
652
+ ax .legend (
653
+ handles = (h_tuple for h_tuple in handles ),
654
+ labels = labels ,
655
+ fontsize = LEGEND_FONT_SIZE ,
656
+ )
609
657
return (fig , ax )
610
658
611
659
def summary (self ):
@@ -710,27 +758,38 @@ def plot(self):
710
758
hue = "group" ,
711
759
alpha = 0.5 ,
712
760
data = self .data ,
761
+ legend = True ,
713
762
ax = ax [0 ],
714
763
)
715
764
ax [0 ].set (xlabel = "Pretest" , ylabel = "Posttest" )
716
765
717
766
# plot posterior predictive of untreated
718
- plot_xY (
767
+ h_line , h_patch = plot_xY (
719
768
self .pred_xi ,
720
- self .pred_untreated ["posterior_predictive" ].y_hat ,
769
+ self .pred_untreated ["posterior_predictive" ].mu ,
721
770
ax = ax [0 ],
722
771
plot_hdi_kwargs = {"color" : "C0" },
772
+ label = "Control group" ,
723
773
)
774
+ handles = [(h_line , h_patch )]
775
+ labels = ["Control group" ]
724
776
725
777
# plot posterior predictive of treated
726
- plot_xY (
778
+ h_line , h_patch = plot_xY (
727
779
self .pred_xi ,
728
- self .pred_treated ["posterior_predictive" ].y_hat ,
780
+ self .pred_treated ["posterior_predictive" ].mu ,
729
781
ax = ax [0 ],
730
782
plot_hdi_kwargs = {"color" : "C1" },
783
+ label = "Treatment group" ,
731
784
)
785
+ handles .append ((h_line , h_patch ))
786
+ labels .append ("Treatment group" )
732
787
733
- ax [0 ].legend (fontsize = LEGEND_FONT_SIZE )
788
+ ax [0 ].legend (
789
+ handles = (h_tuple for h_tuple in handles ),
790
+ labels = labels ,
791
+ fontsize = LEGEND_FONT_SIZE ,
792
+ )
734
793
735
794
# Plot estimated caual impact / treatment effect
736
795
az .plot_posterior (self .causal_impact , ref_val = 0 , ax = ax [1 ])
0 commit comments