@@ -382,13 +382,9 @@ def plot_posterior_predictive(
382
382
plt.Figure
383
383
384
384
"""
385
- try :
386
- posterior_predictive_data : Dataset = self .posterior_predictive
387
-
388
- except Exception as e :
389
- raise RuntimeError (
390
- "Make sure the model has bin fitted and the posterior predictive has been sampled!"
391
- ) from e
385
+ posterior_predictive_data : Dataset = self ._get_posterior_predictive_data (
386
+ original_scale = original_scale
387
+ )
392
388
393
389
target_to_plot = np .asarray (
394
390
self .y
@@ -408,13 +404,6 @@ def plot_posterior_predictive(
408
404
else :
409
405
fig = ax .figure
410
406
411
- if original_scale :
412
- posterior_predictive_data = apply_sklearn_transformer_across_dim (
413
- data = posterior_predictive_data ,
414
- func = self .get_target_transformer ().inverse_transform ,
415
- dim_name = "date" ,
416
- )
417
-
418
407
for hdi_prob , alpha in zip ((0.94 , 0.50 ), (0.2 , 0.4 ), strict = True ):
419
408
likelihood_hdi : DataArray = az .hdi (
420
409
ary = posterior_predictive_data , hdi_prob = hdi_prob
@@ -430,15 +419,8 @@ def plot_posterior_predictive(
430
419
)
431
420
432
421
if add_mean :
433
- mean_prediction = posterior_predictive_data [self .output_var ].mean (
434
- dim = ["chain" , "draw" ]
435
- )
436
-
437
- ax .plot (
438
- np .asarray (posterior_predictive_data .date ),
439
- mean_prediction ,
440
- color = "C0" ,
441
- label = "Mean Prediction" ,
422
+ ax = self ._add_mean_to_plot (
423
+ ax = ax , original_scale = original_scale , color = "red"
442
424
)
443
425
444
426
ax .plot (
@@ -456,6 +438,45 @@ def plot_posterior_predictive(
456
438
457
439
return fig
458
440
441
+ def _get_posterior_predictive_data (self , original_scale : bool = False ) -> Dataset :
442
+ """Get the posterior predictive data."""
443
+ try :
444
+ posterior_predictive_data : Dataset = self .posterior_predictive
445
+
446
+ except Exception as e :
447
+ raise RuntimeError (
448
+ "Make sure the model has bin fitted and the posterior predictive has been sampled!"
449
+ ) from e
450
+
451
+ if original_scale :
452
+ posterior_predictive_data = apply_sklearn_transformer_across_dim (
453
+ data = posterior_predictive_data ,
454
+ func = self .get_target_transformer ().inverse_transform ,
455
+ dim_name = "date" ,
456
+ )
457
+ return posterior_predictive_data
458
+
459
+ def _add_mean_to_plot (
460
+ self , ax , original_scale : bool = False , color = "blue" , linestyle = "-" , ** kwargs
461
+ ) -> plt .Axes :
462
+ """Add mean prediction to existing plot."""
463
+ posterior_predictive_data : Dataset = self ._get_posterior_predictive_data (
464
+ original_scale = original_scale
465
+ )
466
+
467
+ mean_prediction = posterior_predictive_data [self .output_var ].mean (
468
+ dim = ["chain" , "draw" ]
469
+ )
470
+
471
+ ax .plot (
472
+ np .asarray (posterior_predictive_data .date ),
473
+ mean_prediction ,
474
+ color = color ,
475
+ linestyle = linestyle ,
476
+ label = "Mean Prediction" ,
477
+ )
478
+ return ax
479
+
459
480
def get_errors (self , original_scale : bool = False ) -> DataArray :
460
481
"""Get model errors posterior distribution.
461
482
0 commit comments