@@ -370,53 +370,52 @@ def plot(self):
370
370
alpha = 0.5 ,
371
371
ax = ax ,
372
372
)
373
+
373
374
# Plot model fit to control group
374
- parts = ax .violinplot (
375
- az .extract (
376
- self .y_pred_control , group = "posterior_predictive" , var_names = "mu"
377
- ).values .T ,
378
- positions = self .x_pred_control [self .time_variable_name ].values ,
379
- showmeans = False ,
380
- showmedians = False ,
381
- widths = 0.2 ,
382
- )
383
- for pc in parts ["bodies" ]:
384
- pc .set_facecolor ("C0" )
385
- pc .set_edgecolor ("None" )
386
- pc .set_alpha (0.5 )
375
+ time_points = self .x_pred_control [self .time_variable_name ].values
376
+ plot_xY (
377
+ time_points ,
378
+ self .y_pred_control .posterior_predictive .y_hat ,
379
+ ax = ax ,
380
+ plot_hdi_kwargs = {"color" : "C0" },
381
+ )
387
382
388
383
# Plot model fit to treatment group
389
- parts = ax .violinplot (
390
- az .extract (
391
- self .y_pred_treatment , group = "posterior_predictive" , var_names = "mu"
392
- ).values .T ,
393
- positions = self .x_pred_treatment [self .time_variable_name ].values ,
394
- showmeans = False ,
395
- showmedians = False ,
396
- widths = 0.2 ,
397
- )
398
-
399
- for pc in parts ["bodies" ]:
400
- pc .set_facecolor ("C1" )
401
- pc .set_edgecolor ("None" )
402
- pc .set_alpha (0.5 )
384
+ time_points = self .x_pred_control [self .time_variable_name ].values
385
+ plot_xY (
386
+ time_points ,
387
+ self .y_pred_treatment .posterior_predictive .y_hat ,
388
+ ax = ax ,
389
+ plot_hdi_kwargs = {"color" : "C1" },
390
+ )
391
+
403
392
# Plot counterfactual - post-test for treatment group IF no treatment
404
393
# had occurred.
405
- parts = ax .violinplot (
406
- az .extract (
407
- self .y_pred_counterfactual ,
408
- group = "posterior_predictive" ,
409
- var_names = "mu" ,
410
- ).values .T ,
411
- positions = self .x_pred_counterfactual [self .time_variable_name ].values ,
412
- showmeans = False ,
413
- showmedians = False ,
414
- widths = 0.2 ,
415
- )
416
- for pc in parts ["bodies" ]:
417
- pc .set_facecolor ("C2" )
418
- pc .set_edgecolor ("None" )
419
- pc .set_alpha (0.5 )
394
+ time_points = self .x_pred_counterfactual [self .time_variable_name ].values
395
+ if len (time_points ) == 1 :
396
+ parts = ax .violinplot (
397
+ az .extract (
398
+ self .y_pred_counterfactual ,
399
+ group = "posterior_predictive" ,
400
+ var_names = "mu" ,
401
+ ).values .T ,
402
+ positions = self .x_pred_counterfactual [self .time_variable_name ].values ,
403
+ showmeans = False ,
404
+ showmedians = False ,
405
+ widths = 0.2 ,
406
+ )
407
+ for pc in parts ["bodies" ]:
408
+ pc .set_facecolor ("C2" )
409
+ pc .set_edgecolor ("None" )
410
+ pc .set_alpha (0.5 )
411
+ else :
412
+ plot_xY (
413
+ time_points ,
414
+ self .y_pred_counterfactual .posterior_predictive .y_hat ,
415
+ ax = ax ,
416
+ plot_hdi_kwargs = {"color" : "C2" },
417
+ )
418
+
420
419
# arrow to label the causal impact
421
420
self ._plot_causal_impact_arrow (ax )
422
421
# formatting
0 commit comments