Skip to content

Commit 9362709

Browse files
authored
Add user customization to plot_curve methods (#1018)
1 parent 40dee1d commit 9362709

File tree

5 files changed

+588
-135
lines changed

5 files changed

+588
-135
lines changed

pymc_marketing/mmm/components/base.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,23 @@
2121
"""
2222

2323
import warnings
24+
from collections.abc import Iterable
2425
from copy import deepcopy
2526
from inspect import signature
2627
from typing import Any
2728

28-
import matplotlib.pyplot as plt
2929
import numpy as np
3030
import numpy.typing as npt
3131
import pymc as pm
3232
import xarray as xr
33+
from matplotlib.axes import Axes
34+
from matplotlib.figure import Figure
3335
from pymc.distributions.shape_utils import Dims
3436
from pytensor import tensor as pt
37+
from pytensor.tensor.variable import TensorVariable
3538

3639
from pymc_marketing.mmm.plot import (
40+
SelToString,
3741
plot_curve,
3842
plot_hdi,
3943
plot_samples,
@@ -299,12 +303,10 @@ def variable_mapping(self) -> dict[str, str]:
299303

300304
def _create_distributions(
301305
self, dims: Dims | None = None
302-
) -> dict[str, pt.TensorVariable]:
306+
) -> dict[str, TensorVariable]:
303307
dim_handler: DimHandler = create_dim_handler(dims)
304308

305-
def create_variable(
306-
parameter_name: str, variable_name: str
307-
) -> pt.TensorVariable:
309+
def create_variable(parameter_name: str, variable_name: str) -> TensorVariable:
308310
dist = self.function_priors[parameter_name]
309311
var = dist.create_variable(variable_name)
310312
return dim_handler(var, dist.dims)
@@ -344,7 +346,12 @@ def plot_curve(
344346
subplot_kwargs: dict | None = None,
345347
sample_kwargs: dict | None = None,
346348
hdi_kwargs: dict | None = None,
347-
) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]:
349+
axes: npt.NDArray[Axes] | None = None,
350+
same_axes: bool = False,
351+
colors: Iterable[str] | None = None,
352+
legend: bool | None = None,
353+
sel_to_string: SelToString | None = None,
354+
) -> tuple[Figure, npt.NDArray[Axes]]:
348355
"""Plot curve HDI and samples.
349356
350357
Parameters
@@ -357,6 +364,16 @@ def plot_curve(
357364
Keyword arguments for the plot_curve_sample function. Defaults to None.
358365
hdi_kwargs : dict, optional
359366
Keyword arguments for the plot_curve_hdi function. Defaults to None.
367+
axes : npt.NDArray[plt.Axes], optional
368+
The exact axes to plot on. Overrides any subplot_kwargs
369+
same_axes : bool, optional
370+
If the axes should be the same for all plots. Defaults to False.
371+
colors : Iterable[str], optional
372+
The colors to use for the plot. Defaults to None.
373+
legend : bool, optional
374+
If the legend should be shown. Defaults to None.
375+
sel_to_string : SelToString, optional
376+
The function to convert the selection to a string. Defaults to None.
360377
361378
Returns
362379
-------
@@ -369,6 +386,11 @@ def plot_curve(
369386
subplot_kwargs=subplot_kwargs,
370387
sample_kwargs=sample_kwargs,
371388
hdi_kwargs=hdi_kwargs,
389+
axes=axes,
390+
same_axes=same_axes,
391+
colors=colors,
392+
legend=legend,
393+
sel_to_string=sel_to_string,
372394
)
373395

374396
def _sample_curve(
@@ -424,8 +446,8 @@ def plot_curve_samples(
424446
rng: np.random.Generator | None = None,
425447
plot_kwargs: dict | None = None,
426448
subplot_kwargs: dict | None = None,
427-
axes: npt.NDArray[plt.Axes] | None = None,
428-
) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]:
449+
axes: npt.NDArray[Axes] | None = None,
450+
) -> tuple[Figure, npt.NDArray[Axes]]:
429451
"""Plot samples from the curve.
430452
431453
Parameters
@@ -466,8 +488,8 @@ def plot_curve_hdi(
466488
hdi_kwargs: dict | None = None,
467489
plot_kwargs: dict | None = None,
468490
subplot_kwargs: dict | None = None,
469-
axes: npt.NDArray[plt.Axes] | None = None,
470-
) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]:
491+
axes: npt.NDArray[Axes] | None = None,
492+
) -> tuple[Figure, npt.NDArray[Axes]]:
471493
"""Plot the HDI of the curve.
472494
473495
Parameters
@@ -494,9 +516,10 @@ def plot_curve_hdi(
494516
axes=axes,
495517
subplot_kwargs=subplot_kwargs,
496518
plot_kwargs=plot_kwargs,
519+
hdi_kwargs=hdi_kwargs,
497520
)
498521

499-
def apply(self, x: pt.TensorLike, dims: Dims | None = None) -> pt.TensorVariable:
522+
def apply(self, x: pt.TensorLike, dims: Dims | None = None) -> TensorVariable:
500523
"""Call within a model context.
501524
502525
Used internally of the MMM to apply the transformation to the data.

pymc_marketing/mmm/fourier.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@
205205
206206
"""
207207

208-
from collections.abc import Callable
208+
from collections.abc import Callable, Iterable
209209
from typing import Any
210210

211211
import arviz as az
@@ -219,7 +219,7 @@
219219
from typing_extensions import Self
220220

221221
from pymc_marketing.constants import DAYS_IN_MONTH, DAYS_IN_YEAR
222-
from pymc_marketing.mmm.plot import plot_curve, plot_hdi, plot_samples
222+
from pymc_marketing.mmm.plot import SelToString, plot_curve, plot_hdi, plot_samples
223223
from pymc_marketing.prior import Prior, create_dim_handler
224224

225225
X_NAME: str = "day"
@@ -465,6 +465,11 @@ def plot_curve(
465465
subplot_kwargs: dict | None = None,
466466
sample_kwargs: dict | None = None,
467467
hdi_kwargs: dict | None = None,
468+
axes: npt.NDArray[plt.Axes] | None = None,
469+
same_axes: bool = False,
470+
colors: Iterable[str] | None = None,
471+
legend: bool | None = None,
472+
sel_to_string: SelToString | None = None,
468473
) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]:
469474
"""Plot the seasonality for one full period.
470475
@@ -478,6 +483,16 @@ def plot_curve(
478483
Keyword arguments for the plot_full_period_samples method, by default None
479484
hdi_kwargs : dict, optional
480485
Keyword arguments for the plot_full_period_hdi method, by default None
486+
axes : npt.NDArray[plt.Axes], optional
487+
Matplotlib axes, by default None
488+
same_axes : bool, optional
489+
Use the same axes for all plots, by default False
490+
colors : Iterable[str], optional
491+
Colors for the different plots, by default None
492+
legend : bool, optional
493+
Show the legend, by default None
494+
sel_to_string : SelToString, optional
495+
Function to convert the selection to a string, by default None
481496
482497
Returns
483498
-------
@@ -491,6 +506,11 @@ def plot_curve(
491506
subplot_kwargs=subplot_kwargs,
492507
sample_kwargs=sample_kwargs,
493508
hdi_kwargs=hdi_kwargs,
509+
axes=axes,
510+
same_axes=same_axes,
511+
colors=colors,
512+
legend=legend,
513+
sel_to_string=sel_to_string,
494514
)
495515

496516
def plot_curve_hdi(
@@ -596,9 +616,9 @@ class YearlyFourier(FourierBase):
596616
dist = Prior("Laplace", mu=mu, b=b, dims="fourier")
597617
yearly = YearlyFourier(n_order=2, prior=dist)
598618
prior = yearly.sample_prior(random_seed=rng)
599-
curve = yearly.sample_full_period(prior)
619+
curve = yearly.sample_curve(prior)
600620
601-
_, axes = yearly.plot_full_period(curve)
621+
_, axes = yearly.plot_curve(curve)
602622
axes[0].set(title="Yearly Fourier Seasonality")
603623
plt.show()
604624
@@ -643,9 +663,9 @@ class MonthlyFourier(FourierBase):
643663
dist = Prior("Laplace", mu=mu, b=b, dims="fourier")
644664
yearly = MonthlyFourier(n_order=2, prior=dist)
645665
prior = yearly.sample_prior(samples=100)
646-
curve = yearly.sample_full_period(prior)
666+
curve = yearly.sample_curve(prior)
647667
648-
_, axes = yearly.plot_full_period(curve)
668+
_, axes = yearly.plot_curve(curve)
649669
axes[0].set(title="Monthly Fourier Seasonality")
650670
plt.show()
651671

pymc_marketing/mmm/linear_trend.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,22 @@
5252
5353
"""
5454

