Skip to content

Commit b6a88b6

Browse files
committed
more tests
1 parent ddc455e commit b6a88b6

File tree

1 file changed

+296
-0
lines changed

1 file changed

+296
-0
lines changed

tests/mmm/test_mmm.py

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,6 +1054,138 @@ def test_mmm_missing_dag_or_nodes(self):
10541054
assert mmm.treatment_nodes is None, "Treatment nodes should default to None."
10551055
assert mmm.outcome_node is None, "Outcome node should default to None."
10561056

1057+
def test_scaling_dict_with_missing_channel(
1058+
self, toy_X: pd.DataFrame, toy_y: pd.Series
1059+
):
1060+
"""Test scaling dict that only specifies target."""
1061+
from pymc_marketing.mmm.scaling import VariableScaling
1062+
1063+
mmm = MMM(
1064+
date_column="date",
1065+
channel_columns=["channel_1", "channel_2"],
1066+
adstock=GeometricAdstock(l_max=4),
1067+
saturation=LogisticSaturation(),
1068+
scaling={"target": VariableScaling(method="mean", dims=())},
1069+
)
1070+
1071+
mmm.build_model(X=toy_X, y=toy_y)
1072+
assert hasattr(mmm, "model")
1073+
assert mmm.scaling.channel.method == "max"
1074+
1075+
def test_scaling_dict_with_missing_target(
1076+
self, toy_X: pd.DataFrame, toy_y: pd.Series
1077+
):
1078+
"""Test scaling dict that only specifies channel."""
1079+
from pymc_marketing.mmm.scaling import VariableScaling
1080+
1081+
mmm = MMM(
1082+
date_column="date",
1083+
channel_columns=["channel_1", "channel_2"],
1084+
adstock=GeometricAdstock(l_max=4),
1085+
saturation=LogisticSaturation(),
1086+
scaling={"channel": VariableScaling(method="mean", dims=())},
1087+
)
1088+
1089+
mmm.build_model(X=toy_X, y=toy_y)
1090+
assert hasattr(mmm, "model")
1091+
assert mmm.scaling.target.method == "max"
1092+
1093+
def test_mean_scaling_method(self, toy_X: pd.DataFrame, toy_y: pd.Series):
1094+
"""Test using mean scaling instead of max."""
1095+
from pymc_marketing.mmm.scaling import Scaling, VariableScaling
1096+
1097+
mmm = MMM(
1098+
date_column="date",
1099+
channel_columns=["channel_1", "channel_2"],
1100+
adstock=GeometricAdstock(l_max=4),
1101+
saturation=LogisticSaturation(),
1102+
scaling=Scaling(
1103+
target=VariableScaling(method="mean", dims=()),
1104+
channel=VariableScaling(method="mean", dims=()),
1105+
),
1106+
)
1107+
1108+
mmm.build_model(X=toy_X, y=toy_y)
1109+
assert hasattr(mmm, "model")
1110+
assert mmm.target_scale > 0
1111+
1112+
def test_validation_disabled(self, toy_X: pd.DataFrame, toy_y: pd.Series):
1113+
"""Test model with validation disabled."""
1114+
mmm = MMM(
1115+
date_column="date",
1116+
channel_columns=["channel_1", "channel_2"],
1117+
adstock=GeometricAdstock(l_max=4),
1118+
saturation=LogisticSaturation(),
1119+
validate_data=False,
1120+
)
1121+
1122+
mmm.build_model(X=toy_X, y=toy_y)
1123+
assert hasattr(mmm, "model")
1124+
1125+
def test_adstock_first_false(self, toy_X: pd.DataFrame, toy_y: pd.Series):
1126+
"""Test model with saturation applied before adstock."""
1127+
mmm = MMM(
1128+
date_column="date",
1129+
channel_columns=["channel_1", "channel_2"],
1130+
adstock=GeometricAdstock(l_max=4),
1131+
saturation=LogisticSaturation(),
1132+
adstock_first=False,
1133+
)
1134+
1135+
mmm.build_model(X=toy_X, y=toy_y)
1136+
assert hasattr(mmm, "model")
1137+
assert mmm.adstock_first is False
1138+
1139+
def test_serializable_model_config_with_ndarray(self, mmm_fitted: MMM):
1140+
"""Test _serializable_model_config converts ndarrays to lists."""
1141+
mmm_fitted.model_config["test_array"] = np.array([1, 2, 3])
1142+
mmm_fitted.model_config["nested"] = {"array": np.array([4, 5])}
1143+
1144+
serializable = mmm_fitted._serializable_model_config
1145+
1146+
assert isinstance(serializable["test_array"], list)
1147+
assert isinstance(serializable["nested"]["array"], list)
1148+
1149+
def test_model_config_formatting_with_nested_dicts(self):
1150+
"""Test _model_config_formatting with nested dictionaries."""
1151+
config = {
1152+
"param1": [1, 2, 3],
1153+
"param2": {"nested": [4, 5, 6], "dims": ["a", "b"]},
1154+
"param3": {"deep": {"nested": [7, 8], "dims": ["c"]}},
1155+
}
1156+
1157+
formatted = MMM._model_config_formatting(config)
1158+
1159+
assert isinstance(formatted["param2"]["dims"], tuple)
1160+
assert isinstance(formatted["param3"]["deep"]["dims"], tuple)
1161+
assert isinstance(formatted["param1"], np.ndarray)
1162+
assert isinstance(formatted["param2"]["nested"], np.ndarray)
1163+
1164+
def test_deserialize_scaling_none(self):
1165+
"""Test _deserialize_scaling with None input."""
1166+
result = MMM._deserialize_scaling(None)
1167+
assert result is None
1168+
1169+
def test_attrs_to_init_kwargs_with_defaults(self):
1170+
"""Test attrs_to_init_kwargs with missing optional attributes."""
1171+
attrs = {
1172+
"model_config": "{}",
1173+
"date_column": '"date"',
1174+
"control_columns": "null",
1175+
"channel_columns": '["ch1"]',
1176+
"adstock": '{"lookup_name": "geometric", "l_max": 4, "normalize": true, "mode": "After"}',
1177+
"saturation": '{"lookup_name": "logistic"}',
1178+
"yearly_seasonality": "null",
1179+
"validate_data": "true",
1180+
"sampler_config": "{}",
1181+
}
1182+
1183+
kwargs = MMM.attrs_to_init_kwargs(attrs)
1184+
1185+
assert kwargs["adstock_first"] is True
1186+
assert kwargs["time_varying_intercept"] is False
1187+
assert kwargs["scaling"] is None
1188+
10571189

