Skip to content

Commit 38bc101

Browse files
committed
added _add_gradient_to_plot
1 parent e7349ea commit 38bc101

File tree

2 files changed

+122
-8
lines changed

2 files changed

+122
-8
lines changed

docs/source/notebooks/mmm/mmm_plotting_options.ipynb

Lines changed: 38 additions & 4 deletions
Large diffs are not rendered by default.

pymc_marketing/mmm/base.py

Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,9 @@ def plot_prior_predictive(self, **plt_kwargs: Any) -> plt.Figure:
362362
def plot_posterior_predictive(
363363
self,
364364
original_scale: bool = False,
365+
add_hdi: bool = True,
365366
add_mean: bool = True,
367+
add_gradient: bool = False,
366368
ax: plt.Axes = None,
367369
**plt_kwargs: Any,
368370
) -> plt.Figure:
@@ -404,16 +406,22 @@ def plot_posterior_predictive(
404406
else:
405407
fig = ax.figure
406408

407-
for hdi_prob, alpha in zip((0.94, 0.50), (0.2, 0.4), strict=True):
408-
ax = self._add_hdi_to_plot(
409-
ax=ax, original_scale=original_scale, hdi_prob=hdi_prob, alpha=alpha
410-
)
409+
if add_hdi:
410+
for hdi_prob, alpha in zip((0.94, 0.50), (0.2, 0.4), strict=True):
411+
ax = self._add_hdi_to_plot(
412+
ax=ax, original_scale=original_scale, hdi_prob=hdi_prob, alpha=alpha
413+
)
411414

412415
if add_mean:
413416
ax = self._add_mean_to_plot(
414417
ax=ax, original_scale=original_scale, color="red"
415418
)
416419

420+
if add_gradient:
421+
ax = self._add_gradient_to_plot(
422+
ax=ax, original_scale=original_scale, n_percentiles=30, palette="Blues"
423+
)
424+
417425
ax.plot(
418426
np.asarray(posterior_predictive_data.date),
419427
target_to_plot,
@@ -497,6 +505,78 @@ def _add_hdi_to_plot(
497505
)
498506
return ax
499507

508+
def _add_gradient_to_plot(
509+
self,
510+
ax: plt.Axes,
511+
original_scale: bool = False,
512+
n_percentiles: int = 30,
513+
palette: str = "Blues",
514+
**kwargs,
515+
) -> plt.Axes:
516+
"""
517+
Add a gradient representation of the posterior predictive distribution to an existing plot.
518+
519+
This method creates a shaded area plot where the color intensity represents
520+
the density of the posterior predictive distribution.
521+
522+
Parameters
523+
----------
524+
ax : plt.Axes
525+
The matplotlib axes object to add the gradient to.
526+
original_scale : bool, optional
527+
If True, use the original scale of the data. Default is False.
528+
n_percentiles : int, optional
529+
Number of percentile ranges to use for the gradient. Default is 30.
530+
palette : str, optional
531+
Color palette to use for the gradient. Default is "Blues".
532+
**kwargs
533+
Additional keyword arguments passed to ax.fill_between().
534+
535+
Returns
536+
-------
537+
plt.Axes
538+
The matplotlib axes object with the gradient added.
539+
"""
540+
# Get posterior predictive data and flatten it
541+
posterior_predictive = self._get_posterior_predictive_data(
542+
original_scale=original_scale
543+
)
544+
posterior_predictive_flattened = posterior_predictive.stack(
545+
sample=("chain", "draw")
546+
).to_dataarray()
547+
dates = posterior_predictive.date.values
548+
549+
# Set up color map and ranges
550+
cmap = plt.get_cmap(palette)
551+
color_range = np.linspace(0.3, 1.0, n_percentiles // 2)
552+
percentile_ranges = np.linspace(3, 97, n_percentiles)
553+
554+
# Create gradient by filling between percentile ranges
555+
for i in range(len(percentile_ranges) - 1):
556+
lower_percentile = np.percentile(
557+
posterior_predictive_flattened, percentile_ranges[i], axis=2
558+
).squeeze()
559+
upper_percentile = np.percentile(
560+
posterior_predictive_flattened, percentile_ranges[i + 1], axis=2
561+
).squeeze()
562+
if i < n_percentiles // 2:
563+
color_val = color_range[i]
564+
else:
565+
color_val = color_range[n_percentiles - i - 2]
566+
alpha_val = 0.2 + 0.8 * (
567+
1 - abs(2 * i / n_percentiles - 1)
568+
) # Higher alpha in the middle
569+
ax.fill_between(
570+
x=dates,
571+
y1=lower_percentile,
572+
y2=upper_percentile,
573+
color=cmap(color_val),
574+
alpha=alpha_val,
575+
**kwargs,
576+
)
577+
578+
return ax
579+
500580
def get_errors(self, original_scale: bool = False) -> DataArray:
501581
"""Get model errors posterior distribution.
502582

0 commit comments

Comments
 (0)