Skip to content

Commit c0fcde3

Browse files
lift test requires 'date' in df_lift_test Data Frame with time_varying_media=True (#1818)
* adds logic to include a date column placeholder for when using time_varying_media * adds test to automatically add a date column into df_list_test --------- Co-authored-by: Juan Orduz <[email protected]>
1 parent f2d4609 commit c0fcde3

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

pymc_marketing/mmm/mmm.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2094,6 +2094,16 @@ def add_lift_test_measurements(
20942094
"The 'channel' column is required to map the lift measurements to the model."
20952095
)
20962096

2097+
if self.time_varying_media and "date" not in df_lift_test.columns:
2098+
# `time_varying_media=True` parameter requires the date in the df_lift_test DataFrame.
2099+
# The `add_lift_test_measurements` method itself doesn't need a date
2100+
# We need to make sure the `date` coord is present in model_coords
2101+
# By adding this we make sure the model_coords match
2102+
df_lift_test["date"] = pd.to_datetime(self.model_coords["date"][0])
2103+
2104+
# Store df_lift_test for testing purposes
2105+
self._last_lift_test_df = df_lift_test
2106+
20972107
df_lift_test_scaled = scale_lift_measurements(
20982108
df_lift_test=df_lift_test,
20992109
channel_col="channel",

tests/mmm/test_lift_test.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from sklearn.pipeline import Pipeline
2222
from sklearn.preprocessing import MaxAbsScaler
2323

24+
from pymc_marketing.mmm import MMM, GeometricAdstock
2425
from pymc_marketing.mmm.components.saturation import (
2526
HillSaturation,
2627
LogisticSaturation,
@@ -485,3 +486,51 @@ def test_scale_lift_measurements(df_lift_test_with_numerics) -> None:
485486
expected,
486487
check_like=True,
487488
)
489+
490+
491+
@pytest.fixture
492+
def dummy_mmm_model():
493+
# Create sample data for dummy model
494+
df = pd.DataFrame(
495+
{
496+
"date": pd.to_datetime(pd.date_range("2024-01-01", periods=52)),
497+
"organic": np.random.rand(52) * 520,
498+
"paid": np.random.rand(52) * 200,
499+
"social": np.random.rand(52) * 50,
500+
"y": np.random.rand(52) * 500, # target variable
501+
}
502+
)
503+
X = df[["date", "organic", "paid", "social"]]
504+
y = df["y"]
505+
# Initialize model
506+
model = MMM(
507+
date_column="date",
508+
adstock=GeometricAdstock(l_max=6),
509+
saturation=LogisticSaturation(),
510+
channel_columns=["organic", "paid", "social"],
511+
time_varying_media=True, # trigger the condition
512+
)
513+
# Build the model
514+
model.build_model(X, y)
515+
return model
516+
517+
518+
def test_adds_date_column_if_missing(dummy_mmm_model):
519+
df_lift_test = pd.DataFrame(
520+
{
521+
"x": [1, 2, 3],
522+
"delta_x": [0.1, 0.2, 0.3],
523+
"sigma": [0.1, 0.2, 0.3],
524+
"delta_y": [0.1, 0.2, 0.3],
525+
"channel": ["organic", "paid", "social"],
526+
}
527+
)
528+
529+
# Make sure the column is missing initially
530+
assert "date" not in df_lift_test.columns
531+
532+
# Run the method (it should handle date patching internally)
533+
dummy_mmm_model.add_lift_test_measurements(df_lift_test)
534+
535+
# Check if the date was added inside the function
536+
assert dummy_mmm_model._last_lift_test_df["date"].notna().all()

0 commit comments

Comments
 (0)