diff --git a/pymc_marketing/mmm/base.py b/pymc_marketing/mmm/base.py index a690099ef..c71294a02 100644 --- a/pymc_marketing/mmm/base.py +++ b/pymc_marketing/mmm/base.py @@ -1001,6 +1001,34 @@ def plot_direct_contribution_curves( fig.suptitle("Direct response curves", fontsize=16) return fig + def _get_distribution(self, dist: Dict) -> Callable: + """ + Retrieve a PyMC distribution callable based on the provided dictionary. + + Parameters + ---------- + dist : Dict + A dictionary containing the key 'dist' which should correspond to the + name of a PyMC distribution. + + Returns + ------- + Callable + A PyMC distribution callable that can be used to instantiate a random + variable. + + Raises + ------ + ValueError + If the specified distribution name in the dictionary does not correspond + to any distribution in PyMC. + """ + try: + prior_distribution = getattr(pm, dist["dist"]) + except AttributeError: + raise ValueError(f"Distribution {dist['dist']} does not exist in PyMC") + return prior_distribution + def compute_mean_contributions_over_time( self, original_scale: bool = False ) -> pd.DataFrame: diff --git a/pymc_marketing/mmm/delayed_saturated_mmm.py b/pymc_marketing/mmm/delayed_saturated_mmm.py index 4d9c14ba9..8add01243 100644 --- a/pymc_marketing/mmm/delayed_saturated_mmm.py +++ b/pymc_marketing/mmm/delayed_saturated_mmm.py @@ -9,6 +9,8 @@ import pandas as pd import pymc as pm import seaborn as sns +from pytensor.compile.sharedvalue import SharedVariable +from pytensor.tensor import TensorVariable from xarray import DataArray from pymc_marketing.mmm.base import MMM @@ -142,6 +144,105 @@ def _save_input_params(self, idata) -> None: idata.attrs["validate_data"] = json.dumps(self.validate_data) idata.attrs["yearly_seasonality"] = json.dumps(self.yearly_seasonality) + def _create_likelihood_distribution( + self, + dist: Dict, + mu: SharedVariable, + observed: Union[np.ndarray, pd.Series], + dims: str, + ) -> TensorVariable: + """ + Create and return a likelihood distribution for the model. + + This method prepares the distribution and its parameters as specified in the + configuration dictionary, validates them, and constructs the likelihood + distribution using PyMC. + + Parameters + ---------- + dist : Dict + A configuration dictionary that must contain a 'dist' key with the name of + the distribution and a 'kwargs' key with parameters for the distribution. + observed : Union[np.ndarray, pd.Series] + The observed data to which the likelihood distribution will be fitted. + dims : str + The dimensions of the data. + + Returns + ------- + TensorVariable + The likelihood distribution constructed with PyMC. + + Raises + ------ + ValueError + If 'kwargs' key is missing in `dist`, or the parameter configuration does + not contain 'dist' and 'kwargs' keys, or if 'mu' is present in the nested + 'kwargs' + """ + + allowed_distributions = [ + "Normal", + "StudentT", + "Laplace", + "Logistic", + "LogNormal", + "Wald", + "TruncatedNormal", + "Gamma", + "AsymmetricLaplace", + "VonMises", + ] + + if dist["dist"] not in allowed_distributions: + raise ValueError( + f"The distribution used for the likelihood is not allowed. Please, use one of the following distributions: {allowed_distributions}." + ) + + # Validate that 'kwargs' is present and is a dictionary + if "kwargs" not in dist or not isinstance(dist["kwargs"], dict): + raise ValueError( + "The 'kwargs' key must be present in the 'dist' dictionary and be a dictionary itself." + ) + + if "mu" in dist["kwargs"]: + raise ValueError( + "The 'mu' key is not allowed directly within 'kwargs' of the main distribution as it is reserved." + ) + + parameter_distributions = {} + for param, param_config in dist["kwargs"].items(): + # Check if param_config is a dictionary with a 'dist' key + if isinstance(param_config, dict) and "dist" in param_config: + # Prepare nested distribution + if "kwargs" not in param_config: + raise ValueError( + f"The parameter configuration for '{param}' must contain 'kwargs'." + ) + + parameter_distributions[param] = self._get_distribution( + dist=param_config + )(**param_config["kwargs"], name=f"likelihood_{param}") + elif isinstance(param_config, (int, float)): + # Use the value directly + parameter_distributions[param] = param_config + else: + raise ValueError( + f"Invalid parameter configuration for '{param}'. It must be either a dictionary with a 'dist' key or a numeric value." + ) + + # Extract the likelihood distribution name and instantiate it + likelihood_dist_name = dist["dist"] + likelihood_dist = self._get_distribution(dist={"dist": likelihood_dist_name}) + + return likelihood_dist( + name="likelihood", + mu=mu, + observed=observed, + dims=dims, + **parameter_distributions, + ) + def build_model( self, X: pd.DataFrame, @@ -171,13 +272,55 @@ def build_model( --------------- model : pm.Model The PyMC model object containing all the defined stochastic and deterministic variables. + + Examples + -------- + custom_config = { + 'intercept': {'dist': 'Normal', 'kwargs': {'mu': 0, 'sigma': 2}}, + 'beta_channel': {'dist': 'LogNormal', 'kwargs': {'mu': 1, 'sigma': 3}}, + 'alpha': {'dist': 'Beta', 'kwargs': {'alpha': 1, 'beta': 3}}, + 'lam': {'dist': 'Gamma', 'kwargs': {'alpha': 3, 'beta': 1}}, + 'likelihood': {'dist': 'Normal', + 'kwargs': {'sigma': {'dist': 'HalfNormal', 'kwargs': {'sigma': 2}}} + }, + 'gamma_control': {'dist': 'Normal', 'kwargs': {'mu': 0, 'sigma': 2}}, + 'gamma_fourier': {'dist': 'Laplace', 'kwargs': {'mu': 0, 'b': 1}} + } + + model = DelayedSaturatedMMM( + date_column="date_week", + channel_columns=["x1", "x2"], + control_columns=[ + "event_1", + "event_2", + "t", + ], + adstock_max_lag=8, + yearly_seasonality=2, + model_config=custom_config, + ) """ - model_config = self.model_config + + self.intercept_dist = self._get_distribution( + dist=self.model_config["intercept"] + ) + self.beta_channel_dist = self._get_distribution( + dist=self.model_config["beta_channel"] + ) + self.lam_dist = self._get_distribution(dist=self.model_config["lam"]) + self.alpha_dist = self._get_distribution(dist=self.model_config["alpha"]) + self.gamma_control_dist = self._get_distribution( + dist=self.model_config["gamma_control"] + ) + self.gamma_fourier_dist = self._get_distribution( + dist=self.model_config["gamma_fourier"] + ) + self._generate_and_preprocess_model_data(X, y) with pm.Model(coords=self.model_coords) as self.model: channel_data_ = pm.MutableData( name="channel_data", - value=self.preprocessed_data["X"][self.channel_columns].to_numpy(), + value=self.preprocessed_data["X"][self.channel_columns], dims=("date", "channel"), ) @@ -187,33 +330,26 @@ def build_model( dims="date", ) - intercept = pm.Normal( - name="intercept", - mu=model_config["intercept"]["mu"], - sigma=model_config["intercept"]["sigma"], + intercept = self.intercept_dist( + name="intercept", **self.model_config["intercept"]["kwargs"] ) - beta_channel = pm.HalfNormal( + beta_channel = self.beta_channel_dist( name="beta_channel", - sigma=model_config["beta_channel"]["sigma"], - dims=model_config["beta_channel"]["dims"], + **self.model_config["beta_channel"]["kwargs"], + dims=("channel",), ) - alpha = pm.Beta( + alpha = self.alpha_dist( name="alpha", - alpha=model_config["alpha"]["alpha"], - beta=model_config["alpha"]["beta"], - dims=model_config["alpha"]["dims"], + dims="channel", + **self.model_config["alpha"]["kwargs"], ) - - lam = pm.Gamma( + lam = self.lam_dist( name="lam", - alpha=model_config["lam"]["alpha"], - beta=model_config["lam"]["beta"], - dims=model_config["lam"]["dims"], + dims="channel", + **self.model_config["lam"]["kwargs"], ) - sigma = pm.HalfNormal(name="sigma", sigma=model_config["sigma"]["sigma"]) - channel_adstock = pm.Deterministic( name="channel_adstock", var=geometric_adstock( @@ -245,19 +381,18 @@ def build_model( for column in self.control_columns ) ): + gamma_control = self.gamma_control_dist( + name="gamma_control", + dims="control", + **self.model_config["gamma_control"]["kwargs"], + ) + control_data_ = pm.MutableData( name="control_data", value=self.preprocessed_data["X"][self.control_columns], dims=("date", "control"), ) - gamma_control = pm.Normal( - name="gamma_control", - mu=model_config["gamma_control"]["mu"], - sigma=model_config["gamma_control"]["sigma"], - dims=model_config["gamma_control"]["dims"], - ) - control_contributions = pm.Deterministic( name="control_contributions", var=control_data_ * gamma_control, @@ -280,11 +415,10 @@ def build_model( dims=("date", "fourier_mode"), ) - gamma_fourier = pm.Laplace( + gamma_fourier = self.gamma_fourier_dist( name="gamma_fourier", - mu=model_config["gamma_fourier"]["mu"], - b=model_config["gamma_fourier"]["b"], - dims=model_config["gamma_fourier"]["dims"], + dims="fourier_mode", + **self.model_config["gamma_fourier"]["kwargs"], ) fourier_contribution = pm.Deterministic( @@ -295,36 +429,31 @@ def build_model( mu_var += fourier_contribution.sum(axis=-1) - mu = pm.Deterministic( - name="mu", var=mu_var, dims=model_config["mu"]["dims"] - ) + mu = pm.Deterministic(name="mu", var=mu_var, dims="date") - pm.Normal( - name="likelihood", + self._create_likelihood_distribution( + dist=self.model_config["likelihood"], mu=mu, - sigma=sigma, observed=target_, - dims=model_config["likelihood"]["dims"], + dims="date", ) @property def default_model_config(self) -> Dict: - model_config: Dict = { - "intercept": {"mu": 0, "sigma": 2}, - "beta_channel": {"sigma": 2, "dims": ("channel",)}, - "alpha": {"alpha": 1, "beta": 3, "dims": ("channel",)}, - "lam": {"alpha": 3, "beta": 1, "dims": ("channel",)}, - "sigma": {"sigma": 2}, - "gamma_control": { - "mu": 0, - "sigma": 2, - "dims": ("control",), + return { + "intercept": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}}, + "beta_channel": {"dist": "HalfNormal", "kwargs": {"sigma": 2}}, + "alpha": {"dist": "Beta", "kwargs": {"alpha": 1, "beta": 3}}, + "lam": {"dist": "Gamma", "kwargs": {"alpha": 3, "beta": 1}}, + "likelihood": { + "dist": "Normal", + "kwargs": { + "sigma": {"dist": "HalfNormal", "kwargs": {"sigma": 2}}, + }, }, - "mu": {"dims": ("date",)}, - "likelihood": {"dims": ("date",)}, - "gamma_fourier": {"mu": 0, "b": 1, "dims": "fourier_mode"}, + "gamma_control": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}}, + "gamma_fourier": {"dist": "Laplace", "kwargs": {"mu": 0, "b": 1}}, } - return model_config def _get_fourier_models_data(self, X) -> pd.DataFrame: """Generates fourier modes to model seasonality. diff --git a/tests/mmm/test_delayed_saturated_mmm.py b/tests/mmm/test_delayed_saturated_mmm.py index f3d8dbdc7..3abb8749e 100644 --- a/tests/mmm/test_delayed_saturated_mmm.py +++ b/tests/mmm/test_delayed_saturated_mmm.py @@ -1,5 +1,5 @@ import os -from typing import List, Optional +from typing import List, Optional, Dict import arviz as az import numpy as np @@ -39,28 +39,35 @@ def toy_X() -> pd.DataFrame: @pytest.fixture(scope="class") -def model_config_requiring_serialization() -> dict: +def model_config_requiring_serialization() -> Dict: model_config = { - "intercept": {"mu": 0, "sigma": 2}, + "intercept": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}}, "beta_channel": { - "sigma": np.array([0.4533017, 0.25488063]), - "dims": ("channel",), + "dist": "HalfNormal", + "kwargs": {"sigma": np.array([0.4533017, 0.25488063])}, }, "alpha": { - "alpha": np.array([3, 3]), - "beta": np.array([3.55001301, 2.87092431]), - "dims": ("channel",), + "dist": "Beta", + "kwargs": { + "alpha": np.array([3, 3]), + "beta": np.array([3.55001301, 2.87092431]), + }, }, "lam": { - "alpha": np.array([3, 3]), - "beta": np.array([4.12231653, 5.02896872]), - "dims": ("channel",), + "dist": "Gamma", + "kwargs": { + "alpha": np.array([3, 3]), + "beta": np.array([4.12231653, 5.02896872]), + }, }, - "sigma": {"sigma": 2}, - "gamma_control": {"mu": 0, "sigma": 2, "dims": ("control",)}, - "mu": {"dims": ("date",)}, - "likelihood": {"dims": ("date",)}, - "gamma_fourier": {"mu": 0, "b": 1, "dims": "fourier_mode"}, + "likelihood": { + "dist": "Normal", + "kwargs": { + "sigma": {"dist": "HalfNormal", "kwargs": {"sigma": 2}}, + }, + }, + "gamma_control": {"dist": "HalfNormal", "kwargs": {"mu": 0, "sigma": 2}}, + "gamma_fourier": {"dist": "HalfNormal", "kwargs": {"mu": 0, "b": 1}}, } return model_config @@ -506,3 +513,115 @@ def mock_property(self): ): DelayedSaturatedMMM.load("test_model") os.remove("test_model") + + @pytest.mark.parametrize( + argnames="model_config", + argvalues=[ + None, + { + "intercept": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}}, + "beta_channel": { + "dist": "HalfNormal", + "kwargs": {"sigma": np.array([0.4533017, 0.25488063])}, + }, + "alpha": { + "dist": "Beta", + "kwargs": { + "alpha": np.array([3, 3]), + "beta": np.array([3.55001301, 2.87092431]), + }, + }, + "lam": { + "dist": "Gamma", + "kwargs": { + "alpha": np.array([3, 3]), + "beta": np.array([4.12231653, 5.02896872]), + }, + }, + "likelihood": { + "dist": "StudentT", + "kwargs": {"nu": 3, "sigma": 2}, + }, + "gamma_control": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}}, + "gamma_fourier": {"dist": "Laplace", "kwargs": {"mu": 0, "b": 1}}, + }, + ], + ids=["default_config", "custom_config"], + ) + def test_model_config( + self, model_config: Dict, toy_X: pd.DataFrame, toy_y: pd.Series + ): + # Create model instance with specified config + model = DelayedSaturatedMMM( + date_column="date", + channel_columns=["channel_1", "channel_2"], + adstock_max_lag=2, + yearly_seasonality=2, + model_config=model_config, + ) + + model.build_model(X=toy_X, y=toy_y.to_numpy()) + # Check for default configuration + if model_config is None: + # assert observed RV type, and priors of some/all free_RVs. + assert isinstance( + model.model.observed_RVs[0].owner.op, pm.Normal + ) # likelihood + # Add more asserts as needed for default configuration + + # Check for custom configuration + else: + # assert custom configuration is applied correctly + assert isinstance( + model.model.observed_RVs[0].owner.op, pm.StudentT + ) # likelihood + assert isinstance( + model.model["beta_channel"].owner.op, pm.HalfNormal + ) # beta_channel + + +def test_get_valid_distribution(mmm): + normal_dist = mmm._get_distribution({"dist": "Normal"}) + assert normal_dist is pm.Normal + + +def test_get_invalid_distribution(mmm): + with pytest.raises(ValueError, match="does not exist in PyMC"): + mmm._get_distribution({"dist": "NonExistentDist"}) + + +def test_invalid_likelihood_type(mmm): + with pytest.raises( + ValueError, + match="The distribution used for the likelihood is not allowed", + ): + mmm._create_likelihood_distribution( + dist={"dist": "Cauchy", "kwargs": {"alpha": 2, "beta": 4}}, + mu=np.array([0]), + observed=np.random.randn(100), + dims="obs_dim", + ) + + +def test_create_likelihood_invalid_kwargs_structure(mmm): + with pytest.raises( + ValueError, match="either a dictionary with a 'dist' key or a numeric value" + ): + mmm._create_likelihood_distribution( + dist={"dist": "Normal", "kwargs": {"sigma": "not a dictionary or numeric"}}, + mu=np.array([0]), + observed=np.random.randn(100), + dims="obs_dim", + ) + + +def test_create_likelihood_mu_in_top_level_kwargs(mmm): + with pytest.raises( + ValueError, match="'mu' key is not allowed directly within 'kwargs'" + ): + mmm._create_likelihood_distribution( + dist={"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}}, + mu=np.array([0]), + observed=np.random.randn(100), + dims="obs_dim", + )