Skip to content

Commit 05abba7

Browse files
authored
Date Validation and MMM Model Hamonization (Pydantic) (#824)
* validate base mmm init class * validate dateformat * add comment about date * remove ()
1 parent d766722 commit 05abba7

File tree

3 files changed

+49
-6
lines changed

3 files changed

+49
-6
lines changed

pymc_marketing/mmm/base.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,23 @@
4646
from pymc_marketing.model_builder import ModelBuilder
4747

4848
__all__ = ["MMMModelBuilder", "BaseValidateMMM"]
49+
from pydantic import Field, validate_call
4950

5051

5152
class MMMModelBuilder(ModelBuilder):
5253
model: pm.Model
5354
_model_type = "BaseMMM"
5455
version = "0.0.2"
5556

57+
@validate_call
5658
def __init__(
5759
self,
58-
date_column: str,
59-
channel_columns: list[str] | tuple[str],
60-
model_config: dict | None = None,
61-
sampler_config: dict | None = None,
60+
date_column: str = Field(..., description="Column name of the date variable."),
61+
channel_columns: list[str] = Field(
62+
min_length=1, description="Column names of the media channel variables."
63+
),
64+
model_config: dict | None = Field(None, description="Model configuration."),
65+
sampler_config: dict | None = Field(None, description="Sampler configuration."),
6266
**kwargs,
6367
) -> None:
6468
self.date_column: str = date_column

pymc_marketing/mmm/delayed_saturated_mmm.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def __init__(
125125
Parameter
126126
---------
127127
date_column : str
128-
Column name of the date variable.
128+
Column name of the date variable. Must be parsable using ~pandas.to_datetime.
129129
channel_columns : List[str]
130130
Column names of the media channel variables.
131131
adstock_max_lag : int, optional
@@ -236,7 +236,13 @@ def _generate_and_preprocess_model_data( # type: ignore
236236
_time_resolution: int
237237
The time resolution of the date index. Used by TVP.
238238
"""
239-
date_data = X[self.date_column]
239+
try:
240+
date_data = pd.to_datetime(X[self.date_column])
241+
except Exception as e:
242+
raise ValueError(
243+
f"Could not convert {self.date_column} to datetime. Please check the date format."
244+
) from e
245+
240246
channel_data = X[self.channel_columns]
241247

242248
coords: dict[str, Any] = {

tests/mmm/test_delayed_saturated_mmm.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,23 @@ def toy_X(generate_data) -> pd.DataFrame:
8282
return generate_data(date_data)
8383

8484

85+
@pytest.fixture(scope="module")
86+
def toy_X_with_bad_dates() -> pd.DataFrame:
87+
bad_date_data = ["a", "b", "c", "d", "e"]
88+
n: int = len(bad_date_data)
89+
return pd.DataFrame(
90+
data={
91+
"date": bad_date_data,
92+
"channel_1": rng.integers(low=0, high=400, size=n),
93+
"channel_2": rng.integers(low=0, high=50, size=n),
94+
"control_1": rng.gamma(shape=1000, scale=500, size=n),
95+
"control_2": rng.gamma(shape=100, scale=5, size=n),
96+
"other_column_1": rng.integers(low=0, high=100, size=n),
97+
"other_column_2": rng.normal(loc=0, scale=1, size=n),
98+
}
99+
)
100+
101+
85102
@pytest.fixture(scope="class")
86103
def model_config_requiring_serialization() -> dict:
87104
model_config = {
@@ -206,6 +223,22 @@ def deep_equal(dict1, dict2):
206223
assert model.sampler_config == model2.sampler_config
207224
os.remove("test_save_load")
208225

226+
def test_bad_date_column(self, toy_X_with_bad_dates) -> None:
227+
with pytest.raises(
228+
ValueError,
229+
match="Could not convert bad_date_column to datetime. Please check the date format.",
230+
):
231+
my_mmm = MMM(
232+
date_column="bad_date_column",
233+
channel_columns=["channel_1", "channel_2"],
234+
adstock_max_lag=4,
235+
control_columns=["control_1", "control_2"],
236+
adstock="geometric",
237+
saturation="logistic",
238+
)
239+
y = np.ones(toy_X_with_bad_dates.shape[0])
240+
my_mmm.build_model(X=toy_X_with_bad_dates, y=y)
241+
209242
@pytest.mark.parametrize(
210243
argnames="adstock_max_lag",
211244
argvalues=[1, 4],

0 commit comments

Comments
 (0)