55+
from collections.abc import Iterable
5556
from typing import Any, cast
5657

57-
import matplotlib.pyplot as plt
5858
import numpy as np
5959
import numpy.typing as npt
6060
import pymc as pm
6161
import pytensor.tensor as pt
6262
import xarray as xr
63+
from matplotlib.axes import Axes
64+
from matplotlib.figure import Figure
6365
from pydantic import BaseModel, Field, InstanceOf, model_validator
6466
from pymc.distributions.shape_utils import Dims
67+
from pytensor.tensor.variable import TensorVariable
6568
from typing_extensions import Self
6669

67-
from pymc_marketing.mmm.plot import plot_curve
70+
from pymc_marketing.mmm.plot import SelToString, plot_curve
6871
from pymc_marketing.prior import Prior, create_dim_handler
6972

7073

@@ -278,7 +281,7 @@ def default_priors(self) -> dict[str, Prior]:
278281

279282
return priors
280283

281-
def apply(self, t: pt.TensorLike) -> pt.TensorVariable:
284+
def apply(self, t: pt.TensorLike) -> TensorVariable:
282285
"""Create the linear trend for the given x values.
283286
284287
Parameters
@@ -409,7 +412,12 @@ def plot_curve(
409412
sample_kwargs: dict | None = None,
410413
hdi_kwargs: dict | None = None,
411414
include_changepoints: bool = True,
412-
) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]:
415+
axes: npt.NDArray[Axes] | None = None,
416+
same_axes: bool = False,
417+
colors: Iterable[str] | None = None,
418+
legend: bool | None = None,
419+
sel_to_string: SelToString | None = None,
420+
) -> tuple[Figure, npt.NDArray[Axes]]:
413421
"""Plot the curve samples from the trend.
414422
415423
Parameters
@@ -424,6 +432,16 @@ def plot_curve(
424432
Keyword arguments for the HDI, by default None.
425433
include_changepoints : bool, optional
426434
Include the change points in the plot, by default True.
435+
axes : npt.NDArray[plt.Axes], optional
436+
Axes to plot the curve, by default None.
437+
same_axes : bool, optional
438+
Use the same axes for the samples, by default False.
439+
colors : Iterable[str], optional
440+
Colors for the samples, by default None.
441+
legend : bool, optional
442+
Include a legend in the plot, by default None.
443+
sel_to_string : SelToString, optional
444+
Function to convert the selection to a string, by default None.
427445
428446
Returns
429447
-------
@@ -437,6 +455,11 @@ def plot_curve(
437455
subplot_kwargs=subplot_kwargs,
438456
sample_kwargs=sample_kwargs,
439457
hdi_kwargs=hdi_kwargs,
458+
axes=axes,
459+
same_axes=same_axes,
460+
colors=colors,
461+
legend=legend,
462+
sel_to_string=sel_to_string,
440463
)
441464

442465
if not include_changepoints:

0 commit comments

Comments
 (0)