diff --git a/pymc_marketing/mmm/additive_effect.py b/pymc_marketing/mmm/additive_effect.py index 79ca27156..a9b10d45f 100644 --- a/pymc_marketing/mmm/additive_effect.py +++ b/pymc_marketing/mmm/additive_effect.py @@ -104,12 +104,13 @@ def set_data(self, mmm, model, X): - In `set_data`, update the data variables when dates/dims change. """ -from typing import Any, Protocol +from abc import ABC, abstractmethod +from typing import Annotated, Any, Protocol import pandas as pd import pymc as pm import xarray as xr -from pydantic import BaseModel, InstanceOf +from pydantic import BaseModel, Field, InstanceOf, PlainValidator, WithJsonSchema from pymc_extras.prior import create_dim_handler from pytensor import tensor as pt @@ -131,35 +132,31 @@ def model(self) -> pm.Model: """The PyMC model.""" -class MuEffect(Protocol): - """Protocol for arbitrary additive mu effect.""" +class MuEffect(ABC, BaseModel): + """Abstract base class for arbitrary additive mu effects. + All mu_effects must inherit from this Pydantic BaseModel to ensure proper + serialization and deserialization when saving/loading MMM models. + """ + + @abstractmethod def create_data(self, mmm: Model) -> None: """Create the required data in the model.""" + @abstractmethod def create_effect(self, mmm: Model) -> pt.TensorVariable: """Create the additive effect in the model.""" + @abstractmethod def set_data(self, mmm: Model, model: pm.Model, X: xr.Dataset) -> None: """Set the data for new predictions.""" -class FourierEffect: +class FourierEffect(MuEffect): """Fourier seasonality additive effect for MMM.""" - def __init__(self, fourier: FourierBase, date_dim_name: str = "date"): - """Initialize the Fourier effect. - - Parameters - ---------- - fourier : FourierBase - The FourierBase instance to use for the effect. - date_dim_name : str, optional - The name of the date dimension in the model, by default "date". - - """ - self.fourier = fourier - self.date_dim_name: str = date_dim_name + fourier: InstanceOf[FourierBase] + date_dim_name: str = Field("date") def create_data(self, mmm: Model) -> None: """Create the required data in the model. @@ -256,7 +253,14 @@ def set_data(self, mmm: Model, model: pm.Model, X: xr.Dataset) -> None: pm.set_data(new_data=new_data, model=model) -class LinearTrendEffect: +_Timestamp = Annotated[ + pd.Timestamp, + PlainValidator(lambda x: pd.Timestamp(x)), + WithJsonSchema({"type": "date-time"}), +] + + +class LinearTrendEffect(MuEffect): """Wrapper for LinearTrend to use with MMM's MuEffect protocol. This class adapts the LinearTrend component to be used as an additive effect @@ -268,6 +272,8 @@ class LinearTrendEffect: The LinearTrend instance to wrap. prefix : str The prefix to use for variables in the model. + date_dim_name : str + The name of the date dimension in the model. Examples -------- @@ -357,11 +363,10 @@ class MockMMM: """ - def __init__(self, trend: LinearTrend, prefix: str, date_dim_name: str = "date"): - self.trend = trend - self.prefix = prefix - self.linear_trend_first_date: pd.Timestamp - self.date_dim_name: str = date_dim_name + trend: InstanceOf[LinearTrend] + prefix: str + date_dim_name: str = Field("date") + linear_trend_first_date: _Timestamp | None = Field(None, init=False) def create_data(self, mmm: Model) -> None: """Create the required data in the model. @@ -439,7 +444,7 @@ def set_data(self, mmm: Model, model: pm.Model, X: xr.Dataset) -> None: pm.set_data({f"{self.prefix}_t": t}, model=model) -class EventAdditiveEffect(BaseModel): +class EventAdditiveEffect(MuEffect): """Event effect class for the MMM. Parameters diff --git a/tests/mmm/test_additive_effect.py b/tests/mmm/test_additive_effect.py index 901920582..b7efd5651 100644 --- a/tests/mmm/test_additive_effect.py +++ b/tests/mmm/test_additive_effect.py @@ -99,7 +99,7 @@ def test_fourier_effect( dims, coords, ) -> None: - effect = FourierEffect(fourier) + effect = FourierEffect(fourier=fourier) mmm = create_mock_mmm( dims=dims, @@ -168,7 +168,7 @@ def test_fourier_effect_multidimensional( prefix = "weekly" prior = Prior("Laplace", mu=0, b=0.1, dims=prior_dims) fourier = WeeklyFourier(n_order=10, prefix=prefix, prior=prior) - fourier_effect = FourierEffect(fourier) + fourier_effect = FourierEffect(fourier=fourier) with mmm.model: fourier_effect.create_data(mmm) @@ -252,7 +252,7 @@ def test_linear_trend_effect( ) -> None: prefix = "linear_trend" effect = LinearTrendEffect( - LinearTrend(priors=priors, dims=linear_trend_dims), + trend=LinearTrend(priors=priors, dims=linear_trend_dims), prefix=prefix, )