Skip to content

Commit 98cfeaf

Browse files
committed
updated plot_posterior_predictive docstring
1 parent 38bc101 commit 98cfeaf

File tree

1 file changed

+35
-5
lines changed

1 file changed

+35
-5
lines changed

pymc_marketing/mmm/base.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -368,21 +368,51 @@ def plot_posterior_predictive(
368368
ax: plt.Axes = None,
369369
**plt_kwargs: Any,
370370
) -> plt.Figure:
371-
"""Plot posterior distribution from the model fit.
371+
"""
372+
Plot the posterior predictive distribution from the model fit.
373+
374+
This function creates a visualization of the model's posterior predictive distribution,
375+
allowing for comparison with observed data. It can include highest density intervals (HDI),
376+
mean predictions, and a gradient representation of the full distribution.
372377
373378
Parameters
374379
----------
375380
original_scale : bool, optional
376-
Whether to plot in the original scale.
381+
If True, plot in the original scale of the target variable.
382+
If False, plot in the transformed scale used for modeling. Default is False.
383+
add_hdi : bool, optional
384+
If True, add highest density intervals to the plot. Default is True.
385+
add_mean : bool, optional
386+
If True, add the mean prediction to the plot. Default is True.
387+
add_gradient : bool, optional
388+
If True, add a gradient representation of the full posterior distribution. Default is False.
377389
ax : plt.Axes, optional
378-
Matplotlib axis object.
379-
**plt_kwargs
380-
Keyword arguments passed to `plt.subplots`.
390+
A matplotlib Axes object to plot on. If None, a new figure and axes will be created.
391+
**plt_kwargs : dict
392+
Additional keyword arguments to pass to plt.subplots() when creating a new figure.
381393
382394
Returns
383395
-------
384396
plt.Figure
397+
The matplotlib Figure object containing the plot.
385398
399+
Raises
400+
------
401+
ValueError
402+
If the length of the target variable doesn't match the length
403+
of the date column in the posterior predictive data.
404+
405+
Notes
406+
-----
407+
This function visualizes the model's predictions against the observed data.
408+
The observed data is always plotted as a black line.
409+
Depending on the parameters, it can also show:
410+
- HDI (Highest Density Intervals) at 94% and 50% levels
411+
- Mean prediction line
412+
- Gradient representation of the full posterior distribution
413+
414+
If predicting out-of-sample, ensure that `self.y` is overwritten with the
415+
corresponding non-transformed target variable.
386416
"""
387417
posterior_predictive_data: Dataset = self._get_posterior_predictive_data(
388418
original_scale=original_scale

0 commit comments

Comments
 (0)