@@ -99,13 +99,13 @@ def __init__(
99
99
# causal impact pre (ie the residuals of the model fit to observed)
100
100
pre_data = xr .DataArray (self .pre_y [:, 0 ], dims = ["obs_ind" ])
101
101
self .pre_impact = (
102
- pre_data - self .pre_pred ["posterior_predictive" ].y_hat
102
+ pre_data - self .pre_pred ["posterior_predictive" ].mu
103
103
).transpose (..., "obs_ind" )
104
104
105
105
# causal impact post (ie the residuals of the model fit to observed)
106
106
post_data = xr .DataArray (self .post_y [:, 0 ], dims = ["obs_ind" ])
107
107
self .post_impact = (
108
- post_data - self .post_pred ["posterior_predictive" ].y_hat
108
+ post_data - self .post_pred ["posterior_predictive" ].mu
109
109
).transpose (..., "obs_ind" )
110
110
111
111
# cumulative impact post
@@ -118,31 +118,43 @@ def plot(self):
118
118
119
119
# TOP PLOT --------------------------------------------------
120
120
# pre-intervention period
121
- plot_xY (
121
+ h_line , h_patch = plot_xY (
122
122
self .datapre .index ,
123
- self .pre_pred ["posterior_predictive" ].y_hat ,
123
+ self .pre_pred ["posterior_predictive" ].mu ,
124
124
ax = ax [0 ],
125
+ plot_hdi_kwargs = {"color" : "C0" },
125
126
)
126
- ax [0 ].plot (self .datapre .index , self .pre_y , "k." , label = "Observations" )
127
+ handles = [(h_line , h_patch )]
128
+ labels = ["Pre-intervention period" ]
129
+
130
+ (h ,) = ax [0 ].plot (self .datapre .index , self .pre_y , "k." , label = "Observations" )
131
+ handles .append (h )
132
+ labels .append ("Observations" )
133
+
127
134
# post intervention period
128
- plot_xY (
135
+ h_line , h_patch = plot_xY (
129
136
self .datapost .index ,
130
- self .post_pred ["posterior_predictive" ].y_hat ,
137
+ self .post_pred ["posterior_predictive" ].mu ,
131
138
ax = ax [0 ],
132
- include_label = False ,
139
+ plot_hdi_kwargs = { "color" : "C1" } ,
133
140
)
141
+ handles .append ((h_line , h_patch ))
142
+ labels .append ("Synthetic control" )
143
+
134
144
ax [0 ].plot (self .datapost .index , self .post_y , "k." )
135
145
# Shaded causal effect
136
- ax [0 ].fill_between (
146
+ h = ax [0 ].fill_between (
137
147
self .datapost .index ,
138
148
y1 = az .extract (
139
- self .post_pred , group = "posterior_predictive" , var_names = "y_hat "
149
+ self .post_pred , group = "posterior_predictive" , var_names = "mu "
140
150
).mean ("sample" ),
141
151
y2 = np .squeeze (self .post_y ),
142
152
color = "C0" ,
143
153
alpha = 0.25 ,
144
- label = "Causal impact" ,
145
154
)
155
+ handles .append (h )
156
+ labels .append ("Causal impact" )
157
+
146
158
ax [0 ].set (
147
159
title = f"""
148
160
Pre-intervention Bayesian $R^2$: { self .score .r2 :.3f}
@@ -155,12 +167,13 @@ def plot(self):
155
167
self .datapre .index ,
156
168
self .pre_impact ,
157
169
ax = ax [1 ],
170
+ plot_hdi_kwargs = {"color" : "C0" },
158
171
)
159
172
plot_xY (
160
173
self .datapost .index ,
161
174
self .post_impact ,
162
175
ax = ax [1 ],
163
- include_label = False ,
176
+ plot_hdi_kwargs = { "color" : "C1" } ,
164
177
)
165
178
ax [1 ].axhline (y = 0 , c = "k" )
166
179
ax [1 ].fill_between (
@@ -173,12 +186,12 @@ def plot(self):
173
186
ax [1 ].set (title = "Causal Impact" )
174
187
175
188
# BOTTOM PLOT -----------------------------------------------
176
-
177
189
ax [2 ].set (title = "Cumulative Causal Impact" )
178
190
plot_xY (
179
191
self .datapost .index ,
180
192
self .post_impact_cumulative ,
181
193
ax = ax [2 ],
194
+ plot_hdi_kwargs = {"color" : "C1" },
182
195
)
183
196
ax [2 ].axhline (y = 0 , c = "k" )
184
197
@@ -189,10 +202,13 @@ def plot(self):
189
202
ls = "-" ,
190
203
lw = 3 ,
191
204
color = "r" ,
192
- label = "Treatment time" ,
193
205
)
194
206
195
- ax [0 ].legend (fontsize = LEGEND_FONT_SIZE )
207
+ ax [0 ].legend (
208
+ handles = (h_tuple for h_tuple in handles ),
209
+ labels = labels ,
210
+ fontsize = LEGEND_FONT_SIZE ,
211
+ )
196
212
197
213
return (fig , ax )
198
214
@@ -353,39 +369,46 @@ def __init__(
353
369
)
354
370
355
371
def plot (self ):
356
- """Plot the results"""
372
+ """Plot the results.
373
+ Creating the combined mean + HDI legend entries is a bit involved.
374
+ """
357
375
fig , ax = plt .subplots ()
358
376
359
377
# Plot raw data
360
- # NOTE: This will not work when there is just ONE unit in each group
361
- sns .lineplot (
378
+ sns .scatterplot (
362
379
self .data ,
363
380
x = self .time_variable_name ,
364
381
y = self .outcome_variable_name ,
365
382
hue = self .group_variable_name ,
366
- units = "unit" , # NOTE: assumes we have a `unit` predictor variable
367
- estimator = None ,
368
- alpha = 0.5 ,
383
+ alpha = 1 ,
384
+ legend = False ,
385
+ markers = True ,
369
386
ax = ax ,
370
387
)
371
388
372
389
# Plot model fit to control group
373
390
time_points = self .x_pred_control [self .time_variable_name ].values
374
- plot_xY (
391
+ h_line , h_patch = plot_xY (
375
392
time_points ,
376
- self .y_pred_control .posterior_predictive .y_hat ,
393
+ self .y_pred_control .posterior_predictive .mu ,
377
394
ax = ax ,
378
395
plot_hdi_kwargs = {"color" : "C0" },
396
+ label = "Control group" ,
379
397
)
398
+ handles = [(h_line , h_patch )]
399
+ labels = ["Control group" ]
380
400
381
401
# Plot model fit to treatment group
382
402
time_points = self .x_pred_control [self .time_variable_name ].values
383
- plot_xY (
403
+ h_line , h_patch = plot_xY (
384
404
time_points ,
385
- self .y_pred_treatment .posterior_predictive .y_hat ,
405
+ self .y_pred_treatment .posterior_predictive .mu ,
386
406
ax = ax ,
387
407
plot_hdi_kwargs = {"color" : "C1" },
408
+ label = "Treatment group" ,
388
409
)
410
+ handles .append ((h_line , h_patch ))
411
+ labels .append ("Treatment group" )
389
412
390
413
# Plot counterfactual - post-test for treatment group IF no treatment
391
414
# had occurred.
@@ -403,26 +426,34 @@ def plot(self):
403
426
widths = 0.2 ,
404
427
)
405
428
for pc in parts ["bodies" ]:
406
- pc .set_facecolor ("C2 " )
429
+ pc .set_facecolor ("C0 " )
407
430
pc .set_edgecolor ("None" )
408
431
pc .set_alpha (0.5 )
409
432
else :
410
- plot_xY (
433
+ h_line , h_patch = plot_xY (
411
434
time_points ,
412
- self .y_pred_counterfactual .posterior_predictive .y_hat ,
435
+ self .y_pred_counterfactual .posterior_predictive .mu ,
413
436
ax = ax ,
414
437
plot_hdi_kwargs = {"color" : "C2" },
438
+ label = "Counterfactual" ,
415
439
)
440
+ handles .append ((h_line , h_patch ))
441
+ labels .append ("Counterfactual" )
416
442
417
443
# arrow to label the causal impact
418
444
self ._plot_causal_impact_arrow (ax )
445
+
419
446
# formatting
420
447
ax .set (
421
448
xticks = self .x_pred_treatment [self .time_variable_name ].values ,
422
449
title = self ._causal_impact_summary_stat (),
423
450
)
424
- ax .legend (fontsize = LEGEND_FONT_SIZE )
425
- return (fig , ax )
451
+ ax .legend (
452
+ handles = (h_tuple for h_tuple in handles ),
453
+ labels = labels ,
454
+ fontsize = LEGEND_FONT_SIZE ,
455
+ )
456
+ return fig , ax
426
457
427
458
def _plot_causal_impact_arrow (self , ax ):
428
459
"""
@@ -582,12 +613,17 @@ def plot(self):
582
613
c = "k" , # hue="treated",
583
614
ax = ax ,
584
615
)
616
+
585
617
# Plot model fit to data
586
- plot_xY (
618
+ h_line , h_patch = plot_xY (
587
619
self .x_pred [self .running_variable_name ],
588
620
self .pred ["posterior_predictive" ].mu ,
589
621
ax = ax ,
622
+ plot_hdi_kwargs = {"color" : "C1" },
590
623
)
624
+ handles = [(h_line , h_patch )]
625
+ labels = ["Posterior mean" ]
626
+
591
627
# create strings to compose title
592
628
title_info = f"{ self .score .r2 :.3f} (std = { self .score .r2_std :.3f} )"
593
629
r2 = f"Bayesian $R^2$ on all data = { title_info } "
@@ -605,7 +641,11 @@ def plot(self):
605
641
color = "r" ,
606
642
label = "treatment threshold" ,
607
643
)
608
- ax .legend (fontsize = LEGEND_FONT_SIZE )
644
+ ax .legend (
645
+ handles = (h_tuple for h_tuple in handles ),
646
+ labels = labels ,
647
+ fontsize = LEGEND_FONT_SIZE ,
648
+ )
609
649
return (fig , ax )
610
650
611
651
def summary (self ):
@@ -710,27 +750,38 @@ def plot(self):
710
750
hue = "group" ,
711
751
alpha = 0.5 ,
712
752
data = self .data ,
753
+ legend = True ,
713
754
ax = ax [0 ],
714
755
)
715
756
ax [0 ].set (xlabel = "Pretest" , ylabel = "Posttest" )
716
757
717
758
# plot posterior predictive of untreated
718
- plot_xY (
759
+ h_line , h_patch = plot_xY (
719
760
self .pred_xi ,
720
- self .pred_untreated ["posterior_predictive" ].y_hat ,
761
+ self .pred_untreated ["posterior_predictive" ].mu ,
721
762
ax = ax [0 ],
722
763
plot_hdi_kwargs = {"color" : "C0" },
764
+ label = "Control group" ,
723
765
)
766
+ handles = [(h_line , h_patch )]
767
+ labels = ["Control group" ]
724
768
725
769
# plot posterior predictive of treated
726
- plot_xY (
770
+ h_line , h_patch = plot_xY (
727
771
self .pred_xi ,
728
- self .pred_treated ["posterior_predictive" ].y_hat ,
772
+ self .pred_treated ["posterior_predictive" ].mu ,
729
773
ax = ax [0 ],
730
774
plot_hdi_kwargs = {"color" : "C1" },
775
+ label = "Treatment group" ,
731
776
)
777
+ handles .append ((h_line , h_patch ))
778
+ labels .append ("Treatment group" )
732
779
733
- ax [0 ].legend (fontsize = LEGEND_FONT_SIZE )
780
+ ax [0 ].legend (
781
+ handles = (h_tuple for h_tuple in handles ),
782
+ labels = labels ,
783
+ fontsize = LEGEND_FONT_SIZE ,
784
+ )
734
785
735
786
# Plot estimated caual impact / treatment effect
736
787
az .plot_posterior (self .causal_impact , ref_val = 0 , ax = ax [1 ])
0 commit comments