Skip to content

Commit e7349ea

Browse files
committed
created private function to add HDI of posterior predictive to plot
1 parent b069eae commit e7349ea

File tree

1 file changed

+31
-11
lines changed

1 file changed

+31
-11
lines changed

pymc_marketing/mmm/base.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -405,17 +405,8 @@ def plot_posterior_predictive(
405405
fig = ax.figure
406406

407407
for hdi_prob, alpha in zip((0.94, 0.50), (0.2, 0.4), strict=True):
408-
likelihood_hdi: DataArray = az.hdi(
409-
ary=posterior_predictive_data, hdi_prob=hdi_prob
410-
)[self.output_var]
411-
412-
ax.fill_between(
413-
x=posterior_predictive_data.date,
414-
y1=likelihood_hdi[:, 0],
415-
y2=likelihood_hdi[:, 1],
416-
color="C0",
417-
alpha=alpha,
418-
label=f"{hdi_prob:.0%} HDI",
408+
ax = self._add_hdi_to_plot(
409+
ax=ax, original_scale=original_scale, hdi_prob=hdi_prob, alpha=alpha
419410
)
420411

421412
if add_mean:
@@ -477,6 +468,35 @@ def _add_mean_to_plot(
477468
)
478469
return ax
479470

471+
def _add_hdi_to_plot(
472+
self,
473+
ax: plt.Axes,
474+
original_scale: bool = False,
475+
hdi_prob: float = 0.94,
476+
color: str = "C0",
477+
alpha: float = 0.2,
478+
**kwargs,
479+
) -> plt.Axes:
480+
"""Add HDI to existing plot."""
481+
posterior_predictive_data: Dataset = self._get_posterior_predictive_data(
482+
original_scale=original_scale
483+
)
484+
485+
likelihood_hdi: DataArray = az.hdi(
486+
ary=posterior_predictive_data, hdi_prob=hdi_prob
487+
)[self.output_var]
488+
489+
ax.fill_between(
490+
x=posterior_predictive_data.date,
491+
y1=likelihood_hdi[:, 0],
492+
y2=likelihood_hdi[:, 1],
493+
color=color,
494+
alpha=alpha,
495+
label=f"{hdi_prob:.0%} HDI",
496+
**kwargs,
497+
)
498+
return ax
499+
480500
def get_errors(self, original_scale: bool = False) -> DataArray:
481501
"""Get model errors posterior distribution.
482502

0 commit comments

Comments
 (0)