28
28
from pymc_marketing .mmm .utils import create_index
29
29
30
30
31
- class MMM (Protocol ):
31
+ class Model (Protocol ):
32
32
"""Protocol MMM."""
33
33
34
34
@property
@@ -43,30 +43,34 @@ def model(self) -> pm.Model:
43
43
class MuEffect (Protocol ):
44
44
"""Protocol for arbitrary additive mu effect."""
45
45
46
- def create_data (self , mmm : MMM ) -> None :
46
+ def create_data (self , mmm : Model ) -> None :
47
47
"""Create the required data in the model."""
48
48
49
- def create_effect (self , mmm : MMM ) -> pt .TensorVariable :
49
+ def create_effect (self , mmm : Model ) -> pt .TensorVariable :
50
50
"""Create the additive effect in the model."""
51
51
52
- def set_data (self , mmm : MMM , model : pm .Model , X : xr .Dataset ) -> None :
52
+ def set_data (self , mmm : Model , model : pm .Model , X : xr .Dataset ) -> None :
53
53
"""Set the data for new predictions."""
54
54
55
55
56
56
class FourierEffect :
57
57
"""Fourier seasonality additive effect for MMM."""
58
58
59
- def __init__ (self , fourier : FourierBase ):
59
+ def __init__ (self , fourier : FourierBase , date_dim_name : str = "date" ):
60
60
"""Initialize the Fourier effect.
61
61
62
62
Parameters
63
63
----------
64
64
fourier : FourierBase
65
+ The FourierBase instance to use for the effect.
66
+ date_dim_name : str, optional
67
+ The name of the date dimension in the model, by default "date".
65
68
66
69
"""
67
70
self .fourier = fourier
71
+ self .date_dim_name : str = date_dim_name
68
72
69
- def create_data (self , mmm : MMM ) -> None :
73
+ def create_data (self , mmm : Model ) -> None :
70
74
"""Create the required data in the model.
71
75
72
76
Parameters
@@ -77,16 +81,16 @@ def create_data(self, mmm: MMM) -> None:
77
81
model = mmm .model
78
82
79
83
# Get dates from model coordinates
80
- dates = pd .to_datetime (model .coords ["date" ])
84
+ dates = pd .to_datetime (model .coords [self . date_dim_name ])
81
85
82
86
# Add weekday data to the model
83
87
pm .Data (
84
88
f"{ self .fourier .prefix } _day" ,
85
89
self .fourier ._get_days_in_period (dates ).to_numpy (),
86
- dims = "date" ,
90
+ dims = self . date_dim_name ,
87
91
)
88
92
89
- def create_effect (self , mmm : MMM ) -> pt .TensorVariable :
93
+ def create_effect (self , mmm : Model ) -> pt .TensorVariable :
90
94
"""Create the Fourier effect in the model.
91
95
92
96
Parameters
@@ -107,18 +111,18 @@ def create_effect(self, mmm: MMM) -> pt.TensorVariable:
107
111
108
112
# Create a deterministic variable for the effect
109
113
dims = (dim for dim in mmm .dims if dim in self .fourier .prior .dims )
110
- fourier_dims = ("date" , * dims )
114
+ fourier_dims = (self . date_dim_name , * dims )
111
115
fourier_effect_det = pm .Deterministic (
112
116
f"{ self .fourier .prefix } _effect" ,
113
117
fourier_effect ,
114
118
dims = fourier_dims ,
115
119
)
116
120
117
121
# Handle dimensions for the MMM model
118
- dim_handler = create_dim_handler (("date" , * mmm .dims ))
122
+ dim_handler = create_dim_handler ((self . date_dim_name , * mmm .dims ))
119
123
return dim_handler (fourier_effect_det , fourier_dims )
120
124
121
- def set_data (self , mmm : MMM , model : pm .Model , X : xr .Dataset ) -> None :
125
+ def set_data (self , mmm : Model , model : pm .Model , X : xr .Dataset ) -> None :
122
126
"""Set the data for new predictions.
123
127
124
128
Parameters
@@ -131,7 +135,7 @@ def set_data(self, mmm: MMM, model: pm.Model, X: xr.Dataset) -> None:
131
135
The dataset for prediction
132
136
"""
133
137
# Get dates from the new dataset
134
- new_dates = pd .to_datetime (model .coords ["date" ])
138
+ new_dates = pd .to_datetime (model .coords [self . date_dim_name ])
135
139
136
140
# Update the data
137
141
new_data = {
@@ -243,12 +247,13 @@ class MockMMM:
243
247
244
248
"""
245
249
246
- def __init__ (self , trend : LinearTrend , prefix : str ):
250
+ def __init__ (self , trend : LinearTrend , prefix : str , date_dim_name : str = "date" ):
247
251
self .trend = trend
248
252
self .prefix = prefix
249
253
self .linear_trend_first_date : pd .Timestamp
254
+ self .date_dim_name : str = date_dim_name
250
255
251
- def create_data (self , mmm : MMM ) -> None :
256
+ def create_data (self , mmm : Model ) -> None :
252
257
"""Create the required data in the model.
253
258
254
259
Parameters
@@ -259,13 +264,13 @@ def create_data(self, mmm: MMM) -> None:
259
264
model : pm .Model = mmm .model
260
265
261
266
# Create time index data (normalized between 0 and 1)
262
- dates = pd .to_datetime (model .coords ["date" ])
267
+ dates = pd .to_datetime (model .coords [self . date_dim_name ])
263
268
self .linear_trend_first_date = dates [0 ]
264
269
t = (dates - self .linear_trend_first_date ).days .astype (float )
265
270
266
- pm .Data (f"{ self .prefix } _t" , t , dims = "date" )
271
+ pm .Data (f"{ self .prefix } _t" , t , dims = self . date_dim_name )
267
272
268
- def create_effect (self , mmm : MMM ) -> pt .TensorVariable :
273
+ def create_effect (self , mmm : Model ) -> pt .TensorVariable :
269
274
"""Create the trend effect in the model.
270
275
271
276
Parameters
@@ -289,19 +294,22 @@ def create_effect(self, mmm: MMM) -> pt.TensorVariable:
289
294
trend_effect = self .trend .apply (t )
290
295
291
296
# Create deterministic for the trend effect
292
- trend_dims = ("date" , * self .trend .dims ) # type: ignore
293
- trend_non_broadcastable_dims = ("date" , * self .trend .non_broadcastable_dims )
297
+ trend_dims = (self .date_dim_name , * self .trend .dims ) # type: ignore
298
+ trend_non_broadcastable_dims = (
299
+ self .date_dim_name ,
300
+ * self .trend .non_broadcastable_dims ,
301
+ )
294
302
trend_effect = pm .Deterministic (
295
303
f"{ self .prefix } _effect_contribution" ,
296
304
trend_effect [create_index (trend_dims , trend_non_broadcastable_dims )],
297
305
dims = trend_non_broadcastable_dims ,
298
306
)
299
307
300
308
# Return the trend effect
301
- dim_handler = create_dim_handler (("date" , * mmm .dims ))
309
+ dim_handler = create_dim_handler ((self . date_dim_name , * mmm .dims ))
302
310
return dim_handler (trend_effect , trend_non_broadcastable_dims )
303
311
304
- def set_data (self , mmm : MMM , model : pm .Model , X : xr .Dataset ) -> None :
312
+ def set_data (self , mmm : Model , model : pm .Model , X : xr .Dataset ) -> None :
305
313
"""Set the data for new predictions.
306
314
307
315
Parameters
@@ -314,7 +322,7 @@ def set_data(self, mmm: MMM, model: pm.Model, X: xr.Dataset) -> None:
314
322
The dataset for prediction.
315
323
"""
316
324
# Create normalized time index for new data
317
- new_dates = pd .to_datetime (model .coords ["date" ])
325
+ new_dates = pd .to_datetime (model .coords [self . date_dim_name ])
318
326
t = (new_dates - self .linear_trend_first_date ).days .astype (float )
319
327
320
328
# Update the data
@@ -338,13 +346,16 @@ class EventAdditiveEffect(BaseModel):
338
346
reference_date : str
339
347
The arbitrary reference date to calculate distance from events in days. Default
340
348
is "2025-01-01".
349
+ date_dim_name : str
350
+ The name of the date dimension in the model. Default is "date".
341
351
342
352
"""
343
353
344
354
df_events : InstanceOf [pd .DataFrame ]
345
355
prefix : str
346
356
effect : EventEffect
347
357
reference_date : str = "2025-01-01"
358
+ date_dim_name : str = "date"
348
359
349
360
def model_post_init (self , context : Any , / ) -> None :
350
361
"""Post initialization of the model."""
@@ -365,7 +376,7 @@ def end_dates(self) -> pd.Series:
365
376
"""The end dates of the events."""
366
377
return pd .to_datetime (self .df_events ["end_date" ])
367
378
368
- def create_data (self , mmm : MMM ) -> None :
379
+ def create_data (self , mmm : Model ) -> None :
369
380
"""Create the required data in the model.
370
381
371
382
Parameters
@@ -376,15 +387,15 @@ def create_data(self, mmm: MMM) -> None:
376
387
"""
377
388
model : pm .Model = mmm .model
378
389
379
- model_dates = pd .to_datetime (model .coords ["date" ])
390
+ model_dates = pd .to_datetime (model .coords [self . date_dim_name ])
380
391
381
392
model .add_coord (self .prefix , self .df_events ["name" ].to_numpy ())
382
393
383
394
if "days" not in model :
384
395
pm .Data (
385
396
"days" ,
386
397
days_from_reference (model_dates , self .reference_date ),
387
- dims = "date" ,
398
+ dims = self . date_dim_name ,
388
399
)
389
400
390
401
pm .Data (
@@ -398,7 +409,7 @@ def create_data(self, mmm: MMM) -> None:
398
409
dims = self .prefix ,
399
410
)
400
411
401
- def create_effect (self , mmm : MMM ) -> pt .TensorVariable :
412
+ def create_effect (self , mmm : Model ) -> pt .TensorVariable :
402
413
"""Create the event effect in the model.
403
414
404
415
Parameters
@@ -430,15 +441,15 @@ def create_basis_matrix(start_ref, end_ref):
430
441
total_effect = pm .Deterministic (
431
442
f"{ self .prefix } _total_effect" ,
432
443
event_effect .sum (axis = 1 ),
433
- dims = "date" ,
444
+ dims = self . date_dim_name ,
434
445
)
435
446
436
- dim_handler = create_dim_handler (("date" , * mmm .dims ))
437
- return dim_handler (total_effect , "date" )
447
+ dim_handler = create_dim_handler ((self . date_dim_name , * mmm .dims ))
448
+ return dim_handler (total_effect , self . date_dim_name )
438
449
439
- def set_data (self , mmm : MMM , model : pm .Model , X : xr .Dataset ) -> None :
450
+ def set_data (self , mmm : Model , model : pm .Model , X : xr .Dataset ) -> None :
440
451
"""Set the data for new predictions."""
441
- new_dates = pd .to_datetime (model .coords ["date" ])
452
+ new_dates = pd .to_datetime (model .coords [self . date_dim_name ])
442
453
443
454
new_data = {
444
455
"days" : days_from_reference (new_dates , self .reference_date ),
0 commit comments