10581190
def new_date_ranges_to_test():
10591191
yield from [
@@ -2227,6 +2359,160 @@ def test_mixed_positive_negative_channels(self, toy_y: pd.Series):
22272359
mmm.build_model(X=X, y=toy_y)
22282360
assert hasattr(mmm, "model")
22292361

2362+
def test_date_conversion_error(self, toy_y: pd.Series):
2363+
"""Test error when date cannot be converted."""
2364+
X = pd.DataFrame(
2365+
{
2366+
"date": [object() for _ in range(50)],
2367+
"channel_1": rng.integers(100, 1000, size=50),
2368+
"channel_2": rng.integers(100, 1000, size=50),
2369+
}
2370+
)
2371+
2372+
mmm = MMM(
2373+
date_column="date",
2374+
channel_columns=["channel_1", "channel_2"],
2375+
adstock=GeometricAdstock(l_max=4),
2376+
saturation=LogisticSaturation(),
2377+
)
2378+
2379+
with pytest.raises(ValueError, match="Could not convert"):
2380+
mmm.build_model(X=X, y=toy_y)
2381+
2382+
def test_transformer_fitting_failure(
2383+
self, toy_X: pd.DataFrame, toy_y: pd.Series, monkeypatch
2384+
):
2385+
"""Test handling when transformer fitting fails."""
2386+
2387+
def mock_failing_transformer(*args, **kwargs):
2388+
raise RuntimeError("Transformer failure")
2389+
2390+
mmm = MMM(
2391+
date_column="date",
2392+
channel_columns=["channel_1", "channel_2"],
2393+
adstock=GeometricAdstock(l_max=4),
2394+
saturation=LogisticSaturation(),
2395+
)
2396+
2397+
monkeypatch.setattr(mmm, "max_abs_scale_target_data", mock_failing_transformer)
2398+
2399+
with pytest.warns(UserWarning, match="Failed to fit transformers"):
2400+
mmm.build_model(X=toy_X, y=toy_y)
2401+
2402+
def test_causal_graph_with_warnings(self, toy_X: pd.DataFrame, toy_y: pd.Series):
2403+
"""Test causal graph initialization with warnings when treatment nodes not specified."""
2404+
dag = """
2405+
digraph {
2406+
channel_1 -> y;
2407+
channel_2 -> y;
2408+
control_1 -> y;
2409+
}
2410+
"""
2411+
2412+
with pytest.warns(UserWarning, match="No treatment nodes provided"):
2413+
mmm = MMM(
2414+
date_column="date",
2415+
channel_columns=["channel_1", "channel_2"],
2416+
control_columns=["control_1"],
2417+
adstock=GeometricAdstock(l_max=4),
2418+
saturation=LogisticSaturation(),
2419+
dag=dag,
2420+
outcome_node="y",
2421+
)
2422+
2423+
assert mmm.treatment_nodes == ["channel_1", "channel_2"]
2424+
2425+
def test_causal_graph_excludes_seasonality(
2426+
self, toy_X: pd.DataFrame, toy_y: pd.Series
2427+
):
2428+
"""Test that seasonality is excluded when not in adjustment set."""
2429+
dag = """
2430+
digraph {
2431+
channel_1 -> y;
2432+
channel_2 -> y;
2433+
control_1 -> y;
2434+
}
2435+
"""
2436+
2437+
with pytest.warns(UserWarning, match="Yearly seasonality excluded"):
2438+
mmm = MMM(
2439+
date_column="date",
2440+
channel_columns=["channel_1", "channel_2"],
2441+
control_columns=["control_1"],
2442+
adstock=GeometricAdstock(l_max=4),
2443+
saturation=LogisticSaturation(),
2444+
yearly_seasonality=2,
2445+
dag=dag,
2446+
treatment_nodes=["channel_1", "channel_2"],
2447+
outcome_node="y",
2448+
)
2449+
2450+
assert mmm.yearly_seasonality is None
2451+
2452+
def test_x_data_not_dataframe_error(self, toy_y: pd.Series):
2453+
"""Test error when X data is not a DataFrame."""
2454+
mmm = MMM(
2455+
date_column="date",
2456+
channel_columns=["channel_1", "channel_2"],
2457+
adstock=GeometricAdstock(l_max=4),
2458+
saturation=LogisticSaturation(),
2459+
)
2460+
2461+
mmm.preprocessed_data = {"X": np.array([[1, 2], [3, 4]]), "y": toy_y}
2462+
2463+
with pytest.raises(TypeError, match="must be a DataFrame"):
2464+
mmm._compute_scales()
2465+
2466+
def test_invalid_synth_dataset_granularity(
2467+
self, mmm_fitted: MMM, toy_X: pd.DataFrame
2468+
):
2469+
"""Test invalid time granularity in synthetic dataset creation."""
2470+
from xarray import DataArray
2471+
2472+
allocation = DataArray(
2473+
[1000, 2000],
2474+
dims=["channel"],
2475+
coords={"channel": mmm_fitted.channel_columns},
2476+
)
2477+
2478+
with pytest.raises(ValueError, match="Unsupported time granularity"):
2479+
mmm_fitted._create_synth_dataset(
2480+
df=toy_X,
2481+
date_column="date",
2482+
allocation_strategy=allocation,
2483+
channels=mmm_fitted.channel_columns,
2484+
controls=None,
2485+
target_col="y",
2486+
time_granularity="invalid_granularity",
2487+
time_length=10,
2488+
lag=4,
2489+
)
2490+
2491+
def test_invalid_allocation_dims(self, mmm_fitted: MMM, toy_X: pd.DataFrame):
2492+
"""Test error when allocation strategy has wrong dimensions."""
2493+
from xarray import DataArray
2494+
2495+
allocation = DataArray(
2496+
[1000, 2000],
2497+
dims=["wrong_dim"],
2498+
coords={"wrong_dim": ["a", "b"]},
2499+
)
2500+
2501+
with pytest.raises(
2502+
ValueError, match="must have a single dimension named 'channel'"
2503+
):
2504+
mmm_fitted._create_synth_dataset(
2505+
df=toy_X,
2506+
date_column="date",
2507+
allocation_strategy=allocation,
2508+
channels=mmm_fitted.channel_columns,
2509+
controls=None,
2510+
target_col="y",
2511+
time_granularity="weekly",
2512+
time_length=10,
2513+
lag=4,
2514+
)
2515+
22302516

22312517
class TestMMMHelperMethods:
22322518
"""Tests for internal helper methods added during refactoring."""
@@ -2630,3 +2916,13 @@ def test_add_original_scale_deterministics(
26302916
assert "control_contribution_original_scale" in deterministic_names
26312917
assert "yearly_seasonality_contribution_original_scale" in deterministic_names
26322918
assert "y_original_scale" in deterministic_names
2919+
2920+
def test_prepare_target_data_with_none_and_no_preprocessed(self, mmm_fitted: MMM):
2921+
"""Test _prepare_target_data when y is None but preprocessed_data doesn't match expected types."""
2922+
mmm_fitted.preprocessed_data["y"] = "not_a_series_or_array" # type: ignore
2923+
2924+
result = mmm_fitted._prepare_target_data(y=None, n_rows=10)
2925+
2926+
assert "target_data" in result or "target" in result
2927+
if "target_data" in result:
2928+
assert len(result["target_data"]) == 10

0 commit comments

Comments
 (0)