@@ -238,6 +238,8 @@ def __init__(
238
238
239
239
# TODO: check that data in column self.group_variable_name has TWO levels
240
240
241
+ # TODO: check we have `unit` as a predictor column which is an vector of labels of unique units
242
+
241
243
# TODO: `treated` is a deterministic function of group and time, so this should be a function rather than supplied data
242
244
243
245
# DEVIATION FROM SKL EXPERIMENT CODE =============================
@@ -303,18 +305,17 @@ def plot(self):
303
305
304
306
# Plot raw data
305
307
# NOTE: This will not work when there is just ONE unit in each group
306
- # sns.lineplot(
307
- # self.data,
308
- # x=self.time_variable_name,
309
- # y=self.outcome_variable_name,
310
- # hue=self.group_variable_name,
311
- # # units="unit",
312
- # estimator=None,
313
- # alpha=0.25 ,
314
- # ax=ax,
315
- # )
308
+ sns .lineplot (
309
+ self .data ,
310
+ x = self .time_variable_name ,
311
+ y = self .outcome_variable_name ,
312
+ hue = self .group_variable_name ,
313
+ units = "unit" , # NOTE: assumes we have a `unit` predictor variable
314
+ estimator = None ,
315
+ alpha = 0.5 ,
316
+ ax = ax ,
317
+ )
316
318
# Plot model fit to control group
317
- # NOTE: This will not work when there is just ONE unit in each group
318
319
parts = ax .violinplot (
319
320
az .extract (
320
321
self .y_pred_control , group = "posterior_predictive" , var_names = "mu"
@@ -330,7 +331,6 @@ def plot(self):
330
331
pc .set_alpha (0.5 )
331
332
332
333
# Plot model fit to treatment group
333
- # NOTE: This will not work when there is just ONE unit in each group
334
334
parts = ax .violinplot (
335
335
az .extract (
336
336
self .y_pred_treatment , group = "posterior_predictive" , var_names = "mu"
@@ -340,20 +340,41 @@ def plot(self):
340
340
showmedians = False ,
341
341
widths = 0.2 ,
342
342
)
343
- # # Plot counterfactual - post-test for treatment group IF no treatment had occurred.
344
- # # NOTE: This will not work when there is just ONE unit in each group
345
- # parts = ax.violinplot(
346
- # az.extract(
347
- # self.y_pred_counterfactual,
348
- # group="posterior_predictive",
349
- # var_names="mu",
350
- # ).values.T,
351
- # positions=self.x_pred_counterfactual[self.time_variable_name].values,
352
- # showmeans=False,
353
- # showmedians=False,
354
- # widths=0.2,
355
- # )
343
+ for pc in parts ["bodies" ]:
344
+ pc .set_facecolor ("C1" )
345
+ pc .set_edgecolor ("None" )
346
+ pc .set_alpha (0.5 )
347
+ # Plot counterfactual - post-test for treatment group IF no treatment had occurred.
348
+ parts = ax .violinplot (
349
+ az .extract (
350
+ self .y_pred_counterfactual ,
351
+ group = "posterior_predictive" ,
352
+ var_names = "mu" ,
353
+ ).values .T ,
354
+ positions = self .x_pred_counterfactual [self .time_variable_name ].values ,
355
+ showmeans = False ,
356
+ showmedians = False ,
357
+ widths = 0.2 ,
358
+ )
359
+ for pc in parts ["bodies" ]:
360
+ pc .set_facecolor ("C2" )
361
+ pc .set_edgecolor ("None" )
362
+ pc .set_alpha (0.5 )
356
363
# arrow to label the causal impact
364
+ self ._plot_causal_impact_arrow (ax )
365
+ # formatting
366
+ ax .set (
367
+ xticks = self .x_pred_treatment [self .time_variable_name ].values ,
368
+ title = self ._causal_impact_summary_stat (),
369
+ )
370
+ ax .legend (fontsize = LEGEND_FONT_SIZE )
371
+ return (fig , ax )
372
+
373
+ def _plot_causal_impact_arrow (self , ax ):
374
+ """
375
+ draw a vertical arrow between `y_pred_counterfactual` and `y_pred_counterfactual`
376
+ """
377
+ # Calculate y values to plot the arrow between
357
378
y_pred_treatment = (
358
379
self .y_pred_treatment ["posterior_predictive" ]
359
380
.mu .isel ({"obs_ind" : 1 })
@@ -363,32 +384,28 @@ def plot(self):
363
384
y_pred_counterfactual = (
364
385
self .y_pred_counterfactual ["posterior_predictive" ].mu .mean ().data
365
386
)
387
+ # Calculate the x position to plot at
388
+ diff = np .ptp (self .x_pred_treatment [self .time_variable_name ].values )
389
+ x = np .max (self .x_pred_treatment [self .time_variable_name ].values ) + 0.1 * diff
390
+ # Plot the arrow
366
391
ax .annotate (
367
392
"" ,
368
- xy = (1.15 , y_pred_counterfactual ),
393
+ xy = (x , y_pred_counterfactual ),
369
394
xycoords = "data" ,
370
- xytext = (1.15 , y_pred_treatment ),
395
+ xytext = (x , y_pred_treatment ),
371
396
textcoords = "data" ,
372
- arrowprops = {"arrowstyle" : "<-> " , "color" : "green" , "lw" : 3 },
397
+ arrowprops = {"arrowstyle" : "<-" , "color" : "green" , "lw" : 3 },
373
398
)
399
+ # Plot text annotation next to arrow
374
400
ax .annotate (
375
401
"causal\n impact" ,
376
- xy = (1.15 , np .mean ([y_pred_counterfactual , y_pred_treatment ])),
402
+ xy = (x , np .mean ([y_pred_counterfactual , y_pred_treatment ])),
377
403
xycoords = "data" ,
378
404
xytext = (5 , 0 ),
379
405
textcoords = "offset points" ,
380
406
color = "green" ,
381
407
va = "center" ,
382
408
)
383
- # formatting
384
- ax .set (
385
- # xlim=[-0.15, 1.25],
386
- xticks = self .x_pred_treatment [self .time_variable_name ].values ,
387
- # xticklabels=["pre", "post"],
388
- title = self ._causal_impact_summary_stat (),
389
- )
390
- ax .legend (fontsize = LEGEND_FONT_SIZE )
391
- return (fig , ax )
392
409
393
410
def _causal_impact_summary_stat (self ):
394
411
percentiles = self .causal_impact .quantile ([0.03 , 1 - 0.03 ]).values
0 commit comments