@@ -405,17 +405,8 @@ def plot_posterior_predictive(
405
405
fig = ax .figure
406
406
407
407
for hdi_prob , alpha in zip ((0.94 , 0.50 ), (0.2 , 0.4 ), strict = True ):
408
- likelihood_hdi : DataArray = az .hdi (
409
- ary = posterior_predictive_data , hdi_prob = hdi_prob
410
- )[self .output_var ]
411
-
412
- ax .fill_between (
413
- x = posterior_predictive_data .date ,
414
- y1 = likelihood_hdi [:, 0 ],
415
- y2 = likelihood_hdi [:, 1 ],
416
- color = "C0" ,
417
- alpha = alpha ,
418
- label = f"{ hdi_prob :.0%} HDI" ,
408
+ ax = self ._add_hdi_to_plot (
409
+ ax = ax , original_scale = original_scale , hdi_prob = hdi_prob , alpha = alpha
419
410
)
420
411
421
412
if add_mean :
@@ -477,6 +468,35 @@ def _add_mean_to_plot(
477
468
)
478
469
return ax
479
470
471
+ def _add_hdi_to_plot (
472
+ self ,
473
+ ax : plt .Axes ,
474
+ original_scale : bool = False ,
475
+ hdi_prob : float = 0.94 ,
476
+ color : str = "C0" ,
477
+ alpha : float = 0.2 ,
478
+ ** kwargs ,
479
+ ) -> plt .Axes :
480
+ """Add HDI to existing plot."""
481
+ posterior_predictive_data : Dataset = self ._get_posterior_predictive_data (
482
+ original_scale = original_scale
483
+ )
484
+
485
+ likelihood_hdi : DataArray = az .hdi (
486
+ ary = posterior_predictive_data , hdi_prob = hdi_prob
487
+ )[self .output_var ]
488
+
489
+ ax .fill_between (
490
+ x = posterior_predictive_data .date ,
491
+ y1 = likelihood_hdi [:, 0 ],
492
+ y2 = likelihood_hdi [:, 1 ],
493
+ color = color ,
494
+ alpha = alpha ,
495
+ label = f"{ hdi_prob :.0%} HDI" ,
496
+ ** kwargs ,
497
+ )
498
+ return ax
499
+
480
500
def get_errors (self , original_scale : bool = False ) -> DataArray :
481
501
"""Get model errors posterior distribution.
482
502
0 commit comments