Skip to content

Commit 236f39d

Browse files
authored
allow for arbitrary additive effects in MVITS model (#1861)
* allow specifying date dim name and change name * add docstring example * implement arbitrary mu_effects * add simple run through test
1 parent c0fcde3 commit 236f39d

File tree

4 files changed

+128
-40
lines changed

4 files changed

+128
-40
lines changed

pymc_marketing/customer_choice/mv_its.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pymc_extras.prior import Prior
2727
from xarray import DataArray
2828

29+
from pymc_marketing.mmm.additive_effect import MuEffect
2930
from pymc_marketing.model_builder import ModelBuilder, create_idata_accessor
3031
from pymc_marketing.model_config import parse_model_config
3132

@@ -69,6 +70,10 @@ def __init__(
6970
model_config = model_config or {}
7071
model_config = parse_model_config(model_config)
7172

73+
self.mu_effects: list[MuEffect] = []
74+
# extras dims for the likelihood
75+
self.dims = ("existing_product",)
76+
7277
super().__init__(model_config=model_config, sampler_config=sampler_config)
7378

7479
self._distribution_checks()
@@ -264,7 +269,7 @@ def build_model(
264269
"""
265270
self._generate_and_preprocess_model_data(X, y) # type: ignore
266271

267-
with pm.Model(coords=self.coords) as model:
272+
with pm.Model(coords=self.coords) as self.model:
268273
# data
269274
_existing_sales = pm.Data(
270275
"existing_sales",
@@ -277,6 +282,9 @@ def build_model(
277282
dims="time",
278283
)
279284

285+
for mu_effect in self.mu_effects:
286+
mu_effect.create_data(self)
287+
280288
# priors
281289
intercept = self.model_config["intercept"].create_variable(name="intercept")
282290

@@ -301,11 +309,13 @@ def build_model(
301309
pm.Deterministic("new sales", beta_all[-1])
302310

303311
# expectation
304-
mu = pm.Deterministic(
305-
"mu",
306-
intercept[None, :] - y[:, None] * beta[None, :],
307-
dims=("time", "existing_product"),
308-
)
312+
313+
mu = intercept[None, :] - y[:, None] * beta[None, :]
314+
315+
for mu_effect in self.mu_effects:
316+
mu += mu_effect.create_effect(self)
317+
318+
mu = pm.Deterministic("mu", mu, dims=("time", "existing_product"))
309319

310320
# likelihood
311321
self.model_config["likelihood"].create_likelihood_variable(
@@ -314,8 +324,6 @@ def build_model(
314324
observed=_existing_sales,
315325
)
316326

317-
self.model = model
318-
319327
def _data_setter(
320328
self,
321329
X: np.ndarray | pd.DataFrame,

pymc_marketing/customer_choice/synthetic_data.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,33 @@ def generate_saturated_data(
5757
data: pd.DataFrame
5858
The synthetic data generated.
5959
60+
61+
Examples
62+
--------
63+
Generate some synthetic data for the MVITS model:
64+
65+
.. code-block:: python
66+
67+
import numpy as np
68+
69+
from pymc_marketing.customer_choice import generate_saturated_data
70+
71+
seed = sum(map(ord, "Saturated Market Data"))
72+
rng = np.random.default_rng(seed)
73+
74+
scenario = {
75+
"total_sales_mu": 1_000,
76+
"total_sales_sigma": 5,
77+
"treatment_time": 40,
78+
"n_observations": 100,
79+
"market_shares_before": [[0.7, 0.3, 0]],
80+
"market_shares_after": [[0.65, 0.25, 0.1]],
81+
"market_share_labels": ["competitor", "own", "new"],
82+
"random_seed": rng,
83+
}
84+
85+
data = generate_saturated_data(**scenario)
86+
6087
"""
6188
rng: np.random.Generator = (
6289
random_seed

pymc_marketing/mmm/additive_effect.py

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from pymc_marketing.mmm.utils import create_index
2929

3030

31-
class MMM(Protocol):
31+
class Model(Protocol):
3232
"""Protocol MMM."""
3333

3434
@property
@@ -43,30 +43,34 @@ def model(self) -> pm.Model:
4343
class MuEffect(Protocol):
4444
"""Protocol for arbitrary additive mu effect."""
4545

46-
def create_data(self, mmm: MMM) -> None:
46+
def create_data(self, mmm: Model) -> None:
4747
"""Create the required data in the model."""
4848

49-
def create_effect(self, mmm: MMM) -> pt.TensorVariable:
49+
def create_effect(self, mmm: Model) -> pt.TensorVariable:
5050
"""Create the additive effect in the model."""
5151

52-
def set_data(self, mmm: MMM, model: pm.Model, X: xr.Dataset) -> None:
52+
def set_data(self, mmm: Model, model: pm.Model, X: xr.Dataset) -> None:
5353
"""Set the data for new predictions."""
5454

5555

5656
class FourierEffect:
5757
"""Fourier seasonality additive effect for MMM."""
5858

59-
def __init__(self, fourier: FourierBase):
59+
def __init__(self, fourier: FourierBase, date_dim_name: str = "date"):
6060
"""Initialize the Fourier effect.
6161
6262
Parameters
6363
----------
6464
fourier : FourierBase
65+
The FourierBase instance to use for the effect.
66+
date_dim_name : str, optional
67+
The name of the date dimension in the model, by default "date".
6568
6669
"""
6770
self.fourier = fourier
71+
self.date_dim_name: str = date_dim_name
6872

69-
def create_data(self, mmm: MMM) -> None:
73+
def create_data(self, mmm: Model) -> None:
7074
"""Create the required data in the model.
7175
7276
Parameters
@@ -77,16 +81,16 @@ def create_data(self, mmm: MMM) -> None:
7781
model = mmm.model
7882

7983
# Get dates from model coordinates
80-
dates = pd.to_datetime(model.coords["date"])
84+
dates = pd.to_datetime(model.coords[self.date_dim_name])
8185

8286
# Add weekday data to the model
8387
pm.Data(
8488
f"{self.fourier.prefix}_day",
8589
self.fourier._get_days_in_period(dates).to_numpy(),
86-
dims="date",
90+
dims=self.date_dim_name,
8791
)
8892

89-
def create_effect(self, mmm: MMM) -> pt.TensorVariable:
93+
def create_effect(self, mmm: Model) -> pt.TensorVariable:
9094
"""Create the Fourier effect in the model.
9195
9296
Parameters
@@ -107,18 +111,18 @@ def create_effect(self, mmm: MMM) -> pt.TensorVariable:
107111

108112
# Create a deterministic variable for the effect
109113
dims = (dim for dim in mmm.dims if dim in self.fourier.prior.dims)
110-
fourier_dims = ("date", *dims)
114+
fourier_dims = (self.date_dim_name, *dims)
111115
fourier_effect_det = pm.Deterministic(
112116
f"{self.fourier.prefix}_effect",
113117
fourier_effect,
114118
dims=fourier_dims,
115119
)
116120

117121
# Handle dimensions for the MMM model
118-
dim_handler = create_dim_handler(("date", *mmm.dims))
122+
dim_handler = create_dim_handler((self.date_dim_name, *mmm.dims))
119123
return dim_handler(fourier_effect_det, fourier_dims)
120124

121-
def set_data(self, mmm: MMM, model: pm.Model, X: xr.Dataset) -> None:
125+
def set_data(self, mmm: Model, model: pm.Model, X: xr.Dataset) -> None:
122126
"""Set the data for new predictions.
123127
124128
Parameters
@@ -131,7 +135,7 @@ def set_data(self, mmm: MMM, model: pm.Model, X: xr.Dataset) -> None:
131135
The dataset for prediction
132136
"""
133137
# Get dates from the new dataset
134-
new_dates = pd.to_datetime(model.coords["date"])
138+
new_dates = pd.to_datetime(model.coords[self.date_dim_name])
135139

136140
# Update the data
137141
new_data = {
@@ -243,12 +247,13 @@ class MockMMM:
243247
244248
"""
245249

246-
def __init__(self, trend: LinearTrend, prefix: str):
250+
def __init__(self, trend: LinearTrend, prefix: str, date_dim_name: str = "date"):
247251
self.trend = trend
248252
self.prefix = prefix
249253
self.linear_trend_first_date: pd.Timestamp
254+
self.date_dim_name: str = date_dim_name
250255

251-
def create_data(self, mmm: MMM) -> None:
256+
def create_data(self, mmm: Model) -> None:
252257
"""Create the required data in the model.
253258
254259
Parameters
@@ -259,13 +264,13 @@ def create_data(self, mmm: MMM) -> None:
259264
model: pm.Model = mmm.model
260265

261266
# Create time index data (normalized between 0 and 1)
262-
dates = pd.to_datetime(model.coords["date"])
267+
dates = pd.to_datetime(model.coords[self.date_dim_name])
263268
self.linear_trend_first_date = dates[0]
264269
t = (dates - self.linear_trend_first_date).days.astype(float)
265270

266-
pm.Data(f"{self.prefix}_t", t, dims="date")
271+
pm.Data(f"{self.prefix}_t", t, dims=self.date_dim_name)
267272

268-
def create_effect(self, mmm: MMM) -> pt.TensorVariable:
273+
def create_effect(self, mmm: Model) -> pt.TensorVariable:
269274
"""Create the trend effect in the model.
270275
271276
Parameters
@@ -289,19 +294,22 @@ def create_effect(self, mmm: MMM) -> pt.TensorVariable:
289294
trend_effect = self.trend.apply(t)
290295

291296
# Create deterministic for the trend effect
292-
trend_dims = ("date", *self.trend.dims) # type: ignore
293-
trend_non_broadcastable_dims = ("date", *self.trend.non_broadcastable_dims)
297+
trend_dims = (self.date_dim_name, *self.trend.dims) # type: ignore
298+
trend_non_broadcastable_dims = (
299+
self.date_dim_name,
300+
*self.trend.non_broadcastable_dims,
301+
)
294302
trend_effect = pm.Deterministic(
295303
f"{self.prefix}_effect_contribution",
296304
trend_effect[create_index(trend_dims, trend_non_broadcastable_dims)],
297305
dims=trend_non_broadcastable_dims,
298306
)
299307

300308
# Return the trend effect
301-
dim_handler = create_dim_handler(("date", *mmm.dims))
309+
dim_handler = create_dim_handler((self.date_dim_name, *mmm.dims))
302310
return dim_handler(trend_effect, trend_non_broadcastable_dims)
303311

304-
def set_data(self, mmm: MMM, model: pm.Model, X: xr.Dataset) -> None:
312+
def set_data(self, mmm: Model, model: pm.Model, X: xr.Dataset) -> None:
305313
"""Set the data for new predictions.
306314
307315
Parameters
@@ -314,7 +322,7 @@ def set_data(self, mmm: MMM, model: pm.Model, X: xr.Dataset) -> None:
314322
The dataset for prediction.
315323
"""
316324
# Create normalized time index for new data
317-
new_dates = pd.to_datetime(model.coords["date"])
325+
new_dates = pd.to_datetime(model.coords[self.date_dim_name])
318326
t = (new_dates - self.linear_trend_first_date).days.astype(float)
319327

320328
# Update the data
@@ -338,13 +346,16 @@ class EventAdditiveEffect(BaseModel):
338346
reference_date : str
339347
The arbitrary reference date to calculate distance from events in days. Default
340348
is "2025-01-01".
349+
date_dim_name : str
350+
The name of the date dimension in the model. Default is "date".
341351
342352
"""
343353

344354
df_events: InstanceOf[pd.DataFrame]
345355
prefix: str
346356
effect: EventEffect
347357
reference_date: str = "2025-01-01"
358+
date_dim_name: str = "date"
348359

349360
def model_post_init(self, context: Any, /) -> None:
350361
"""Post initialization of the model."""
@@ -365,7 +376,7 @@ def end_dates(self) -> pd.Series:
365376
"""The end dates of the events."""
366377
return pd.to_datetime(self.df_events["end_date"])
367378

368-
def create_data(self, mmm: MMM) -> None:
379+
def create_data(self, mmm: Model) -> None:
369380
"""Create the required data in the model.
370381
371382
Parameters
@@ -376,15 +387,15 @@ def create_data(self, mmm: MMM) -> None:
376387
"""
377388
model: pm.Model = mmm.model
378389

379-
model_dates = pd.to_datetime(model.coords["date"])
390+
model_dates = pd.to_datetime(model.coords[self.date_dim_name])
380391

381392
model.add_coord(self.prefix, self.df_events["name"].to_numpy())
382393

383394
if "days" not in model:
384395
pm.Data(
385396
"days",
386397
days_from_reference(model_dates, self.reference_date),
387-
dims="date",
398+
dims=self.date_dim_name,
388399
)
389400

390401
pm.Data(
@@ -398,7 +409,7 @@ def create_data(self, mmm: MMM) -> None:
398409
dims=self.prefix,
399410
)
400411

401-
def create_effect(self, mmm: MMM) -> pt.TensorVariable:
412+
def create_effect(self, mmm: Model) -> pt.TensorVariable:
402413
"""Create the event effect in the model.
403414
404415
Parameters
@@ -430,15 +441,15 @@ def create_basis_matrix(start_ref, end_ref):
430441
total_effect = pm.Deterministic(
431442
f"{self.prefix}_total_effect",
432443
event_effect.sum(axis=1),
433-
dims="date",
444+
dims=self.date_dim_name,
434445
)
435446

436-
dim_handler = create_dim_handler(("date", *mmm.dims))
437-
return dim_handler(total_effect, "date")
447+
dim_handler = create_dim_handler((self.date_dim_name, *mmm.dims))
448+
return dim_handler(total_effect, self.date_dim_name)
438449

439-
def set_data(self, mmm: MMM, model: pm.Model, X: xr.Dataset) -> None:
450+
def set_data(self, mmm: Model, model: pm.Model, X: xr.Dataset) -> None:
440451
"""Set the data for new predictions."""
441-
new_dates = pd.to_datetime(model.coords["date"])
452+
new_dates = pd.to_datetime(model.coords[self.date_dim_name])
442453

443454
new_data = {
444455
"days": days_from_reference(new_dates, self.reference_date),

tests/customer_choice/test_mv_its.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
generate_saturated_data,
2828
generate_unsaturated_data,
2929
)
30+
from pymc_marketing.mmm.additive_effect import FourierEffect
31+
from pymc_marketing.mmm.fourier import YearlyFourier
3032

3133
seed = sum(map(ord, "CustomerChoice"))
3234
rng = np.random.default_rng(seed)
@@ -286,3 +288,43 @@ def test_calculate_counterfactual_raises() -> None:
286288
match = "Call the 'fit' method first."
287289
with pytest.raises(RuntimeError, match=match):
288290
model.calculate_counterfactual()
291+
292+
293+
def test_support_for_mu_effects(saturated_data, mock_pymc_sample) -> None:
294+
model = MVITS(existing_sales=["competitor", "own"])
295+
296+
n_order = 5
297+
fourier = YearlyFourier(
298+
n_order=n_order,
299+
prior=Prior(
300+
"Laplace",
301+
mu=0,
302+
b=1,
303+
dims=("fourier", "existing_product"),
304+
),
305+
)
306+
effect = FourierEffect(fourier=fourier, date_dim_name="time")
307+
308+
model.mu_effects.append(effect)
309+
model.sample(
310+
saturated_data.loc[:, ["competitor", "own"]],
311+
saturated_data["new"],
312+
random_seed=rng,
313+
sample_prior_predictive_kwargs={"samples": 10},
314+
)
315+
316+
n_time = len(saturated_data)
317+
n_existing_products = 2
318+
posterior_size = {"chain": 1, "draw": 10}
319+
320+
assert model.posterior["fourier_effect"].sizes == {
321+
**posterior_size,
322+
"time": n_time,
323+
"existing_product": n_existing_products,
324+
}
325+
assert model.posterior["fourier_beta"].sizes == {
326+
**posterior_size,
327+
"fourier": n_order * 2,
328+
"existing_product": n_existing_products,
329+
}
330+
assert model.posterior["fourier"].sizes == {"fourier": n_order * 2}

0 commit comments

Comments
 (0)