-
Notifications
You must be signed in to change notification settings - Fork 326
Time varying intercept #628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 40 commits
faaba0f
3ce5ac4
dac69da
ec316b4
7e8aee7
b0aaa2d
16f9414
8ab0532
c06c06d
b70cd06
9b1287b
e01f9ac
59708cf
83e7103
196d720
c77ddce
f25bc6e
1b3e73e
32f11e1
b35df87
293e752
910d223
15308ab
19d13d8
0fad32b
973a921
95c7ee8
4158768
843ec21
1c90255
d5e1699
74c09c0
a6c5972
374303c
3ac1f6a
0bfbbe4
665d1d2
6defce8
5f5be67
33f8f7b
7677015
e0b8ad6
3ac0585
3d207b2
ba1de7b
85cf5aa
0f89bd4
23ea9ec
d4688be
d3ced36
f8c0b3c
3f88589
f0e59e6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| DAYS_IN_YEAR = 365.25 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,13 +14,15 @@ | |
| from pytensor.tensor import TensorVariable | ||
| from xarray import DataArray, Dataset | ||
|
|
||
| from pymc_marketing.constants import DAYS_IN_YEAR | ||
| from pymc_marketing.mmm.base import MMM | ||
| from pymc_marketing.mmm.lift_test import ( | ||
| add_logistic_empirical_lift_measurements_to_likelihood, | ||
| scale_lift_measurements, | ||
| ) | ||
| from pymc_marketing.mmm.preprocessing import MaxAbsScaleChannels, MaxAbsScaleTarget | ||
| from pymc_marketing.mmm.transformers import geometric_adstock, logistic_saturation | ||
| from pymc_marketing.mmm.tvp import create_time_varying_intercept, infer_time_index | ||
| from pymc_marketing.mmm.utils import ( | ||
| apply_sklearn_transformer_across_dim, | ||
| create_new_spend_data, | ||
|
|
@@ -47,6 +49,7 @@ def __init__( | |
| date_column: str, | ||
| channel_columns: list[str], | ||
| adstock_max_lag: int, | ||
| time_varying_intercept: bool = False, | ||
| model_config: dict | None = None, | ||
| sampler_config: dict | None = None, | ||
| validate_data: bool = True, | ||
|
|
@@ -62,6 +65,10 @@ def __init__( | |
| Column name of the date variable. | ||
| channel_columns : List[str] | ||
| Column names of the media channel variables. | ||
| adstock_max_lag : int | ||
| Number of lags to consider in the adstock transformation. | ||
| time_varying_intercept : bool, optional | ||
| Whether to consider time-varying intercept, by default False. | ||
| model_config : Dictionary, optional | ||
| dictionary of parameters that initialise model configuration. | ||
| Class-default defined by the user default_model_config method. | ||
|
|
@@ -79,6 +86,7 @@ def __init__( | |
| """ | ||
| self.control_columns = control_columns | ||
| self.adstock_max_lag = adstock_max_lag | ||
| self.time_varying_intercept = time_varying_intercept | ||
| self.yearly_seasonality = yearly_seasonality | ||
| self.date_column = date_column | ||
| self.validate_data = validate_data | ||
|
|
@@ -112,6 +120,24 @@ def _generate_and_preprocess_model_data( # type: ignore | |
| ---------- | ||
| X : Union[pd.DataFrame, pd.Series], shape (n_obs, n_features) | ||
| y : Union[pd.Series, np.ndarray], shape (n_obs,) | ||
|
|
||
| Sets | ||
| ---- | ||
| preprocessed_data : Dict[str, Union[pd.DataFrame, pd.Series]] | ||
| Preprocessed data for the model. | ||
| X : pd.DataFrame | ||
| A filtered version of the input `X`, such that it is guaranteed that | ||
| it contains only the `date_column`, the columns that are specified | ||
| in the `channel_columns` and `control_columns`, and fourier features | ||
| if `yearly_seasonality=True`. | ||
| y : Union[pd.Series, np.ndarray] | ||
| The target variable for the model (as provided). | ||
| _time_index : np.ndarray | ||
| The index of the date column. Used by TVP | ||
| _time_index_mid : int | ||
| The middle index of the date index. Used by TVP. | ||
| _time_resolution: int | ||
| The time resolution of the date index. Used by TVP. | ||
| """ | ||
| date_data = X[self.date_column] | ||
| channel_data = X[self.channel_columns] | ||
|
|
@@ -152,6 +178,13 @@ def _generate_and_preprocess_model_data( # type: ignore | |
| self.X: pd.DataFrame = X_data | ||
| self.y: pd.Series | np.ndarray = y | ||
|
|
||
| if self.time_varying_intercept: | ||
| self._time_index = np.arange(0, X.shape[0]) | ||
| self._time_index_mid = X.shape[0] // 2 | ||
| self._time_resolution = ( | ||
| self.X[self.date_column].iloc[1] - self.X[self.date_column].iloc[0] | ||
| ).days | ||
|
|
||
| def _save_input_params(self, idata) -> None: | ||
| """Saves input parameters to the attrs of idata.""" | ||
| idata.attrs["date_column"] = json.dumps(self.date_column) | ||
|
|
@@ -355,9 +388,23 @@ def build_model( | |
| dims="date", | ||
| ) | ||
|
|
||
| intercept = self.intercept_dist( | ||
| name="intercept", **self.model_config["intercept"]["kwargs"] | ||
| ) | ||
| if self.time_varying_intercept: | ||
ulfaslak marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| time_index = pm.Data( | ||
| "time_index", | ||
| self._time_index, | ||
| dims="date", | ||
| ) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe time_index = pm.MutableData(
"time_index",
value=np.arange(self.x_channel_data.shape[0]),
dims="date",
) |
||
| intercept = create_time_varying_intercept( | ||
| time_index, | ||
| self._time_index_mid, | ||
| self._time_resolution, | ||
| self.intercept_dist, | ||
| self.model_config, | ||
| ) | ||
| else: | ||
| intercept = self.intercept_dist( | ||
| name="intercept", **self.model_config["intercept"]["kwargs"] | ||
| ) | ||
|
|
||
| beta_channel = self.beta_channel_dist( | ||
| name="beta_channel", | ||
|
|
@@ -391,9 +438,11 @@ def build_model( | |
| var=logistic_saturation(x=channel_adstock, lam=lam), | ||
| dims=("date", "channel"), | ||
| ) | ||
|
|
||
| channel_contributions_var = channel_adstock_saturated * beta_channel | ||
| channel_contributions = pm.Deterministic( | ||
| name="channel_contributions", | ||
| var=channel_adstock_saturated * beta_channel, | ||
| var=channel_contributions_var, | ||
| dims=("date", "channel"), | ||
| ) | ||
|
|
||
|
|
@@ -468,7 +517,10 @@ def build_model( | |
| @property | ||
| def default_model_config(self) -> dict: | ||
| return { | ||
| "intercept": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}}, | ||
| "intercept": { | ||
| "dist": "Normal", | ||
| "kwargs": {"mu": 0, "sigma": 2}, | ||
| }, | ||
| "beta_channel": {"dist": "HalfNormal", "kwargs": {"sigma": 2}}, | ||
| "alpha": {"dist": "Beta", "kwargs": {"alpha": 1, "beta": 3}}, | ||
| "lam": {"dist": "Gamma", "kwargs": {"alpha": 3, "beta": 1}}, | ||
|
|
@@ -480,6 +532,14 @@ def default_model_config(self) -> dict: | |
| }, | ||
| "gamma_control": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}}, | ||
| "gamma_fourier": {"dist": "Laplace", "kwargs": {"mu": 0, "b": 1}}, | ||
| "intercept_tvp_kwargs": { | ||
| "m": 200, | ||
| "L": None, | ||
| "eta_lam": 1, | ||
| "ls_mu": None, | ||
| "ls_sigma": 10, | ||
| "cov_func": None, | ||
|
Comment on lines
+536
to
+541
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do we expect to happen with these There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's are checks starting on like 399, which set them if they are None. The reason that we wan't directly set them here is that the best defaults are estimated from the data, and we can't know here. |
||
| }, | ||
| } | ||
|
|
||
| def _get_fourier_models_data(self, X) -> pd.DataFrame: | ||
|
|
@@ -494,7 +554,9 @@ def _get_fourier_models_data(self, X) -> pd.DataFrame: | |
| date_data: pd.Series = pd.to_datetime( | ||
| arg=X[self.date_column], format="%Y-%m-%d" | ||
| ) | ||
| periods: npt.NDArray[np.float_] = date_data.dt.dayofyear.to_numpy() / 365.25 | ||
| periods: npt.NDArray[np.float_] = ( | ||
| date_data.dt.dayofyear.to_numpy() / DAYS_IN_YEAR | ||
| ) | ||
| return generate_fourier_modes( | ||
| periods=periods, | ||
| n_order=self.yearly_seasonality, | ||
|
|
@@ -678,6 +740,11 @@ def identity(x): | |
| if hasattr(self, "fourier_columns"): | ||
| data["fourier_data"] = self._get_fourier_models_data(X) | ||
|
|
||
| if self.time_varying_intercept: | ||
| data["time_index"] = infer_time_index( | ||
| X[self.date_column], self.X[self.date_column], self._time_resolution | ||
| ) | ||
|
|
||
| if y is not None: | ||
| if isinstance(y, pd.Series): | ||
| data["target"] = ( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.