-
Notifications
You must be signed in to change notification settings - Fork 326
Add time-varying coefficient #598
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
faaba0f
3ce5ac4
dac69da
ec316b4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -52,11 +52,17 @@ def __init__( | |
| sampler_config: Optional[Dict] = None, | ||
| **kwargs, | ||
| ) -> None: | ||
| self.X: Optional[pd.DataFrame] = None | ||
| self.y: Optional[Union[pd.Series, np.ndarray]] = None | ||
| self.date_column: str = date_column | ||
| self.channel_columns: Union[List[str], Tuple[str]] = channel_columns | ||
|
|
||
| self.n_channel: int = len(channel_columns) | ||
|
|
||
| self.X: Optional[pd.DataFrame] = None | ||
| self.y: Optional[Union[pd.Series, np.ndarray]] = None | ||
|
|
||
| self._time_resolution: Optional[int] = None | ||
| self._time_index: Optional[np.ndarray[int]] = None | ||
| self._time_index_mid: Optional[int] = None | ||
| self._fit_result: Optional[az.InferenceData] = None | ||
| self._posterior_predictive: Optional[az.InferenceData] = None | ||
| super().__init__(model_config=model_config, sampler_config=sampler_config) | ||
|
|
@@ -327,7 +333,7 @@ def plot_posterior_predictive( | |
| fig, ax = plt.subplots(**plt_kwargs) | ||
| if self.X is not None and self.y is not None: | ||
| ax.fill_between( | ||
| x=self.X[self.date_column], | ||
| x=posterior_predictive_data.date, | ||
| y1=likelihood_hdi_94[:, 0], | ||
| y2=likelihood_hdi_94[:, 1], | ||
| color="C0", | ||
|
|
@@ -336,19 +342,26 @@ def plot_posterior_predictive( | |
| ) | ||
|
|
||
| ax.fill_between( | ||
| x=self.X[self.date_column], | ||
| x=posterior_predictive_data.date, | ||
| y1=likelihood_hdi_50[:, 0], | ||
| y2=likelihood_hdi_50[:, 1], | ||
| color="C0", | ||
| alpha=0.3, | ||
| label="$50\%$ HDI", # noqa: W605 | ||
| ) | ||
|
|
||
| target_to_plot: np.ndarray = np.asarray( | ||
| self.y if original_scale else self.preprocessed_data["y"] # type: ignore | ||
| target_to_plot = np.asarray( | ||
| self.y if original_scale else self.get_target_transformer().transform(self.y[:, None]).flatten() # type: ignore | ||
| ) | ||
|
|
||
| assert len(target_to_plot) == len(posterior_predictive_data.date), ( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can wee keep There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mean set |
||
| "The length of the target variable doesn't match the length of the date column. " | ||
| "If you are predicting out-of-sample, please overwrite `self.y` with the " | ||
| "corresponding (non-transformed) target variable." | ||
| ) | ||
|
|
||
| ax.plot( | ||
| np.asarray(self.X[self.date_column]), | ||
| np.asarray(posterior_predictive_data.date), | ||
| target_to_plot, | ||
| color="black", | ||
| ) | ||
|
|
@@ -425,11 +438,18 @@ def plot_components_contributions(self, **plt_kwargs: Any) -> plt.Figure: | |
| intercept = az.extract( | ||
| self.fit_result, var_names=["intercept"], combined=False | ||
| ) | ||
| intercept_hdi = np.repeat( | ||
| a=az.hdi(intercept).intercept.data[None, ...], | ||
| repeats=self.X[self.date_column].shape[0], | ||
| axis=0, | ||
| ) | ||
|
|
||
| if intercept.ndim == 2: | ||
| # Intercept has a stationary prior | ||
| intercept_hdi = np.repeat( | ||
| a=az.hdi(intercept).intercept.data[None, ...], | ||
| repeats=self.X[self.date_column].shape[0], | ||
| axis=0, | ||
| ) | ||
| elif intercept.ndim == 3: | ||
| # Intercept has a time-varying prior | ||
| intercept_hdi = az.hdi(intercept).intercept.data | ||
|
|
||
| ax.plot( | ||
| np.asarray(self.X[self.date_column]), | ||
| np.full(len(self.X[self.date_column]), intercept.mean().data), | ||
|
|
@@ -1000,6 +1020,7 @@ def label_func(channel): | |
|
|
||
| def legend_title_func(channel): | ||
| return "Legend" | ||
|
|
||
| else: | ||
| nrows = len(channels_to_plot) | ||
| figsize = (12, 4 * len(channels_to_plot)) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
| from pymc_marketing.mmm.base import MMM | ||
| from pymc_marketing.mmm.preprocessing import MaxAbsScaleChannels, MaxAbsScaleTarget | ||
| from pymc_marketing.mmm.transformers import geometric_adstock, logistic_saturation | ||
| from pymc_marketing.mmm.tvp import time_varying_prior | ||
| from pymc_marketing.mmm.utils import ( | ||
| apply_sklearn_transformer_across_dim, | ||
| create_new_spend_data, | ||
|
|
@@ -24,6 +25,7 @@ | |
|
|
||
| __all__ = ["DelayedSaturatedMMM"] | ||
|
|
||
| DAYS_IN_YEAR = 365.25 | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can move this into a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perfect. Will do. |
||
|
|
||
| class BaseDelayedSaturatedMMM(MMM): | ||
| _model_type = "DelayedSaturatedMMM" | ||
|
|
@@ -34,6 +36,8 @@ def __init__( | |
| date_column: str, | ||
| channel_columns: List[str], | ||
| adstock_max_lag: int, | ||
| time_varying_media_effect: bool = False, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey, small comment here... This is a matter of taste but I think having two parameters for the same component is a bit strange. What do you think if we evaluate with strings or tuples? I think it would be better inside for the API There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I quite like this. Maybe this gets too complicated but: time_varying='intercept'
...
time_varying='total_media'
...
time_varying=['intercept', 'total_media']
...
time_varying=['intercept', 'channel1', 'channe2']could work There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe even time_varying=['intercept', ('channel1', 'channel2')] # now channel1 and channel2 are summed and multiplied by a time varying coef |
||
| time_varying_intercept: bool = False, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be great if we could control the HSGP. Could we add an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe better to just support that in the model config? The chief thing you'd want to control is the covariance function, m and L. But for MMM this can even be simplified to just lengthscale and m, I believe since optimal L can be calculated form the length of the data and the lengthscale. |
||
| model_config: Optional[Dict] = None, | ||
| sampler_config: Optional[Dict] = None, | ||
| validate_data: bool = True, | ||
|
|
@@ -49,6 +53,12 @@ def __init__( | |
| Column name of the date variable. | ||
| channel_columns : List[str] | ||
| Column names of the media channel variables. | ||
| adstock_max_lag : int | ||
| Number of lags to consider in the adstock transformation. | ||
| time_varying_media_effect : bool, optional | ||
| Whether to consider time-varying media effects, by default False. | ||
| time_varying_intercept : bool, optional | ||
| Whether to consider time-varying intercept, by default False. | ||
| model_config : Dictionary, optional | ||
| dictionary of parameters that initialise model configuration. Class-default defined by the user default_model_config method. | ||
| sampler_config : Dictionary, optional | ||
|
|
@@ -68,6 +78,8 @@ def __init__( | |
| """ | ||
| self.control_columns = control_columns | ||
| self.adstock_max_lag = adstock_max_lag | ||
| self.time_varying_media_effect = time_varying_media_effect | ||
| self.time_varying_intercept = time_varying_intercept | ||
| self.yearly_seasonality = yearly_seasonality | ||
| self.date_column = date_column | ||
| self.validate_data = validate_data | ||
|
|
@@ -92,15 +104,34 @@ def output_var(self): | |
| def _generate_and_preprocess_model_data( # type: ignore | ||
| self, X: Union[pd.DataFrame, pd.Series], y: Union[pd.Series, np.ndarray] | ||
| ) -> None: | ||
| """ | ||
| Applies preprocessing to the data before fitting the model. | ||
| if validate is True, it will check if the data is valid for the model. | ||
| sets self.model_coords based on provided dataset | ||
| """Preprocess data and set model state variables. | ||
|
|
||
| Applies preprocessing to the data before fitting the model. If validate | ||
| is True, it will check if the data is valid for the model. *Only* gets | ||
| called before fitting the model. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| X : Union[pd.DataFrame, pd.Series], shape (n_obs, n_features) | ||
| y : Union[pd.Series, np.ndarray], shape (n_obs,) | ||
|
|
||
| Sets | ||
| ---- | ||
| preprocessed_data : Dict[str, Union[pd.DataFrame, pd.Series]] | ||
| Preprocessed data for the model. | ||
| X : pd.DataFrame | ||
| A filtered version of the input `X`, such that it is guaranteed that | ||
| it contains only the `date_column`, the columns that are specified | ||
| in the `channel_columns` and `control_columns`, and fourier features | ||
| if `yearly_seasonality=True`. | ||
| y : Union[pd.Series, np.ndarray] | ||
| The target variable for the model (as provided). | ||
| _time_index : np.ndarray | ||
| The index of the date column. Used by TVP | ||
| _time_index_mid : int | ||
| The middle index of the date index. Used by TVP. | ||
| _time_resolution: int | ||
| The time resolution of the date index. Used by TVP. | ||
| """ | ||
| date_data = X[self.date_column] | ||
| channel_data = X[self.channel_columns] | ||
|
|
@@ -140,6 +171,11 @@ def _generate_and_preprocess_model_data( # type: ignore | |
| } | ||
| self.X: pd.DataFrame = X_data | ||
| self.y: Union[pd.Series, np.ndarray] = y | ||
| self._time_index = np.arange(0, X.shape[0]) | ||
| self._time_index_mid = X.shape[0] // 2 | ||
| self._time_resolution = ( | ||
| self.X[self.date_column].iloc[1] - self.X[self.date_column].iloc[0] | ||
| ).days | ||
|
|
||
| def _save_input_params(self, idata) -> None: | ||
| """Saves input parameters to the attrs of idata.""" | ||
|
|
@@ -338,9 +374,52 @@ def build_model( | |
| dims="date", | ||
| ) | ||
|
|
||
| intercept = self.intercept_dist( | ||
| name="intercept", **self.model_config["intercept"]["kwargs"] | ||
| ) | ||
| if self.time_varying_intercept or self.time_varying_media_effect: | ||
| time_index = pm.MutableData( | ||
| "time_index", | ||
| self._time_index, | ||
| dims="date", | ||
| ) | ||
|
|
||
| if self.time_varying_intercept: | ||
| tv_multiplier_intercept = time_varying_prior( | ||
| name="tv_multiplier_intercept", | ||
| X=time_index, | ||
| X_mid=self._time_index_mid, | ||
| positive=True, | ||
| m=200, | ||
| L=[self._time_index_mid + DAYS_IN_YEAR / self._time_resolution], | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it might be cleaner to use the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately this causes some trouble with predicting out of sample. @bwengals did you find the cause for this? |
||
| ls_mu=DAYS_IN_YEAR / self._time_resolution * 2, | ||
| ls_sigma=10, | ||
| eta_lam=1, | ||
| dims="date", | ||
| ) | ||
| intercept_base = self.intercept_dist( | ||
| name="intercept_base", **self.model_config["intercept"]["kwargs"] | ||
| ) | ||
| intercept = pm.Deterministic( | ||
| name="intercept", | ||
| var=intercept_base * tv_multiplier_intercept, | ||
| dims="date", | ||
| ) | ||
| else: | ||
| intercept = self.intercept_dist( | ||
| name="intercept", **self.model_config["intercept"]["kwargs"] | ||
| ) | ||
|
|
||
| if self.time_varying_media_effect: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of checking twice, would it be better to check by the end? You load the data and apply the multiplier only if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. Can I refactor this whole function actually? Fingertips itching. |
||
| tv_multiplier_media = time_varying_prior( | ||
| name="tv_multiplier_media", | ||
| X=time_index, | ||
| X_mid=self._time_index_mid, | ||
| positive=True, | ||
| m=200, | ||
| L=[self._time_index_mid + DAYS_IN_YEAR / self._time_resolution], | ||
| ls_mu=DAYS_IN_YEAR / self._time_resolution * 2, | ||
| ls_sigma=10, | ||
| eta_lam=1, | ||
| dims="date", | ||
| ) | ||
|
|
||
| beta_channel = self.beta_channel_dist( | ||
| name="beta_channel", | ||
|
|
@@ -374,10 +453,14 @@ def build_model( | |
| var=logistic_saturation(x=channel_adstock, lam=lam), | ||
| dims=("date", "channel"), | ||
| ) | ||
|
|
||
| channel_contributions_var = channel_adstock_saturated * beta_channel | ||
| if self.time_varying_media_effect: | ||
| channel_contributions_var *= tv_multiplier_media[:, None] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is not a strong opinion, but I think that if we are using the logic of a base contribution, which is then modified, it would be great to have this join saved in another e.g: pm.Deterministic(
var = channel_contributions_var * tv_multiplier_media[:, None]
name='varying_contribution' #or something like it
)There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah this is a nice idea. I think I might actually change the time_varying_prior function so it always swings in the positive range. Then when using it, it always works as a multiplier on base contributions. Will also make joining this logic with logic for hierarchical parameters easy. I'll make this change, it's nice. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree :) |
||
| channel_contributions = pm.Deterministic( | ||
| name="channel_contributions", | ||
| var=channel_adstock_saturated * beta_channel, | ||
| dims=("date", "channel"), | ||
| name="channel_contributions", | ||
| var=channel_contributions_var, | ||
| dims=("date", "channel"), | ||
| ) | ||
|
|
||
| mu_var = intercept + channel_contributions.sum(axis=-1) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the HSGP is in dimension "date" from the beginning, I feel that we are first transforming said HSGP to a form that is compatible with It's just an opinion, I think it would be more appropriate the other way and we can avoid a somewhat unnecessary transformation, what do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't fully understand. You can just commit the change if you believe it is best. In general this time_varying_prior function supports two dimensions but I'm only using the 1D case here, maybe that's confusing. I should actually add support for individually varying parameters in this PR too... not complicated There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
@@ -475,7 +558,7 @@ def _get_fourier_models_data(self, X) -> pd.DataFrame: | |
| date_data: pd.Series = pd.to_datetime( | ||
| arg=X[self.date_column], format="%Y-%m-%d" | ||
| ) | ||
| periods: npt.NDArray[np.float_] = date_data.dt.dayofyear.to_numpy() / 365.25 | ||
| periods: npt.NDArray[np.float_] = date_data.dt.dayofyear.to_numpy() / DAYS_IN_YEAR | ||
| return generate_fourier_modes( | ||
| periods=periods, | ||
| n_order=self.yearly_seasonality, | ||
|
|
@@ -657,6 +740,11 @@ def identity(x): | |
| if hasattr(self, "fourier_columns"): | ||
| data["fourier_data"] = self._get_fourier_models_data(X) | ||
|
|
||
| if self.time_varying_intercept or self.time_varying_media_effect: | ||
| data["time_index"] = np.arange( | ||
| self._time_index[-1], self._time_index[-1] + X.shape[0] | ||
| ) | ||
|
|
||
| if y is not None: | ||
| if isinstance(y, pd.Series): | ||
| data["target"] = ( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| from typing import Optional | ||
|
|
||
| import pymc as pm | ||
| from pymc_marketing.mmm.utils import softplus | ||
|
|
||
|
|
||
| def time_varying_prior( | ||
| name: str, | ||
| X: pm.Deterministic, | ||
| X_mid: int | float, | ||
| positive: bool = False, | ||
| dims: Optional[tuple[str, str] | str] = None, | ||
| m: int = 40, | ||
| L: int = 100, | ||
| eta_lam: float = 1, | ||
| ls_mu: float = 5, | ||
| ls_sigma: float = 5, | ||
| cov_func: Optional[pm.gp.cov.Prod] = None, | ||
| model: Optional[pm.Model] = None, | ||
| ) -> pm.Deterministic: | ||
| """Time varying prior, based the Hilbert Space Gaussian Process (HSGP). | ||
| Parameters | ||
| ---------- | ||
| name : str | ||
| Name of the prior. | ||
| X : 1d array-like of int or float | ||
| Time points. | ||
| X_mid : int or float | ||
| Midpoint of the time points. | ||
| positive : bool | ||
| Whether the prior should be positive. | ||
| dims : tuple of str or str | ||
| Dimensions of the prior. | ||
| m : int | ||
| Number of basis functions. | ||
| L : int | ||
| Number of quadrature points. | ||
| eta_lam : float | ||
| Exponential prior for the variance. | ||
| ls_mu : float | ||
| Mean of the inverse gamma prior for the lengthscale. | ||
| ls_sigma : float | ||
| Standard deviation of the inverse gamma prior for the lengthscale. | ||
| cov_func : pm.gp.cov.Prod | ||
| Covariance function. | ||
| model : pm.Model | ||
| PyMC model. | ||
| Returns | ||
| ------- | ||
| pm.Deterministic | ||
| Time-varying prior. | ||
| """ # noqa: W605 | ||
|
|
||
| with pm.modelcontext(model) as model: | ||
| if cov_func is None: | ||
| eta = pm.Exponential(f"eta_{name}", lam=eta_lam) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think using this current implementation, we could move these distributions into the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh this would be really clever, it just becomes another supported distribution? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's better than an hsgp_config I think. I'm trying to work out what's the right level of rigidity to build into this thing. There are some things here you wouldn't want to change as a user I think. And maybe this is just me, but as we put distributions into the config, we pay the price of more obscure code. Opinions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice to allow these priors to be defined in the config as GPs are quite sensitive to priors. |
||
| ls = pm.InverseGamma(f"ls_{name}", mu=ls_mu, sigma=ls_sigma) | ||
| cov_func = eta**2 * pm.gp.cov.Matern52(1, ls=ls) | ||
|
|
||
| if type(dims) is tuple: | ||
| n_columns = len(model.coords[dims[1]]) | ||
| hsgp_size = (n_columns, m) | ||
| else: | ||
| hsgp_size = m | ||
| gp = pm.gp.HSGP(m=[m], L=[L], cov_func=cov_func) | ||
| phi, sqrt_psd = gp.prior_linearized(Xs=X[:, None] - X_mid) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note for future review: We need to be careful we center the data with respect to the training set (even when we are doing out of sample prediction) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See doctstrings in https://github.com/pymc-devs/pymc/blob/main/pymc/gp/hsgp_approx.py#L243 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you elaborate? What is the concern? |
||
| hsgp_coefs = pm.Normal(f"_hsgp_coefs_{name}", size=hsgp_size) | ||
| f = phi @ (hsgp_coefs * sqrt_psd).T | ||
| if positive: | ||
| f = softplus(f) | ||
| return pm.Deterministic(name, f, dims=dims) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think would be great to create a plot for the varying parameters, showing the recovered latent pattern which is affecting the channels and/or the intercept, what do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree. Will add! |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,8 @@ | |
| import numpy as np | ||
| import numpy.typing as npt | ||
| import pandas as pd | ||
| import pymc as pm | ||
| import pytensor.tensor as pt | ||
| import xarray as xr | ||
| from scipy.optimize import curve_fit, minimize_scalar | ||
|
|
||
|
|
@@ -329,6 +331,10 @@ def apply_sklearn_transformer_across_dim( | |
|
|
||
| return data | ||
|
|
||
|
|
||
| def softplus(x: pt.TensorVariable) -> pt.TensorVariable: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this exists in |
||
| return pm.math.log(1 + pm.math.exp(x)) | ||
|
|
||
|
|
||
| def create_new_spend_data( | ||
| spend: np.ndarray, | ||
|
|
@@ -415,4 +421,4 @@ def create_new_spend_data( | |
| spend_leading_up, | ||
| spend, | ||
| ] | ||
| ) | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the reason for this change? The date column can have a different name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This enables OOS posterior predictive.
self.Xdoes not get updated uponself._data_setter(X_new). So if the posterior predictive is for new OOS data, thenself.X[self.date_column]will still be the in-sample dates.