Skip to content

Commit 0b27a25

Browse files
authored
Change the event function to class (#1675)
* turn function into class * use the class in the tests * add docstrings * remove the original function version * add to the docstrig * rename the variables * change variable name in other documentation location
1 parent d232333 commit 0b27a25

File tree

5 files changed

+107
-94
lines changed

5 files changed

+107
-94
lines changed

pymc_marketing/mmm/additive_effect.py

Lines changed: 87 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
# limitations under the License.
1414
"""Additive effects for the multidimensional Marketing Mix Model."""
1515

16-
from typing import Protocol
16+
from typing import Any, Protocol
1717

1818
import pandas as pd
1919
import pymc as pm
2020
import xarray as xr
21+
from pydantic import BaseModel, InstanceOf
2122
from pytensor import tensor as pt
2223

2324
from pymc_marketing.mmm.events import EventEffect, days_from_reference
@@ -318,14 +319,8 @@ def set_data(self, mmm: MMM, model: pm.Model, X: xr.Dataset) -> None:
318319
pm.set_data({f"{self.prefix}_t": t}, model=model)
319320

320321

321-
def create_event_mu_effect(
322-
df_events: pd.DataFrame,
323-
prefix: str,
324-
effect: EventEffect,
325-
) -> MuEffect:
326-
"""Create an event effect for the MMM.
327-
328-
This class has the ability to create data and mean effects for the MMM model.
322+
class EventAdditiveEffect(BaseModel):
323+
"""Event effect class for the MMM.
329324
330325
Parameters
331326
----------
@@ -338,105 +333,112 @@ def create_event_mu_effect(
338333
The prefix to use for the event effect and associated variables.
339334
effect : EventEffect
340335
The event effect to apply.
341-
342-
Returns
343-
-------
344-
MuEffect
345-
The event effect which is used in the MMM.
336+
reference_date : str
337+
The arbitrary reference date to calculate distance from events in days. Default
338+
is "2025-01-01".
346339
347340
"""
348-
if missing_columns := set(["start_date", "end_date", "name"]).difference(
349-
df_events.columns,
350-
):
351-
raise ValueError(f"Columns {missing_columns} are missing in df_events.")
352341

353-
effect.basis.prefix = prefix
342+
df_events: InstanceOf[pd.DataFrame]
343+
prefix: str
344+
effect: EventEffect
345+
reference_date: str = "2025-01-01"
354346

355-
reference_date = "2025-01-01"
356-
start_dates = pd.to_datetime(df_events["start_date"])
357-
end_dates = pd.to_datetime(df_events["end_date"])
347+
def model_post_init(self, context: Any, /) -> None:
348+
"""Post initialization of the model."""
349+
if missing_columns := set(["start_date", "end_date", "name"]).difference(
350+
self.df_events.columns
351+
):
352+
raise ValueError(f"Columns {missing_columns} are missing in df_events.")
358353

359-
class Effect:
360-
"""Event effect class for the MMM."""
354+
self.effect.basis.prefix = self.prefix
361355

362-
def create_data(self, mmm: MMM) -> None:
363-
"""Create the required data in the model.
356+
@property
357+
def start_dates(self) -> pd.Series:
358+
"""The start dates of the events."""
359+
return pd.to_datetime(self.df_events["start_date"])
364360

365-
Parameters
366-
----------
367-
mmm : MMM
368-
The MMM model instance.
361+
@property
362+
def end_dates(self) -> pd.Series:
363+
"""The end dates of the events."""
364+
return pd.to_datetime(self.df_events["end_date"])
369365

370-
"""
371-
model: pm.Model = mmm.model
366+
def create_data(self, mmm: MMM) -> None:
367+
"""Create the required data in the model.
372368
373-
model_dates = pd.to_datetime(model.coords["date"])
369+
Parameters
370+
----------
371+
mmm : MMM
372+
The MMM model instance.
374373
375-
model.add_coord(prefix, df_events["name"].to_numpy())
374+
"""
375+
model: pm.Model = mmm.model
376376

377-
if "days" not in model:
378-
pm.Data(
379-
"days",
380-
days_from_reference(model_dates, reference_date),
381-
dims="date",
382-
)
377+
model_dates = pd.to_datetime(model.coords["date"])
383378

379+
model.add_coord(self.prefix, self.df_events["name"].to_numpy())
380+
381+
if "days" not in model:
384382
pm.Data(
385-
f"{prefix}_start_diff",
386-
days_from_reference(start_dates, reference_date),
387-
dims=prefix,
388-
)
389-
pm.Data(
390-
f"{prefix}_end_diff",
391-
days_from_reference(end_dates, reference_date),
392-
dims=prefix,
383+
"days",
384+
days_from_reference(model_dates, self.reference_date),
385+
dims="date",
393386
)
394387

395-
def create_effect(self, mmm: MMM) -> pt.TensorVariable:
396-
"""Create the event effect in the model.
397-
398-
Parameters
399-
----------
400-
mmm : MMM
401-
The MMM model instance.
388+
pm.Data(
389+
f"{self.prefix}_start_diff",
390+
days_from_reference(self.start_dates, self.reference_date),
391+
dims=self.prefix,
392+
)
393+
pm.Data(
394+
f"{self.prefix}_end_diff",
395+
days_from_reference(self.end_dates, self.reference_date),
396+
dims=self.prefix,
397+
)
402398

403-
Returns
404-
-------
405-
pt.TensorVariable
406-
The average event effect in the model.
399+
def create_effect(self, mmm: MMM) -> pt.TensorVariable:
400+
"""Create the event effect in the model.
407401
408-
"""
409-
model: pm.Model = mmm.model
402+
Parameters
403+
----------
404+
mmm : MMM
405+
The MMM model instance.
410406
411-
s_ref = model["days"][:, None] - model[f"{prefix}_start_diff"]
412-
e_ref = model["days"][:, None] - model[f"{prefix}_end_diff"]
407+
Returns
408+
-------
409+
pt.TensorVariable
410+
The average event effect in the model.
413411
414-
def create_basis_matrix(s_ref, e_ref):
415-
return pt.where(
416-
(s_ref >= 0) & (e_ref <= 0),
417-
0,
418-
pt.where(pt.abs(s_ref) < pt.abs(e_ref), s_ref, e_ref),
419-
)
412+
"""
413+
model: pm.Model = mmm.model
420414

421-
X = create_basis_matrix(s_ref, e_ref)
422-
event_effect = effect.apply(X, name=prefix)
415+
start_ref = model["days"][:, None] - model[f"{self.prefix}_start_diff"]
416+
end_ref = model["days"][:, None] - model[f"{self.prefix}_end_diff"]
423417

424-
total_effect = pm.Deterministic(
425-
f"{prefix}_total_effect",
426-
event_effect.sum(axis=1),
427-
dims="date",
418+
def create_basis_matrix(start_ref, end_ref):
419+
return pt.where(
420+
(start_ref >= 0) & (end_ref <= 0),
421+
0,
422+
pt.where(pt.abs(start_ref) < pt.abs(end_ref), start_ref, end_ref),
428423
)
429424

430-
dim_handler = create_dim_handler(("date", *mmm.dims))
431-
return dim_handler(total_effect, "date")
425+
X = create_basis_matrix(start_ref, end_ref)
426+
event_effect = self.effect.apply(X, name=self.prefix)
432427

433-
def set_data(self, mmm: MMM, model: pm.Model, X: xr.Dataset) -> None:
434-
"""Set the data for new predictions."""
435-
new_dates = pd.to_datetime(model.coords["date"])
428+
total_effect = pm.Deterministic(
429+
f"{self.prefix}_total_effect",
430+
event_effect.sum(axis=1),
431+
dims="date",
432+
)
436433

437-
new_data = {
438-
"days": days_from_reference(new_dates, reference_date),
439-
}
440-
pm.set_data(new_data=new_data, model=model)
434+
dim_handler = create_dim_handler(("date", *mmm.dims))
435+
return dim_handler(total_effect, "date")
436+
437+
def set_data(self, mmm: MMM, model: pm.Model, X: xr.Dataset) -> None:
438+
"""Set the data for new predictions."""
439+
new_dates = pd.to_datetime(model.coords["date"])
441440

442-
return Effect()
441+
new_data = {
442+
"days": days_from_reference(new_dates, self.reference_date),
443+
}
444+
pm.set_data(new_data=new_data, model=model)

pymc_marketing/mmm/events.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,13 @@ def create_basis_matrix(df_events: pd.DataFrame, model_dates: np.ndarray):
5353
start_dates = df_events["start_date"]
5454
end_dates = df_events["end_date"]
5555
56-
s_ref = difference_in_days(model_dates, start_dates)
57-
e_ref = difference_in_days(model_dates, end_dates)
56+
start_ref = difference_in_days(model_dates, start_dates)
57+
end_ref = difference_in_days(model_dates, end_dates)
5858
5959
return np.where(
60-
(s_ref >= 0) & (e_ref <= 0),
60+
(start_ref >= 0) & (end_ref <= 0),
6161
0,
62-
np.where(np.abs(s_ref) < np.abs(e_ref), s_ref, e_ref),
62+
np.where(np.abs(start_ref) < np.abs(end_ref), start_ref, end_ref),
6363
)
6464
6565

pymc_marketing/mmm/multidimensional.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from scipy.optimize import OptimizeResult
3434

3535
from pymc_marketing.mmm import SoftPlusHSGP
36-
from pymc_marketing.mmm.additive_effect import MuEffect, create_event_mu_effect
36+
from pymc_marketing.mmm.additive_effect import EventAdditiveEffect, MuEffect
3737
from pymc_marketing.mmm.budget_optimizer import OptimizerCompatibleModelWrapper
3838
from pymc_marketing.mmm.components.adstock import (
3939
AdstockTransformation,
@@ -247,7 +247,11 @@ def add_events(
247247
f"Event effect dims {effect.dims} must contain {prefix} and {self.dims}"
248248
)
249249

250-
event_effect = create_event_mu_effect(df_events, prefix, effect)
250+
event_effect = EventAdditiveEffect(
251+
df_events=df_events,
252+
prefix=prefix,
253+
effect=effect,
254+
)
251255
self.mu_effects.append(event_effect)
252256

253257
@property

tests/mmm/test_additive_effect.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
import pymc as pm
1717
import pytest
1818

19-
from pymc_marketing.mmm.additive_effect import FourierEffect, LinearTrendEffect
19+
from pymc_marketing.mmm.additive_effect import (
20+
FourierEffect,
21+
LinearTrendEffect,
22+
)
2023
from pymc_marketing.mmm.fourier import MonthlyFourier, WeeklyFourier, YearlyFourier
2124
from pymc_marketing.mmm.linear_trend import LinearTrend
2225
from pymc_marketing.prior import Prior

tests/mmm/test_multidimensional.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424
from scipy.optimize import OptimizeResult
2525

2626
from pymc_marketing.mmm import GeometricAdstock, LogisticSaturation
27+
from pymc_marketing.mmm.additive_effect import EventAdditiveEffect
2728
from pymc_marketing.mmm.events import EventEffect, GaussianBasis
2829
from pymc_marketing.mmm.multidimensional import (
2930
MMM,
3031
MultiDimensionalBudgetOptimizerWrapper,
31-
create_event_mu_effect,
3232
)
3333
from pymc_marketing.mmm.scaling import Scaling, VariableScaling
3434
from pymc_marketing.prior import Prior
@@ -469,7 +469,11 @@ def test_create_effect_mu_effect(
469469
df_events,
470470
event_effect,
471471
) -> None:
472-
effect = create_event_mu_effect(df_events, prefix="holiday", effect=event_effect)
472+
effect = EventAdditiveEffect(
473+
df_events=df_events,
474+
prefix="holiday",
475+
effect=event_effect,
476+
)
473477

474478
with mock_mmm.model:
475479
effect.create_data(mock_mmm)

0 commit comments

Comments
 (0)