|
21 | 21 | from sklearn.pipeline import Pipeline
|
22 | 22 | from sklearn.preprocessing import MaxAbsScaler
|
23 | 23 |
|
| 24 | +from pymc_marketing.mmm import MMM, GeometricAdstock |
24 | 25 | from pymc_marketing.mmm.components.saturation import (
|
25 | 26 | HillSaturation,
|
26 | 27 | LogisticSaturation,
|
@@ -485,3 +486,51 @@ def test_scale_lift_measurements(df_lift_test_with_numerics) -> None:
|
485 | 486 | expected,
|
486 | 487 | check_like=True,
|
487 | 488 | )
|
| 489 | + |
| 490 | + |
| 491 | +@pytest.fixture |
| 492 | +def dummy_mmm_model(): |
| 493 | + # Create sample data for dummy model |
| 494 | + df = pd.DataFrame( |
| 495 | + { |
| 496 | + "date": pd.to_datetime(pd.date_range("2024-01-01", periods=52)), |
| 497 | + "organic": np.random.rand(52) * 520, |
| 498 | + "paid": np.random.rand(52) * 200, |
| 499 | + "social": np.random.rand(52) * 50, |
| 500 | + "y": np.random.rand(52) * 500, # target variable |
| 501 | + } |
| 502 | + ) |
| 503 | + X = df[["date", "organic", "paid", "social"]] |
| 504 | + y = df["y"] |
| 505 | + # Initialize model |
| 506 | + model = MMM( |
| 507 | + date_column="date", |
| 508 | + adstock=GeometricAdstock(l_max=6), |
| 509 | + saturation=LogisticSaturation(), |
| 510 | + channel_columns=["organic", "paid", "social"], |
| 511 | + time_varying_media=True, # trigger the condition |
| 512 | + ) |
| 513 | + # Build the model |
| 514 | + model.build_model(X, y) |
| 515 | + return model |
| 516 | + |
| 517 | + |
| 518 | +def test_adds_date_column_if_missing(dummy_mmm_model): |
| 519 | + df_lift_test = pd.DataFrame( |
| 520 | + { |
| 521 | + "x": [1, 2, 3], |
| 522 | + "delta_x": [0.1, 0.2, 0.3], |
| 523 | + "sigma": [0.1, 0.2, 0.3], |
| 524 | + "delta_y": [0.1, 0.2, 0.3], |
| 525 | + "channel": ["organic", "paid", "social"], |
| 526 | + } |
| 527 | + ) |
| 528 | + |
| 529 | + # Make sure the column is missing initially |
| 530 | + assert "date" not in df_lift_test.columns |
| 531 | + |
| 532 | + # Run the method (it should handle date patching internally) |
| 533 | + dummy_mmm_model.add_lift_test_measurements(df_lift_test) |
| 534 | + |
| 535 | + # Check if the date was added inside the function |
| 536 | + assert dummy_mmm_model._last_lift_test_df["date"].notna().all() |
0 commit comments