diff --git a/pymc_marketing/clv/models/basic.py b/pymc_marketing/clv/models/basic.py index 9f43148ee..07070bdf8 100644 --- a/pymc_marketing/clv/models/basic.py +++ b/pymc_marketing/clv/models/basic.py @@ -1,17 +1,14 @@ import json -import types import warnings from pathlib import Path -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Union import arviz as az import numpy as np import pandas as pd import pymc as pm -from pymc import str_for_dist from pymc.backends import NDArray from pymc.backends.base import MultiTrace -from pytensor.tensor import TensorVariable from xarray import Dataset from pymc_marketing.model_builder import ModelBuilder @@ -192,34 +189,6 @@ def load(cls, fname: str): return model - @staticmethod - def _check_prior_ndim(prior, ndim: int = 0): - if prior.ndim != ndim: - raise ValueError( - f"Prior variable {prior} must be have {ndim} ndims, but it has {prior.ndim} ndims." - ) - - @staticmethod - def _create_distribution(dist: Dict, ndim: int = 0) -> TensorVariable: - try: - prior_distribution = getattr(pm, dist["dist"]).dist(**dist["kwargs"]) - CLVModel._check_prior_ndim(prior_distribution, ndim) - except AttributeError: - raise ValueError(f"Distribution {dist['dist']} does not exist in PyMC") - return prior_distribution - - @staticmethod - def _process_priors( - *priors: TensorVariable, check_ndim: bool = True - ) -> Tuple[TensorVariable, ...]: - """Check that each prior variable is unique and attach `str_repr` method.""" - if len(priors) != len(set(priors)): - raise ValueError("Prior variables must be unique") - # Related to https://github.com/pymc-devs/pymc/issues/6311 - for prior in priors: - prior.str_repr = types.MethodType(str_for_dist, prior) # type: ignore - return priors - @property def default_sampler_config(self) -> Dict: return {} diff --git a/pymc_marketing/mmm/delayed_saturated_mmm.py b/pymc_marketing/mmm/delayed_saturated_mmm.py index 2ff94d71e..c9748efcd 100644 --- a/pymc_marketing/mmm/delayed_saturated_mmm.py +++ b/pymc_marketing/mmm/delayed_saturated_mmm.py @@ -57,6 +57,49 @@ def __init__( yearly_seasonality : Optional[int], optional Number of Fourier modes to model yearly seasonality, by default None. + Examples + -------- + DelayedSaturatedMMM + + .. code-block:: python + + import pymc as pm + from pymc_marketing.mmm import DelayedSaturatedMMM + + data_url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/datasets/mmm_example.csv" + data = pd.read_csv(data_url, parse_dates=['date_week']) + + model = DelayedSaturatedMMM( + date_column="date_week", + channel_columns=["x1", "x2"], + control_columns=[ + "event_1", + "event_2", + "t", + ], + adstock_max_lag=8, + yearly_seasonality=2, + model_config={ + # priors + "intercept": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}}, + "beta_channel": {"dist": "HalfNormal", "kwargs": {"sigma": 2}, "dims": ("channel",)}, + "alpha": {"dist": "Beta", "kwargs": {"alpha": 1, "beta": 3}, "dims": ("channel",)}, + "lam": {"dist": "Gamma", "kwargs": {"alpha": 3, "beta": 1}, "dims": ("channel",)}, + "sigma": {"dist": "HalfNormal", "kwargs": {"sigma": 2}}, + "gamma_control": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}, "dims": ("control",)}, + "gamma_fourier": {"dist": "Laplace", "kwargs": {"mu": 0, "b": 1}, "dims": "fourier_mode"}, + # params + "mu": {"dims": ("date",)}, + "likelihood": {"dims": ("date",)}, + }, + ) + + X = data.drop('y',axis=1) + y = data['y'] + + model.fit(X,y) + model.plot_components_contributions(); + References ---------- .. [1] Jin, Yuxue, et al. “Bayesian methods for media mix modeling with carryover and shape effects.” (2017). @@ -75,6 +118,28 @@ def __init__( adstock_max_lag=adstock_max_lag, ) + # Define custom priors + self.intercept = self._create_distribution(self.model_config["intercept"]) + self.beta_channel = self._create_distribution(self.model_config["beta_channel"]) + self.alpha = self._create_distribution(self.model_config["alpha"]) + self.lam = self._create_distribution(self.model_config["lam"]) + self.sigma = self._create_distribution(self.model_config["sigma"]) + self.gamma_control = self._create_distribution( + self.model_config["gamma_control"] + ) + self.gamma_fourier = self._create_distribution( + self.model_config["gamma_fourier"] + ) + self._process_priors( + self.intercept, + self.beta_channel, + self.alpha, + self.lam, + self.sigma, + self.gamma_control, + self.gamma_fourier, + ) + @property def default_sampler_config(self) -> Dict: return {} @@ -174,6 +239,7 @@ def build_model( """ model_config = self.model_config self._generate_and_preprocess_model_data(X, y) + with pm.Model(coords=self.model_coords) as self.model: channel_data_ = pm.MutableData( name="channel_data", @@ -187,33 +253,16 @@ def build_model( dims="date", ) - intercept = pm.Normal( - name="intercept", - mu=model_config["intercept"]["mu"], - sigma=model_config["intercept"]["sigma"], - ) - - beta_channel = pm.HalfNormal( - name="beta_channel", - sigma=model_config["beta_channel"]["sigma"], - dims=model_config["beta_channel"]["dims"], - ) - alpha = pm.Beta( - name="alpha", - alpha=model_config["alpha"]["alpha"], - beta=model_config["alpha"]["beta"], - dims=model_config["alpha"]["dims"], - ) - - lam = pm.Gamma( - name="lam", - alpha=model_config["lam"]["alpha"], - beta=model_config["lam"]["beta"], - dims=model_config["lam"]["dims"], + # FIXME: Need to add the correct dims to `beta_channel`, `alpha`, `lam`, + intercept = self.model.register_rv(self.intercept, name="intercept") + beta_channel = self.model.register_rv( + self.beta_channel, name="beta_channel" ) + alpha = self.model.register_rv(self.alpha, name="alpha") + lam = self.model.register_rv(self.lam, name="lam") + sigma = self.model.register_rv(self.sigma, name="sigma") - sigma = pm.HalfNormal(name="sigma", sigma=model_config["sigma"]["sigma"]) - + # TODO: register the adstock transforms channel_adstock = pm.Deterministic( name="channel_adstock", var=geometric_adstock( @@ -245,19 +294,16 @@ def build_model( for column in self.control_columns ) ): + gamma_control = self.model.register_rv( + self.gamma_control, name="gamma_control" + ) + control_data_ = pm.MutableData( name="control_data", value=self.preprocessed_data["X"][self.control_columns], dims=("date", "control"), ) - gamma_control = pm.Normal( - name="gamma_control", - mu=model_config["gamma_control"]["mu"], - sigma=model_config["gamma_control"]["sigma"], - dims=model_config["gamma_control"]["dims"], - ) - control_contributions = pm.Deterministic( name="control_contributions", var=control_data_ * gamma_control, @@ -274,19 +320,16 @@ def build_model( for column in self.fourier_columns ) ): + gamma_fourier = self.model.register_rv( + self.gamma_fourier, name="gamma_fourier" + ) + fourier_data_ = pm.MutableData( name="fourier_data", value=self.preprocessed_data["X"][self.fourier_columns], dims=("date", "fourier_mode"), ) - gamma_fourier = pm.Laplace( - name="gamma_fourier", - mu=model_config["gamma_fourier"]["mu"], - b=model_config["gamma_fourier"]["b"], - dims=model_config["gamma_fourier"]["dims"], - ) - fourier_contribution = pm.Deterministic( name="fourier_contributions", var=fourier_data_ * gamma_fourier, @@ -308,23 +351,41 @@ def build_model( ) @property - def default_model_config(self) -> Dict: - model_config: Dict = { - "intercept": {"mu": 0, "sigma": 2}, - "beta_channel": {"sigma": 2, "dims": ("channel",)}, - "alpha": {"alpha": 1, "beta": 3, "dims": ("channel",)}, - "lam": {"alpha": 3, "beta": 1, "dims": ("channel",)}, - "sigma": {"sigma": 2}, + def default_model_config(self) -> Dict[str, Dict]: + return { + # Prior + "intercept": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}}, + "beta_channel": { + "dist": "HalfNormal", + "kwargs": {"sigma": 2}, + "dims": ("channel",), + }, + "alpha": { + "dist": "Beta", + "kwargs": {"alpha": 1, "beta": 3}, + "dims": ("channel",), + }, + "lam": { + "dist": "Gamma", + "kwargs": {"alpha": 3, "beta": 1}, + "dims": ("channel",), + }, + "sigma": {"dist": "HalfNormal", "kwargs": {"sigma": 2}}, "gamma_control": { - "mu": 0, - "sigma": 2, + "dist": "Normal", + "kwargs": {"mu": 0, "sigma": 2}, "dims": ("control",), }, + "gamma_fourier": { + "dist": "Laplace", + "kwargs": {"mu": 0, "b": 1}, + "dims": "fourier_mode", + }, + # Deterministic "mu": {"dims": ("date",)}, + # Likelihood "likelihood": {"dims": ("date",)}, - "gamma_fourier": {"mu": 0, "b": 1, "dims": "fourier_mode"}, } - return model_config def _get_fourier_models_data(self, X) -> pd.DataFrame: """Generates fourier modes to model seasonality. diff --git a/pymc_marketing/model_builder.py b/pymc_marketing/model_builder.py index b846d7fe4..1a0346113 100644 --- a/pymc_marketing/model_builder.py +++ b/pymc_marketing/model_builder.py @@ -15,17 +15,20 @@ import hashlib import json +import types import warnings from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import arviz as az import numpy as np import pandas as pd import pymc as pm import xarray as xr +from pymc import str_for_dist from pymc.util import RandomState +from pytensor.tensor import TensorVariable # If scikit-learn is available, use its data validator try: @@ -46,6 +49,7 @@ class ModelBuilder(ABC): and help with deployment. """ + model: pm.Model _model_type = "BaseClass" version = "None" @@ -414,6 +418,35 @@ def load(cls, fname: str): return model + @staticmethod + def _check_prior_ndim(prior, ndim: int = 0): + if prior.ndim != ndim: + raise ValueError( + f"Prior variable {prior} must be have {ndim} ndims, but it has {prior.ndim} ndims." + ) + + @staticmethod + def _create_distribution(dist: Dict, ndim: int = 0) -> TensorVariable: + try: + prior_distribution = getattr(pm, dist["dist"]).dist(**dist["kwargs"]) + except AttributeError as e: + raise ValueError( + f"Distribution {dist['dist']} does not exist in PyMC" + ) from e + return prior_distribution + + @staticmethod + def _process_priors( + *priors: TensorVariable, check_ndim: bool = True + ) -> Tuple[TensorVariable, ...]: + """Check that each prior variable is unique and attach `str_repr` method.""" + if len(priors) != len(set(priors)): + raise ValueError("Prior variables must be unique") + # Related to https://github.com/pymc-devs/pymc/issues/6311 + for prior in priors: + prior.str_repr = types.MethodType(str_for_dist, prior) # type: ignore + return priors + def fit( self, X: pd.DataFrame, diff --git a/tests/mmm/test_delayed_saturated_mmm.py b/tests/mmm/test_delayed_saturated_mmm.py index f3d8dbdc7..1ccdf450d 100644 --- a/tests/mmm/test_delayed_saturated_mmm.py +++ b/tests/mmm/test_delayed_saturated_mmm.py @@ -41,26 +41,41 @@ def toy_X() -> pd.DataFrame: @pytest.fixture(scope="class") def model_config_requiring_serialization() -> dict: model_config = { - "intercept": {"mu": 0, "sigma": 2}, + "intercept": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}}, "beta_channel": { - "sigma": np.array([0.4533017, 0.25488063]), + "dist": "HalfNormal", + "kwargs": {"sigma": np.array([0.4533017, 0.25488063])}, "dims": ("channel",), }, "alpha": { - "alpha": np.array([3, 3]), - "beta": np.array([3.55001301, 2.87092431]), + "dist": "Beta", + "kwargs": { + "alpha": np.array([3, 3]), + "beta": np.array([3.55001301, 2.87092431]), + }, "dims": ("channel",), }, "lam": { - "alpha": np.array([3, 3]), - "beta": np.array([4.12231653, 5.02896872]), + "dist": "Gamma", + "kwargs": { + "alpha": np.array([3, 3]), + "beta": np.array([4.12231653, 5.02896872]), + }, "dims": ("channel",), }, - "sigma": {"sigma": 2}, - "gamma_control": {"mu": 0, "sigma": 2, "dims": ("control",)}, + "sigma": {"dist": "HalfNormal", "kwargs": {"sigma": 2}}, + "gamma_control": { + "dist": "Normal", + "kwargs": {"mu": 0, "sigma": 2}, + "dims": ("control",), + }, + "gamma_fourier": { + "dist": "Laplace", + "kwargs": {"mu": 0, "b": 1}, + "dims": "fourier_mode", + }, "mu": {"dims": ("date",)}, "likelihood": {"dims": ("date",)}, - "gamma_fourier": {"mu": 0, "b": 1, "dims": "fourier_mode"}, } return model_config