-
Notifications
You must be signed in to change notification settings - Fork 326
Add time-varying coefficient #598
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
faaba0f
3ce5ac4
dac69da
ec316b4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -52,11 +52,17 @@ def __init__( | |
| sampler_config: Optional[Dict] = None, | ||
| **kwargs, | ||
| ) -> None: | ||
| self.X: Optional[pd.DataFrame] = None | ||
| self.y: Optional[Union[pd.Series, np.ndarray]] = None | ||
| self.date_column: str = date_column | ||
| self.channel_columns: Union[List[str], Tuple[str]] = channel_columns | ||
|
|
||
| self.n_channel: int = len(channel_columns) | ||
|
|
||
| self.X: Optional[pd.DataFrame] = None | ||
| self.y: Optional[Union[pd.Series, np.ndarray]] = None | ||
|
|
||
| self._time_resolution: Optional[int] = None | ||
| self._time_index: Optional[np.ndarray[int]] = None | ||
| self._time_index_mid: Optional[int] = None | ||
| self._fit_result: Optional[az.InferenceData] = None | ||
| self._posterior_predictive: Optional[az.InferenceData] = None | ||
| super().__init__(model_config=model_config, sampler_config=sampler_config) | ||
|
|
@@ -319,7 +325,7 @@ def plot_posterior_predictive( | |
| fig, ax = plt.subplots(**plt_kwargs) | ||
| if self.X is not None and self.y is not None: | ||
| ax.fill_between( | ||
| x=self.X[self.date_column], | ||
| x=posterior_predictive_data.date, | ||
| y1=likelihood_hdi_94[:, 0], | ||
| y2=likelihood_hdi_94[:, 1], | ||
| color="C0", | ||
|
|
@@ -328,19 +334,26 @@ def plot_posterior_predictive( | |
| ) | ||
|
|
||
| ax.fill_between( | ||
| x=self.X[self.date_column], | ||
| x=posterior_predictive_data.date, | ||
| y1=likelihood_hdi_50[:, 0], | ||
| y2=likelihood_hdi_50[:, 1], | ||
| color="C0", | ||
| alpha=0.3, | ||
| label="$50\%$ HDI", # noqa: W605 | ||
| ) | ||
|
|
||
| target_to_plot: np.ndarray = np.asarray( | ||
| self.y if original_scale else self.preprocessed_data["y"] # type: ignore | ||
| target_to_plot = np.asarray( | ||
| self.y if original_scale else self.get_target_transformer().transform(self.y[:, None]).flatten() # type: ignore | ||
| ) | ||
|
|
||
| assert len(target_to_plot) == len(posterior_predictive_data.date), ( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can wee keep There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mean set |
||
| "The length of the target variable doesn't match the length of the date column. " | ||
| "If you are predicting out-of-sample, please overwrite `self.y` with the " | ||
| "corresponding (non-transformed) target variable." | ||
| ) | ||
|
|
||
| ax.plot( | ||
| np.asarray(self.X[self.date_column]), | ||
| np.asarray(posterior_predictive_data.date), | ||
| target_to_plot, | ||
| color="black", | ||
| ) | ||
|
|
@@ -417,11 +430,18 @@ def plot_components_contributions(self, **plt_kwargs: Any) -> plt.Figure: | |
| intercept = az.extract( | ||
| self.fit_result, var_names=["intercept"], combined=False | ||
| ) | ||
| intercept_hdi = np.repeat( | ||
| a=az.hdi(intercept).intercept.data[None, ...], | ||
| repeats=self.X[self.date_column].shape[0], | ||
| axis=0, | ||
| ) | ||
|
|
||
| if intercept.ndim == 2: | ||
| # Intercept has a stationary prior | ||
| intercept_hdi = np.repeat( | ||
| a=az.hdi(intercept).intercept.data[None, ...], | ||
| repeats=self.X[self.date_column].shape[0], | ||
| axis=0, | ||
| ) | ||
| elif intercept.ndim == 3: | ||
| # Intercept has a time-varying prior | ||
| intercept_hdi = az.hdi(intercept).intercept.data | ||
|
|
||
| ax.plot( | ||
| np.asarray(self.X[self.date_column]), | ||
| np.full(len(self.X[self.date_column]), intercept.mean().data), | ||
|
|
@@ -992,6 +1012,7 @@ def label_func(channel): | |
|
|
||
| def legend_title_func(channel): | ||
| return "Legend" | ||
|
|
||
| else: | ||
| nrows = len(channels_to_plot) | ||
| figsize = (12, 4 * len(channels_to_plot)) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
| from pymc_marketing.mmm.base import MMM | ||
| from pymc_marketing.mmm.preprocessing import MaxAbsScaleChannels, MaxAbsScaleTarget | ||
| from pymc_marketing.mmm.transformers import geometric_adstock, logistic_saturation | ||
| from pymc_marketing.mmm.tvp import time_varying_prior | ||
| from pymc_marketing.mmm.utils import ( | ||
| apply_sklearn_transformer_across_date, | ||
| generate_fourier_modes, | ||
|
|
@@ -33,6 +34,8 @@ def __init__( | |
| date_column: str, | ||
| channel_columns: List[str], | ||
| adstock_max_lag: int, | ||
| time_varying_media_effect: bool = False, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey, small comment here... This is a matter of taste but I think having two parameters for the same component is a bit strange. What do you think if we evaluate with strings or tuples? I think it would be better inside for the API There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I quite like this. Maybe this gets too complicated but: time_varying='intercept'
...
time_varying='total_media'
...
time_varying=['intercept', 'total_media']
...
time_varying=['intercept', 'channel1', 'channe2']could work There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe even time_varying=['intercept', ('channel1', 'channel2')] # now channel1 and channel2 are summed and multiplied by a time varying coef |
||
| time_varying_intercept: bool = False, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be great if we could control the HSGP. Could we add an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe better to just support that in the model config? The chief thing you'd want to control is the covariance function, m and L. But for MMM this can even be simplified to just lengthscale and m, I believe since optimal L can be calculated form the length of the data and the lengthscale. |
||
| model_config: Optional[Dict] = None, | ||
| sampler_config: Optional[Dict] = None, | ||
| validate_data: bool = True, | ||
|
|
@@ -48,6 +51,12 @@ def __init__( | |
| Column name of the date variable. | ||
| channel_columns : List[str] | ||
| Column names of the media channel variables. | ||
| adstock_max_lag : int | ||
| Number of lags to consider in the adstock transformation. | ||
| time_varying_media_effect : bool, optional | ||
| Whether to consider time-varying media effects, by default False. | ||
| time_varying_intercept : bool, optional | ||
| Whether to consider time-varying intercept, by default False. | ||
| model_config : Dictionary, optional | ||
| dictionary of parameters that initialise model configuration. Class-default defined by the user default_model_config method. | ||
| sampler_config : Dictionary, optional | ||
|
|
@@ -67,6 +76,8 @@ def __init__( | |
| """ | ||
| self.control_columns = control_columns | ||
| self.adstock_max_lag = adstock_max_lag | ||
| self.time_varying_media_effect = time_varying_media_effect | ||
| self.time_varying_intercept = time_varying_intercept | ||
| self.yearly_seasonality = yearly_seasonality | ||
| self.date_column = date_column | ||
| self.validate_data = validate_data | ||
|
|
@@ -91,15 +102,34 @@ def output_var(self): | |
| def _generate_and_preprocess_model_data( # type: ignore | ||
| self, X: Union[pd.DataFrame, pd.Series], y: Union[pd.Series, np.ndarray] | ||
| ) -> None: | ||
| """ | ||
| Applies preprocessing to the data before fitting the model. | ||
| if validate is True, it will check if the data is valid for the model. | ||
| sets self.model_coords based on provided dataset | ||
| """Preprocess data and set model state variables. | ||
|
|
||
| Applies preprocessing to the data before fitting the model. If validate | ||
| is True, it will check if the data is valid for the model. *Only* gets | ||
| called before fitting the model. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| X : Union[pd.DataFrame, pd.Series], shape (n_obs, n_features) | ||
| y : Union[pd.Series, np.ndarray], shape (n_obs,) | ||
|
|
||
| Sets | ||
| ---- | ||
| preprocessed_data : Dict[str, Union[pd.DataFrame, pd.Series]] | ||
| Preprocessed data for the model. | ||
| X : pd.DataFrame | ||
| A filtered version of the input `X`, such that it is guaranteed that | ||
| it contains only the `date_column`, the columns that are specified | ||
| in the `channel_columns` and `control_columns`, and fourier features | ||
| if `yearly_seasonality=True`. | ||
| y : Union[pd.Series, np.ndarray] | ||
| The target variable for the model (as provided). | ||
| _time_index : np.ndarray | ||
| The index of the date column. Used by TVP | ||
| _time_index_mid : int | ||
| The middle index of the date index. Used by TVP. | ||
| _time_resolution: int | ||
| The time resolution of the date index. Used by TVP. | ||
| """ | ||
| date_data = X[self.date_column] | ||
| channel_data = X[self.channel_columns] | ||
|
|
@@ -139,6 +169,11 @@ def _generate_and_preprocess_model_data( # type: ignore | |
| } | ||
| self.X: pd.DataFrame = X_data | ||
| self.y: Union[pd.Series, np.ndarray] = y | ||
| self._time_index = np.arange(0, X.shape[0]) | ||
| self._time_index_mid = X.shape[0] // 2 | ||
| self._time_resolution = ( | ||
| self.X[self.date_column].iloc[1] - self.X[self.date_column].iloc[0] | ||
| ).days | ||
|
|
||
| def _save_input_params(self, idata) -> None: | ||
| """Saves input parameters to the attrs of idata.""" | ||
|
|
@@ -337,9 +372,52 @@ def build_model( | |
| dims="date", | ||
| ) | ||
|
|
||
| intercept = self.intercept_dist( | ||
| name="intercept", **self.model_config["intercept"]["kwargs"] | ||
| ) | ||
| if self.time_varying_intercept or self.time_varying_media_effect: | ||
| time_index = pm.MutableData( | ||
| "time_index", | ||
| self._time_index, | ||
| dims="date", | ||
| ) | ||
|
|
||
| if self.time_varying_intercept: | ||
| tv_multiplier_intercept = time_varying_prior( | ||
| name="tv_multiplier_intercept", | ||
| X=time_index, | ||
| X_mid=self._time_index_mid, | ||
| positive=True, | ||
| m=200, | ||
| L=[self._time_index_mid + 365 / self._time_resolution], | ||
| ls_mu=365 / self._time_resolution * 2, | ||
| ls_sigma=10, | ||
| eta_lam=1, | ||
| dims="date", | ||
| ) | ||
| intercept_base = self.intercept_dist( | ||
| name="intercept_base", **self.model_config["intercept"]["kwargs"] | ||
| ) | ||
| intercept = pm.Deterministic( | ||
| name="intercept", | ||
| var=intercept_base * tv_multiplier_intercept, | ||
| dims="date", | ||
| ) | ||
| else: | ||
| intercept = self.intercept_dist( | ||
| name="intercept", **self.model_config["intercept"]["kwargs"] | ||
| ) | ||
|
|
||
| if self.time_varying_media_effect: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of checking twice, would it be better to check by the end? You load the data and apply the multiplier only if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. Can I refactor this whole function actually? Fingertips itching. |
||
| tv_multiplier_media = time_varying_prior( | ||
| name="tv_multiplier_media", | ||
| X=time_index, | ||
| X_mid=self._time_index_mid, | ||
| positive=True, | ||
| m=200, | ||
| L=[self._time_index_mid + 365 / self._time_resolution], | ||
| ls_mu=365 / self._time_resolution * 2, | ||
ulfaslak marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ls_sigma=10, | ||
| eta_lam=1, | ||
| dims="date", | ||
| ) | ||
|
|
||
| beta_channel = self.beta_channel_dist( | ||
| name="beta_channel", | ||
|
|
@@ -373,11 +451,21 @@ def build_model( | |
| var=logistic_saturation(x=channel_adstock, lam=lam), | ||
| dims=("date", "channel"), | ||
| ) | ||
| channel_contributions = pm.Deterministic( | ||
| name="channel_contributions", | ||
| var=channel_adstock_saturated * beta_channel, | ||
| dims=("date", "channel"), | ||
| ) | ||
|
|
||
| if self.time_varying_media_effect: | ||
| channel_contributions = pm.Deterministic( | ||
| name="channel_contributions", | ||
| var=channel_adstock_saturated | ||
| * beta_channel | ||
| * tv_multiplier_media[:, None], | ||
| dims=("date", "channel"), | ||
| ) | ||
| else: | ||
| channel_contributions = pm.Deterministic( | ||
| name="channel_contributions", | ||
| var=channel_adstock_saturated * beta_channel, | ||
| dims=("date", "channel"), | ||
| ) | ||
ulfaslak marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| mu_var = intercept + channel_contributions.sum(axis=-1) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the HSGP is in dimension "date" from the beginning, I feel that we are first transforming said HSGP to a form that is compatible with It's just an opinion, I think it would be more appropriate the other way and we can avoid a somewhat unnecessary transformation, what do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't fully understand. You can just commit the change if you believe it is best. In general this time_varying_prior function supports two dimensions but I'm only using the 1D case here, maybe that's confusing. I should actually add support for individually varying parameters in this PR too... not complicated There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| if ( | ||
|
|
@@ -657,11 +745,16 @@ def identity(x): | |
| if hasattr(self, "fourier_columns"): | ||
| data["fourier_data"] = self._get_fourier_models_data(X) | ||
|
|
||
| if self.time_varying_intercept or self.time_varying_media_effect: | ||
| data["time_index"] = np.arange( | ||
| self._time_index[-1], self._time_index[-1] + X.shape[0] | ||
| ) | ||
|
|
||
| if y is not None: | ||
| if isinstance(y, pd.Series): | ||
| data[ | ||
| "target" | ||
| ] = y.to_numpy() # convert Series to numpy array explicitly | ||
| data["target"] = ( | ||
| y.to_numpy() | ||
| ) # convert Series to numpy array explicitly | ||
| elif isinstance(y, np.ndarray): | ||
| data["target"] = y | ||
| else: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| from typing import Optional | ||
|
|
||
| import pymc as pm | ||
| from pymc_marketing.mmm.utils import softplus | ||
|
|
||
|
|
||
| def time_varying_prior( | ||
| name: str, | ||
| X: pm.Deterministic, | ||
| X_mid: int | float, | ||
| positive: bool = False, | ||
| dims: Optional[tuple[str, str] | str] = None, | ||
| m: int = 40, | ||
| L: int = 100, | ||
| eta_lam: float = 1, | ||
| ls_mu: float = 5, | ||
| ls_sigma: float = 5, | ||
| cov_func: Optional[pm.gp.cov.Prod] = None, | ||
| model: Optional[pm.Model] = None, | ||
| ) -> pm.Deterministic: | ||
| """Time varying prior, based the Hilbert Space Gaussian Process (HSGP). | ||
| Parameters | ||
| ---------- | ||
| name : str | ||
| Name of the prior. | ||
| X : 1d array-like of int or float | ||
| Time points. | ||
| X_mid : int or float | ||
| Midpoint of the time points. | ||
| positive : bool | ||
| Whether the prior should be positive. | ||
| dims : tuple of str or str | ||
| Dimensions of the prior. | ||
| m : int | ||
| Number of basis functions. | ||
| L : int | ||
| Number of quadrature points. | ||
| eta_lam : float | ||
| Exponential prior for the variance. | ||
| ls_mu : float | ||
| Mean of the inverse gamma prior for the lengthscale. | ||
| ls_sigma : float | ||
| Standard deviation of the inverse gamma prior for the lengthscale. | ||
| cov_func : pm.gp.cov.Prod | ||
| Covariance function. | ||
| model : pm.Model | ||
| PyMC model. | ||
| Returns | ||
| ------- | ||
| pm.Deterministic | ||
| Time-varying prior. | ||
| """ # noqa: W605 | ||
| if cov_func is None: | ||
| eta = pm.Exponential(f"eta_{name}", lam=eta_lam) | ||
| ls = pm.InverseGamma(f"ls_{name}", mu=ls_mu, sigma=ls_sigma) | ||
| cov_func = eta**2 * pm.gp.cov.Matern52(1, ls=ls) | ||
ulfaslak marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| with pm.modelcontext(model) as model: | ||
| if type(dims) is tuple: | ||
| n_columns = len(model.coords[dims[1]]) | ||
| hsgp_size = (n_columns, m) | ||
| else: | ||
| hsgp_size = m | ||
| gp = pm.gp.HSGP(m=[m], L=[L], cov_func=cov_func) | ||
| phi, sqrt_psd = gp.prior_linearized(Xs=X[:, None] - X_mid) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note for future review: We need to be careful we center the data with respect to the training set (even when we are doing out of sample prediction) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See doctstrings in https://github.com/pymc-devs/pymc/blob/main/pymc/gp/hsgp_approx.py#L243 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you elaborate? What is the concern? |
||
| hsgp_coefs = pm.Normal(f"_hsgp_coefs_{name}", size=hsgp_size) | ||
| f = phi @ (hsgp_coefs * sqrt_psd).T | ||
| if positive: | ||
| f = softplus(f) | ||
| return pm.Deterministic(name, f, dims=dims) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think would be great to create a plot for the varying parameters, showing the recovered latent pattern which is affecting the channels and/or the intercept, what do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree. Will add! |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the reason for this change? The date column can have a different name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This enables OOS posterior predictive.
self.Xdoes not get updated uponself._data_setter(X_new). So if the posterior predictive is for new OOS data, thenself.X[self.date_column]will still be the in-sample dates.