diff --git a/pymc_marketing/mmm/delayed_saturated_mmm.py b/pymc_marketing/mmm/delayed_saturated_mmm.py index 2ff94d71e..5b49d9ca1 100644 --- a/pymc_marketing/mmm/delayed_saturated_mmm.py +++ b/pymc_marketing/mmm/delayed_saturated_mmm.py @@ -8,6 +8,7 @@ import numpy.typing as npt import pandas as pd import pymc as pm +import pytensor.tensor as pt import seaborn as sns from xarray import DataArray @@ -187,32 +188,15 @@ def build_model( dims="date", ) - intercept = pm.Normal( - name="intercept", - mu=model_config["intercept"]["mu"], - sigma=model_config["intercept"]["sigma"], - ) - - beta_channel = pm.HalfNormal( - name="beta_channel", - sigma=model_config["beta_channel"]["sigma"], - dims=model_config["beta_channel"]["dims"], - ) - alpha = pm.Beta( - name="alpha", - alpha=model_config["alpha"]["alpha"], - beta=model_config["alpha"]["beta"], - dims=model_config["alpha"]["dims"], - ) + #Building the priors + priors = self.create_priors_from_config(self.model_config) - lam = pm.Gamma( - name="lam", - alpha=model_config["lam"]["alpha"], - beta=model_config["lam"]["beta"], - dims=model_config["lam"]["dims"], - ) - - sigma = pm.HalfNormal(name="sigma", sigma=model_config["sigma"]["sigma"]) + #Specifying the variables + intercept = priors['intercept'] + beta_channel = priors['beta_channel'] + alpha = priors['alpha'] + lam = priors['lam'] + gamma_control = priors['gamma_control'] channel_adstock = pm.Deterministic( name="channel_adstock", @@ -230,6 +214,7 @@ def build_model( var=logistic_saturation(x=channel_adstock, lam=lam), dims=("date", "channel"), ) + channel_contributions = pm.Deterministic( name="channel_contributions", var=channel_adstock_saturated * beta_channel, @@ -251,12 +236,8 @@ def build_model( dims=("date", "control"), ) - gamma_control = pm.Normal( - name="gamma_control", - mu=model_config["gamma_control"]["mu"], - sigma=model_config["gamma_control"]["sigma"], - dims=model_config["gamma_control"]["dims"], - ) + print("Shape of control_data_:", control_data_.eval().shape) + print("Shape of gamma_control:", gamma_control.eval().shape) control_contributions = pm.Deterministic( name="control_contributions", @@ -299,33 +280,135 @@ def build_model( name="mu", var=mu_var, dims=model_config["mu"]["dims"] ) - pm.Normal( - name="likelihood", - mu=mu, - sigma=sigma, - observed=target_, - dims=model_config["likelihood"]["dims"], - ) + likelihood = self.create_likelihood(self.model_config, target_, mu) + + def create_priors_from_config(self, model_config): + priors, dimensions = {}, {"channel": len(self.channel_columns), "control": len(self.control_columns)} + stacked_priors = {} + + positive_params = {"intercept", "beta_channel", "alpha", "lam", "sigma"} # Set of params that need positive=True + + for param, config in model_config.items(): + if param == "likelihood": continue + + prior_type = config.get("type") + if prior_type is not None: + + # Initial value based on parameter name + is_positive = param in positive_params + + # Override if the config explicitly sets the 'positive' key + if 'positive' in config: + is_positive = config.get('positive') + + if prior_type == "tvp": + if param in ["intercept", "lam", "alpha", "sigma"]: + priors[param] = self.gp_wrapper(name=param, X=np.arange(len(self.X[self.date_column]))[:, None], config=config, positive=is_positive) + continue + + length = dimensions.get(config.get("dims", [None, None])[1], 1) + priors[param] = self.create_tvp_priors(param, config, length, positive=is_positive) + continue + + dist_func = getattr(pm, prior_type, None) + if not dist_func: raise ValueError(f"Invalid distribution type {prior_type}") + config_copy = {k: v for k, v in config.items() if k != "type"} + priors[param] = dist_func(name=param, **config_copy) + + return priors + + def create_likelihood(self, model_config, target_, mu): + likelihood_config = model_config.get("likelihood", {}) + likelihood_type = likelihood_config.get("type") + dims = likelihood_config.get("dims") + + if not likelihood_type: + raise ValueError("Likelihood type must be specified in the model config.") + + likelihood_func = getattr(pm, likelihood_type, None) + if likelihood_func is None: + raise ValueError(f"Invalid likelihood type {likelihood_type}") + + # Transform mu if the likelihood type is Lognormal or HurdleLognormal + if likelihood_type in ['LogNormal', 'HurdleLogNormal']: + mu = pt.log(mu) + + # Create sub-priors + sub_priors = {} + for param, config in likelihood_config.items(): + if param not in ['type', 'dims']: # Skip 'type' and 'dims' + if param == 'params': # Handle nested 'params' + for sub_param, sub_config in config.items(): + sub_priors[sub_param] = self.create_priors_from_config({sub_param: sub_config})[sub_param] + else: + sub_priors[param] = self.create_priors_from_config({param: config})[param] + + return likelihood_func(name="likelihood", mu=mu, observed=target_, dims=dims, **sub_priors) + + def create_tvp_priors(self, param, config, length, positive=False): + dims = config.get("dims", None) # Extracting dims from the config + print(dims) + gp_list = [self.gp_wrapper(name=f"{param}_{i}", X=np.arange(len(self.X[self.date_column]))[:, None], config=config, positive=positive) for i in range(length)] + stacked_gp = pt.stack(gp_list, axis=1) + return pm.Deterministic(f"{param}", stacked_gp, dims=dims) + + + def gp_wrapper(self, name, X, config, positive=False, **kwargs): + return self.gp_coeff(X, name, config=config, positive=positive, **kwargs) + + def gp_coeff(self, X, name, mean=0.0, positive=False, config=None): + params = pm.find_constrained_prior(pm.Gamma, 8, 12, init_guess={"alpha": 1, "beta": 1}, mass=0.8) + ell = pm.Gamma(f"ell_{name}", **params) + eta = pm.Exponential(f"_eta_{name}", lam=1) + # cov = eta ** 2 * pm.gp.cov.ExpQuad(1, ls=ell) + + cov = eta ** 2 * pm.gp.cov.Matern32(1, ls=ell) + + gp = pm.gp.HSGP(m=[40], c=4, cov_func=cov) + f_raw = gp.prior(f"{name}_tvp_raw", X=X) + + + # Inside your gp_coeff function + # Offset + offset_config = config.get('offset', None) if config else None + if offset_config: + offset_type = offset_config.get('type') + offset_params = {k: v for k, v in offset_config.items() if k != 'type'} + offset_prior = getattr(pm, offset_type)(name=f"{name}_offset", **offset_params) + else: + offset_prior = 0 + + if positive: + f_output = pm.Deterministic(f"{name}", (pt.exp(f_raw)) + offset_prior, dims=("date")) + else: + f_output = pm.Deterministic(f"{name}", f_raw + offset_prior, dims=("date")) + + return f_output + + @property def default_model_config(self) -> Dict: model_config: Dict = { - "intercept": {"mu": 0, "sigma": 2}, - "beta_channel": {"sigma": 2, "dims": ("channel",)}, - "alpha": {"alpha": 1, "beta": 3, "dims": ("channel",)}, - "lam": {"alpha": 3, "beta": 1, "dims": ("channel",)}, - "sigma": {"sigma": 2}, - "gamma_control": { - "mu": 0, - "sigma": 2, - "dims": ("control",), - }, + "intercept": {"type": "Normal", "mu": 0, "sigma": 2}, + "beta_channel": {"type": "HalfNormal", "sigma": 2, "dims": ("channel",)}, + "alpha": {"type": "Beta", "alpha": 1, "beta": 3, "dims": ("channel",)}, + "lam": {"type": "Gamma", "alpha": 3, "beta": 1, "dims": ("channel",)}, + "gamma_control": {'type': 'Gamma', 'alpha': 2, 'beta': 1, 'dims': ('control',)}, "mu": {"dims": ("date",)}, - "likelihood": {"dims": ("date",)}, + "likelihood": { + "type": "Normal", + "dims": ("date",), + "params": { + "sigma": {"type": "HalfNormal", "sigma": 1, 'dims': ('date',)}, + } + }, "gamma_fourier": {"mu": 0, "b": 1, "dims": "fourier_mode"}, } return model_config + + def _get_fourier_models_data(self, X) -> pd.DataFrame: """Generates fourier modes to model seasonality.