@@ -362,7 +362,9 @@ def plot_prior_predictive(self, **plt_kwargs: Any) -> plt.Figure:
362
362
def plot_posterior_predictive (
363
363
self ,
364
364
original_scale : bool = False ,
365
+ add_hdi : bool = True ,
365
366
add_mean : bool = True ,
367
+ add_gradient : bool = False ,
366
368
ax : plt .Axes = None ,
367
369
** plt_kwargs : Any ,
368
370
) -> plt .Figure :
@@ -404,16 +406,22 @@ def plot_posterior_predictive(
404
406
else :
405
407
fig = ax .figure
406
408
407
- for hdi_prob , alpha in zip ((0.94 , 0.50 ), (0.2 , 0.4 ), strict = True ):
408
- ax = self ._add_hdi_to_plot (
409
- ax = ax , original_scale = original_scale , hdi_prob = hdi_prob , alpha = alpha
410
- )
409
+ if add_hdi :
410
+ for hdi_prob , alpha in zip ((0.94 , 0.50 ), (0.2 , 0.4 ), strict = True ):
411
+ ax = self ._add_hdi_to_plot (
412
+ ax = ax , original_scale = original_scale , hdi_prob = hdi_prob , alpha = alpha
413
+ )
411
414
412
415
if add_mean :
413
416
ax = self ._add_mean_to_plot (
414
417
ax = ax , original_scale = original_scale , color = "red"
415
418
)
416
419
420
+ if add_gradient :
421
+ ax = self ._add_gradient_to_plot (
422
+ ax = ax , original_scale = original_scale , n_percentiles = 30 , palette = "Blues"
423
+ )
424
+
417
425
ax .plot (
418
426
np .asarray (posterior_predictive_data .date ),
419
427
target_to_plot ,
@@ -497,6 +505,78 @@ def _add_hdi_to_plot(
497
505
)
498
506
return ax
499
507
508
+ def _add_gradient_to_plot (
509
+ self ,
510
+ ax : plt .Axes ,
511
+ original_scale : bool = False ,
512
+ n_percentiles : int = 30 ,
513
+ palette : str = "Blues" ,
514
+ ** kwargs ,
515
+ ) -> plt .Axes :
516
+ """
517
+ Add a gradient representation of the posterior predictive distribution to an existing plot.
518
+
519
+ This method creates a shaded area plot where the color intensity represents
520
+ the density of the posterior predictive distribution.
521
+
522
+ Parameters
523
+ ----------
524
+ ax : plt.Axes
525
+ The matplotlib axes object to add the gradient to.
526
+ original_scale : bool, optional
527
+ If True, use the original scale of the data. Default is False.
528
+ n_percentiles : int, optional
529
+ Number of percentile ranges to use for the gradient. Default is 30.
530
+ palette : str, optional
531
+ Color palette to use for the gradient. Default is "Blues".
532
+ **kwargs
533
+ Additional keyword arguments passed to ax.fill_between().
534
+
535
+ Returns
536
+ -------
537
+ plt.Axes
538
+ The matplotlib axes object with the gradient added.
539
+ """
540
+ # Get posterior predictive data and flatten it
541
+ posterior_predictive = self ._get_posterior_predictive_data (
542
+ original_scale = original_scale
543
+ )
544
+ posterior_predictive_flattened = posterior_predictive .stack (
545
+ sample = ("chain" , "draw" )
546
+ ).to_dataarray ()
547
+ dates = posterior_predictive .date .values
548
+
549
+ # Set up color map and ranges
550
+ cmap = plt .get_cmap (palette )
551
+ color_range = np .linspace (0.3 , 1.0 , n_percentiles // 2 )
552
+ percentile_ranges = np .linspace (3 , 97 , n_percentiles )
553
+
554
+ # Create gradient by filling between percentile ranges
555
+ for i in range (len (percentile_ranges ) - 1 ):
556
+ lower_percentile = np .percentile (
557
+ posterior_predictive_flattened , percentile_ranges [i ], axis = 2
558
+ ).squeeze ()
559
+ upper_percentile = np .percentile (
560
+ posterior_predictive_flattened , percentile_ranges [i + 1 ], axis = 2
561
+ ).squeeze ()
562
+ if i < n_percentiles // 2 :
563
+ color_val = color_range [i ]
564
+ else :
565
+ color_val = color_range [n_percentiles - i - 2 ]
566
+ alpha_val = 0.2 + 0.8 * (
567
+ 1 - abs (2 * i / n_percentiles - 1 )
568
+ ) # Higher alpha in the middle
569
+ ax .fill_between (
570
+ x = dates ,
571
+ y1 = lower_percentile ,
572
+ y2 = upper_percentile ,
573
+ color = cmap (color_val ),
574
+ alpha = alpha_val ,
575
+ ** kwargs ,
576
+ )
577
+
578
+ return ax
579
+
500
580
def get_errors (self , original_scale : bool = False ) -> DataArray :
501
581
"""Get model errors posterior distribution.
502
582
0 commit comments