@@ -302,17 +302,19 @@ def plot(self):
302
302
fig , ax = plt .subplots ()
303
303
304
304
# Plot raw data
305
- sns .lineplot (
306
- self .data ,
307
- x = self .time_variable_name ,
308
- y = self .outcome_variable_name ,
309
- hue = self .group_variable_name ,
310
- units = "unit" ,
311
- estimator = None ,
312
- alpha = 0.25 ,
313
- ax = ax ,
314
- )
305
+ # NOTE: This will not work when there is just ONE unit in each group
306
+ # sns.lineplot(
307
+ # self.data,
308
+ # x=self.time_variable_name,
309
+ # y=self.outcome_variable_name,
310
+ # hue=self.group_variable_name,
311
+ # # units="unit",
312
+ # estimator=None,
313
+ # alpha=0.25,
314
+ # ax=ax,
315
+ # )
315
316
# Plot model fit to control group
317
+ # NOTE: This will not work when there is just ONE unit in each group
316
318
parts = ax .violinplot (
317
319
az .extract (
318
320
self .y_pred_control , group = "posterior_predictive" , var_names = "mu"
@@ -328,6 +330,7 @@ def plot(self):
328
330
pc .set_alpha (0.5 )
329
331
330
332
# Plot model fit to treatment group
333
+ # NOTE: This will not work when there is just ONE unit in each group
331
334
parts = ax .violinplot (
332
335
az .extract (
333
336
self .y_pred_treatment , group = "posterior_predictive" , var_names = "mu"
@@ -337,18 +340,19 @@ def plot(self):
337
340
showmedians = False ,
338
341
widths = 0.2 ,
339
342
)
340
- # Plot counterfactual - post-test for treatment group IF no treatment had occurred.
341
- parts = ax .violinplot (
342
- az .extract (
343
- self .y_pred_counterfactual ,
344
- group = "posterior_predictive" ,
345
- var_names = "mu" ,
346
- ).values .T ,
347
- positions = self .x_pred_counterfactual [self .time_variable_name ].values ,
348
- showmeans = False ,
349
- showmedians = False ,
350
- widths = 0.2 ,
351
- )
343
+ # # Plot counterfactual - post-test for treatment group IF no treatment had occurred.
344
+ # # NOTE: This will not work when there is just ONE unit in each group
345
+ # parts = ax.violinplot(
346
+ # az.extract(
347
+ # self.y_pred_counterfactual,
348
+ # group="posterior_predictive",
349
+ # var_names="mu",
350
+ # ).values.T,
351
+ # positions=self.x_pred_counterfactual[self.time_variable_name].values,
352
+ # showmeans=False,
353
+ # showmedians=False,
354
+ # widths=0.2,
355
+ # )
352
356
# arrow to label the causal impact
353
357
y_pred_treatment = (
354
358
self .y_pred_treatment ["posterior_predictive" ]
@@ -378,9 +382,9 @@ def plot(self):
378
382
)
379
383
# formatting
380
384
ax .set (
381
- xlim = [- 0.15 , 1.25 ],
382
- xticks = [ 0 , 1 ] ,
383
- xticklabels = ["pre" , "post" ],
385
+ # xlim=[-0.15, 1.25],
386
+ xticks = self . x_pred_treatment [ self . time_variable_name ]. values ,
387
+ # xticklabels=["pre", "post"],
384
388
title = self ._causal_impact_summary_stat (),
385
389
)
386
390
ax .legend (fontsize = LEGEND_FONT_SIZE )
0 commit comments