Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4,250 changes: 4,250 additions & 0 deletions docs/source/notebooks/mmm/mmm_tvp_example.ipynb

Large diffs are not rendered by default.

401 changes: 401 additions & 0 deletions docs/source/notebooks/mmm/mock_cgp_data-no-target.csv

Large diffs are not rendered by default.

45 changes: 33 additions & 12 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,17 @@ def __init__(
sampler_config: Optional[Dict] = None,
**kwargs,
) -> None:
self.X: Optional[pd.DataFrame] = None
self.y: Optional[Union[pd.Series, np.ndarray]] = None
self.date_column: str = date_column
self.channel_columns: Union[List[str], Tuple[str]] = channel_columns

self.n_channel: int = len(channel_columns)

self.X: Optional[pd.DataFrame] = None
self.y: Optional[Union[pd.Series, np.ndarray]] = None

self._time_resolution: Optional[int] = None
self._time_index: Optional[np.ndarray[int]] = None
self._time_index_mid: Optional[int] = None
self._fit_result: Optional[az.InferenceData] = None
self._posterior_predictive: Optional[az.InferenceData] = None
super().__init__(model_config=model_config, sampler_config=sampler_config)
Expand Down Expand Up @@ -319,7 +325,7 @@ def plot_posterior_predictive(
fig, ax = plt.subplots(**plt_kwargs)
if self.X is not None and self.y is not None:
ax.fill_between(
x=self.X[self.date_column],
x=posterior_predictive_data.date,
y1=likelihood_hdi_94[:, 0],
y2=likelihood_hdi_94[:, 1],
color="C0",
Expand All @@ -328,19 +334,26 @@ def plot_posterior_predictive(
)

ax.fill_between(
x=self.X[self.date_column],
x=posterior_predictive_data.date,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason for this change? The date column can have a different name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This enables OOS posterior predictive. self.X does not get updated upon self._data_setter(X_new). So if the posterior predictive is for new OOS data, then self.X[self.date_column] will still be the in-sample dates.

y1=likelihood_hdi_50[:, 0],
y2=likelihood_hdi_50[:, 1],
color="C0",
alpha=0.3,
label="$50\%$ HDI", # noqa: W605
)

target_to_plot: np.ndarray = np.asarray(
self.y if original_scale else self.preprocessed_data["y"] # type: ignore
target_to_plot = np.asarray(
self.y if original_scale else self.get_target_transformer().transform(self.y[:, None]).flatten() # type: ignore
)

assert len(target_to_plot) == len(posterior_predictive_data.date), (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can wee keep date_col as generic date name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean set date_col = posterior_predictive_data.date somewhere earlier in this function?

"The length of the target variable doesn't match the length of the date column. "
"If you are predicting out-of-sample, please overwrite `self.y` with the "
"corresponding (non-transformed) target variable."
)

ax.plot(
np.asarray(self.X[self.date_column]),
np.asarray(posterior_predictive_data.date),
target_to_plot,
color="black",
)
Expand Down Expand Up @@ -417,11 +430,18 @@ def plot_components_contributions(self, **plt_kwargs: Any) -> plt.Figure:
intercept = az.extract(
self.fit_result, var_names=["intercept"], combined=False
)
intercept_hdi = np.repeat(
a=az.hdi(intercept).intercept.data[None, ...],
repeats=self.X[self.date_column].shape[0],
axis=0,
)

if intercept.ndim == 2:
# Intercept has a stationary prior
intercept_hdi = np.repeat(
a=az.hdi(intercept).intercept.data[None, ...],
repeats=self.X[self.date_column].shape[0],
axis=0,
)
elif intercept.ndim == 3:
# Intercept has a time-varying prior
intercept_hdi = az.hdi(intercept).intercept.data

ax.plot(
np.asarray(self.X[self.date_column]),
np.full(len(self.X[self.date_column]), intercept.mean().data),
Expand Down Expand Up @@ -992,6 +1012,7 @@ def label_func(channel):

def legend_title_func(channel):
return "Legend"

else:
nrows = len(channels_to_plot)
figsize = (12, 4 * len(channels_to_plot))
Expand Down
123 changes: 108 additions & 15 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pymc_marketing.mmm.base import MMM
from pymc_marketing.mmm.preprocessing import MaxAbsScaleChannels, MaxAbsScaleTarget
from pymc_marketing.mmm.transformers import geometric_adstock, logistic_saturation
from pymc_marketing.mmm.tvp import time_varying_prior
from pymc_marketing.mmm.utils import (
apply_sklearn_transformer_across_date,
generate_fourier_modes,
Expand All @@ -33,6 +34,8 @@ def __init__(
date_column: str,
channel_columns: List[str],
adstock_max_lag: int,
time_varying_media_effect: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, small comment here... This is a matter of taste but I think having two parameters for the same component is a bit strange. What do you think if we evaluate with strings or tuples?

time_varying = ('intercept', 'media') #or
time_varying = 'intercept-media'

I think it would be better inside for the API

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I quite like this. Maybe this gets too complicated but:

time_varying='intercept'
...
time_varying='total_media'
...
time_varying=['intercept', 'total_media']
...
time_varying=['intercept', 'channel1', 'channe2']

could work

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe even

time_varying=['intercept', ('channel1', 'channel2')] # now channel1 and channel2 are summed and multiplied by a time varying coef

time_varying_intercept: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be great if we could control the HSGP. Could we add an hsgp_config similar to how we defined the priors? So the more "expert" user could play with how much flexibility they want to give to the HSGP?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better to just support that in the model config? The chief thing you'd want to control is the covariance function, m and L. But for MMM this can even be simplified to just lengthscale and m, I believe since optimal L can be calculated form the length of the data and the lengthscale.

model_config: Optional[Dict] = None,
sampler_config: Optional[Dict] = None,
validate_data: bool = True,
Expand All @@ -48,6 +51,12 @@ def __init__(
Column name of the date variable.
channel_columns : List[str]
Column names of the media channel variables.
adstock_max_lag : int
Number of lags to consider in the adstock transformation.
time_varying_media_effect : bool, optional
Whether to consider time-varying media effects, by default False.
time_varying_intercept : bool, optional
Whether to consider time-varying intercept, by default False.
model_config : Dictionary, optional
dictionary of parameters that initialise model configuration. Class-default defined by the user default_model_config method.
sampler_config : Dictionary, optional
Expand All @@ -67,6 +76,8 @@ def __init__(
"""
self.control_columns = control_columns
self.adstock_max_lag = adstock_max_lag
self.time_varying_media_effect = time_varying_media_effect
self.time_varying_intercept = time_varying_intercept
self.yearly_seasonality = yearly_seasonality
self.date_column = date_column
self.validate_data = validate_data
Expand All @@ -91,15 +102,34 @@ def output_var(self):
def _generate_and_preprocess_model_data( # type: ignore
self, X: Union[pd.DataFrame, pd.Series], y: Union[pd.Series, np.ndarray]
) -> None:
"""
Applies preprocessing to the data before fitting the model.
if validate is True, it will check if the data is valid for the model.
sets self.model_coords based on provided dataset
"""Preprocess data and set model state variables.

Applies preprocessing to the data before fitting the model. If validate
is True, it will check if the data is valid for the model. *Only* gets
called before fitting the model.

Parameters
----------
X : Union[pd.DataFrame, pd.Series], shape (n_obs, n_features)
y : Union[pd.Series, np.ndarray], shape (n_obs,)

Sets
----
preprocessed_data : Dict[str, Union[pd.DataFrame, pd.Series]]
Preprocessed data for the model.
X : pd.DataFrame
A filtered version of the input `X`, such that it is guaranteed that
it contains only the `date_column`, the columns that are specified
in the `channel_columns` and `control_columns`, and fourier features
if `yearly_seasonality=True`.
y : Union[pd.Series, np.ndarray]
The target variable for the model (as provided).
_time_index : np.ndarray
The index of the date column. Used by TVP
_time_index_mid : int
The middle index of the date index. Used by TVP.
_time_resolution: int
The time resolution of the date index. Used by TVP.
"""
date_data = X[self.date_column]
channel_data = X[self.channel_columns]
Expand Down Expand Up @@ -139,6 +169,11 @@ def _generate_and_preprocess_model_data( # type: ignore
}
self.X: pd.DataFrame = X_data
self.y: Union[pd.Series, np.ndarray] = y
self._time_index = np.arange(0, X.shape[0])
self._time_index_mid = X.shape[0] // 2
self._time_resolution = (
self.X[self.date_column].iloc[1] - self.X[self.date_column].iloc[0]
).days

def _save_input_params(self, idata) -> None:
"""Saves input parameters to the attrs of idata."""
Expand Down Expand Up @@ -337,9 +372,52 @@ def build_model(
dims="date",
)

intercept = self.intercept_dist(
name="intercept", **self.model_config["intercept"]["kwargs"]
)
if self.time_varying_intercept or self.time_varying_media_effect:
time_index = pm.MutableData(
"time_index",
self._time_index,
dims="date",
)

if self.time_varying_intercept:
tv_multiplier_intercept = time_varying_prior(
name="tv_multiplier_intercept",
X=time_index,
X_mid=self._time_index_mid,
positive=True,
m=200,
L=[self._time_index_mid + 365 / self._time_resolution],
ls_mu=365 / self._time_resolution * 2,
ls_sigma=10,
eta_lam=1,
dims="date",
)
intercept_base = self.intercept_dist(
name="intercept_base", **self.model_config["intercept"]["kwargs"]
)
intercept = pm.Deterministic(
name="intercept",
var=intercept_base * tv_multiplier_intercept,
dims="date",
)
else:
intercept = self.intercept_dist(
name="intercept", **self.model_config["intercept"]["kwargs"]
)

if self.time_varying_media_effect:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of checking twice, would it be better to check by the end? You load the data and apply the multiplier only if True all at once, saving a few lines of code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Can I refactor this whole function actually? Fingertips itching.

tv_multiplier_media = time_varying_prior(
name="tv_multiplier_media",
X=time_index,
X_mid=self._time_index_mid,
positive=True,
m=200,
L=[self._time_index_mid + 365 / self._time_resolution],
ls_mu=365 / self._time_resolution * 2,
ls_sigma=10,
eta_lam=1,
dims="date",
)

beta_channel = self.beta_channel_dist(
name="beta_channel",
Expand Down Expand Up @@ -373,11 +451,21 @@ def build_model(
var=logistic_saturation(x=channel_adstock, lam=lam),
dims=("date", "channel"),
)
channel_contributions = pm.Deterministic(
name="channel_contributions",
var=channel_adstock_saturated * beta_channel,
dims=("date", "channel"),
)

if self.time_varying_media_effect:
channel_contributions = pm.Deterministic(
name="channel_contributions",
var=channel_adstock_saturated
* beta_channel
* tv_multiplier_media[:, None],
dims=("date", "channel"),
)
else:
channel_contributions = pm.Deterministic(
name="channel_contributions",
var=channel_adstock_saturated * beta_channel,
dims=("date", "channel"),
)

mu_var = intercept + channel_contributions.sum(axis=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the HSGP is in dimension "date" from the beginning, I feel that we are first transforming said HSGP to a form that is compatible with ("date", "channel") to return at the end to "date" only. Isn't it better to transform the channels to "date" and finally multiply by the HSGP without needing of [:,None]?

It's just an opinion, I think it would be more appropriate the other way and we can avoid a somewhat unnecessary transformation, what do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't fully understand. You can just commit the change if you believe it is best.

In general this time_varying_prior function supports two dimensions but I'm only using the 1D case here, maybe that's confusing.

I should actually add support for individually varying parameters in this PR too... not complicated

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should actually add support for individually varying parameters in this PR too... not complicated
I suggest we keep the scope of this PR small and then iterate

if (
Expand Down Expand Up @@ -657,11 +745,16 @@ def identity(x):
if hasattr(self, "fourier_columns"):
data["fourier_data"] = self._get_fourier_models_data(X)

if self.time_varying_intercept or self.time_varying_media_effect:
data["time_index"] = np.arange(
self._time_index[-1], self._time_index[-1] + X.shape[0]
)

if y is not None:
if isinstance(y, pd.Series):
data[
"target"
] = y.to_numpy() # convert Series to numpy array explicitly
data["target"] = (
y.to_numpy()
) # convert Series to numpy array explicitly
elif isinstance(y, np.ndarray):
data["target"] = y
else:
Expand Down
72 changes: 72 additions & 0 deletions pymc_marketing/mmm/tvp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Optional

import pymc as pm
from pymc_marketing.mmm.utils import softplus


def time_varying_prior(
name: str,
X: pm.Deterministic,
X_mid: int | float,
positive: bool = False,
dims: Optional[tuple[str, str] | str] = None,
m: int = 40,
L: int = 100,
eta_lam: float = 1,
ls_mu: float = 5,
ls_sigma: float = 5,
cov_func: Optional[pm.gp.cov.Prod] = None,
model: Optional[pm.Model] = None,
) -> pm.Deterministic:
"""Time varying prior, based the Hilbert Space Gaussian Process (HSGP).
Parameters
----------
name : str
Name of the prior.
X : 1d array-like of int or float
Time points.
X_mid : int or float
Midpoint of the time points.
positive : bool
Whether the prior should be positive.
dims : tuple of str or str
Dimensions of the prior.
m : int
Number of basis functions.
L : int
Number of quadrature points.
eta_lam : float
Exponential prior for the variance.
ls_mu : float
Mean of the inverse gamma prior for the lengthscale.
ls_sigma : float
Standard deviation of the inverse gamma prior for the lengthscale.
cov_func : pm.gp.cov.Prod
Covariance function.
model : pm.Model
PyMC model.
Returns
-------
pm.Deterministic
Time-varying prior.
""" # noqa: W605
if cov_func is None:
eta = pm.Exponential(f"eta_{name}", lam=eta_lam)
ls = pm.InverseGamma(f"ls_{name}", mu=ls_mu, sigma=ls_sigma)
cov_func = eta**2 * pm.gp.cov.Matern52(1, ls=ls)

with pm.modelcontext(model) as model:
if type(dims) is tuple:
n_columns = len(model.coords[dims[1]])
hsgp_size = (n_columns, m)
else:
hsgp_size = m
gp = pm.gp.HSGP(m=[m], L=[L], cov_func=cov_func)
phi, sqrt_psd = gp.prior_linearized(Xs=X[:, None] - X_mid)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for future review: We need to be careful we center the data with respect to the training set (even when we are doing out of sample prediction)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate? What is the concern?

hsgp_coefs = pm.Normal(f"_hsgp_coefs_{name}", size=hsgp_size)
f = phi @ (hsgp_coefs * sqrt_psd).T
if positive:
f = softplus(f)
return pm.Deterministic(name, f, dims=dims)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think would be great to create a plot for the varying parameters, showing the recovered latent pattern which is affecting the channels and/or the intercept, what do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Will add!

6 changes: 6 additions & 0 deletions pymc_marketing/mmm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import numpy as np
import numpy.typing as npt
import pandas as pd
import pymc as pm
import pytensor.tensor as pt
import xarray as xr
from scipy.optimize import curve_fit, minimize_scalar

Expand Down Expand Up @@ -326,3 +328,7 @@ def apply_sklearn_transformer_across_date(
data.attrs = attrs

return data


def softplus(x: pt.TensorVariable) -> pt.TensorVariable:
return pm.math.log(1 + pm.math.exp(x))