@@ -1054,6 +1054,138 @@ def test_mmm_missing_dag_or_nodes(self):
1054
1054
assert mmm .treatment_nodes is None , "Treatment nodes should default to None."
1055
1055
assert mmm .outcome_node is None , "Outcome node should default to None."
1056
1056
1057
+ def test_scaling_dict_with_missing_channel (
1058
+ self , toy_X : pd .DataFrame , toy_y : pd .Series
1059
+ ):
1060
+ """Test scaling dict that only specifies target."""
1061
+ from pymc_marketing .mmm .scaling import VariableScaling
1062
+
1063
+ mmm = MMM (
1064
+ date_column = "date" ,
1065
+ channel_columns = ["channel_1" , "channel_2" ],
1066
+ adstock = GeometricAdstock (l_max = 4 ),
1067
+ saturation = LogisticSaturation (),
1068
+ scaling = {"target" : VariableScaling (method = "mean" , dims = ())},
1069
+ )
1070
+
1071
+ mmm .build_model (X = toy_X , y = toy_y )
1072
+ assert hasattr (mmm , "model" )
1073
+ assert mmm .scaling .channel .method == "max"
1074
+
1075
+ def test_scaling_dict_with_missing_target (
1076
+ self , toy_X : pd .DataFrame , toy_y : pd .Series
1077
+ ):
1078
+ """Test scaling dict that only specifies channel."""
1079
+ from pymc_marketing .mmm .scaling import VariableScaling
1080
+
1081
+ mmm = MMM (
1082
+ date_column = "date" ,
1083
+ channel_columns = ["channel_1" , "channel_2" ],
1084
+ adstock = GeometricAdstock (l_max = 4 ),
1085
+ saturation = LogisticSaturation (),
1086
+ scaling = {"channel" : VariableScaling (method = "mean" , dims = ())},
1087
+ )
1088
+
1089
+ mmm .build_model (X = toy_X , y = toy_y )
1090
+ assert hasattr (mmm , "model" )
1091
+ assert mmm .scaling .target .method == "max"
1092
+
1093
+ def test_mean_scaling_method (self , toy_X : pd .DataFrame , toy_y : pd .Series ):
1094
+ """Test using mean scaling instead of max."""
1095
+ from pymc_marketing .mmm .scaling import Scaling , VariableScaling
1096
+
1097
+ mmm = MMM (
1098
+ date_column = "date" ,
1099
+ channel_columns = ["channel_1" , "channel_2" ],
1100
+ adstock = GeometricAdstock (l_max = 4 ),
1101
+ saturation = LogisticSaturation (),
1102
+ scaling = Scaling (
1103
+ target = VariableScaling (method = "mean" , dims = ()),
1104
+ channel = VariableScaling (method = "mean" , dims = ()),
1105
+ ),
1106
+ )
1107
+
1108
+ mmm .build_model (X = toy_X , y = toy_y )
1109
+ assert hasattr (mmm , "model" )
1110
+ assert mmm .target_scale > 0
1111
+
1112
+ def test_validation_disabled (self , toy_X : pd .DataFrame , toy_y : pd .Series ):
1113
+ """Test model with validation disabled."""
1114
+ mmm = MMM (
1115
+ date_column = "date" ,
1116
+ channel_columns = ["channel_1" , "channel_2" ],
1117
+ adstock = GeometricAdstock (l_max = 4 ),
1118
+ saturation = LogisticSaturation (),
1119
+ validate_data = False ,
1120
+ )
1121
+
1122
+ mmm .build_model (X = toy_X , y = toy_y )
1123
+ assert hasattr (mmm , "model" )
1124
+
1125
+ def test_adstock_first_false (self , toy_X : pd .DataFrame , toy_y : pd .Series ):
1126
+ """Test model with saturation applied before adstock."""
1127
+ mmm = MMM (
1128
+ date_column = "date" ,
1129
+ channel_columns = ["channel_1" , "channel_2" ],
1130
+ adstock = GeometricAdstock (l_max = 4 ),
1131
+ saturation = LogisticSaturation (),
1132
+ adstock_first = False ,
1133
+ )
1134
+
1135
+ mmm .build_model (X = toy_X , y = toy_y )
1136
+ assert hasattr (mmm , "model" )
1137
+ assert mmm .adstock_first is False
1138
+
1139
+ def test_serializable_model_config_with_ndarray (self , mmm_fitted : MMM ):
1140
+ """Test _serializable_model_config converts ndarrays to lists."""
1141
+ mmm_fitted .model_config ["test_array" ] = np .array ([1 , 2 , 3 ])
1142
+ mmm_fitted .model_config ["nested" ] = {"array" : np .array ([4 , 5 ])}
1143
+
1144
+ serializable = mmm_fitted ._serializable_model_config
1145
+
1146
+ assert isinstance (serializable ["test_array" ], list )
1147
+ assert isinstance (serializable ["nested" ]["array" ], list )
1148
+
1149
+ def test_model_config_formatting_with_nested_dicts (self ):
1150
+ """Test _model_config_formatting with nested dictionaries."""
1151
+ config = {
1152
+ "param1" : [1 , 2 , 3 ],
1153
+ "param2" : {"nested" : [4 , 5 , 6 ], "dims" : ["a" , "b" ]},
1154
+ "param3" : {"deep" : {"nested" : [7 , 8 ], "dims" : ["c" ]}},
1155
+ }
1156
+
1157
+ formatted = MMM ._model_config_formatting (config )
1158
+
1159
+ assert isinstance (formatted ["param2" ]["dims" ], tuple )
1160
+ assert isinstance (formatted ["param3" ]["deep" ]["dims" ], tuple )
1161
+ assert isinstance (formatted ["param1" ], np .ndarray )
1162
+ assert isinstance (formatted ["param2" ]["nested" ], np .ndarray )
1163
+
1164
+ def test_deserialize_scaling_none (self ):
1165
+ """Test _deserialize_scaling with None input."""
1166
+ result = MMM ._deserialize_scaling (None )
1167
+ assert result is None
1168
+
1169
+ def test_attrs_to_init_kwargs_with_defaults (self ):
1170
+ """Test attrs_to_init_kwargs with missing optional attributes."""
1171
+ attrs = {
1172
+ "model_config" : "{}" ,
1173
+ "date_column" : '"date"' ,
1174
+ "control_columns" : "null" ,
1175
+ "channel_columns" : '["ch1"]' ,
1176
+ "adstock" : '{"lookup_name": "geometric", "l_max": 4, "normalize": true, "mode": "After"}' ,
1177
+ "saturation" : '{"lookup_name": "logistic"}' ,
1178
+ "yearly_seasonality" : "null" ,
1179
+ "validate_data" : "true" ,
1180
+ "sampler_config" : "{}" ,
1181
+ }
1182
+
1183
+ kwargs = MMM .attrs_to_init_kwargs (attrs )
1184
+
1185
+ assert kwargs ["adstock_first" ] is True
1186
+ assert kwargs ["time_varying_intercept" ] is False
1187
+ assert kwargs ["scaling" ] is None
1188
+
1057
1189
1058
1190
def new_date_ranges_to_test ():
1059
1191
yield from [
@@ -2227,6 +2359,160 @@ def test_mixed_positive_negative_channels(self, toy_y: pd.Series):
2227
2359
mmm .build_model (X = X , y = toy_y )
2228
2360
assert hasattr (mmm , "model" )
2229
2361
2362
+ def test_date_conversion_error (self , toy_y : pd .Series ):
2363
+ """Test error when date cannot be converted."""
2364
+ X = pd .DataFrame (
2365
+ {
2366
+ "date" : [object () for _ in range (50 )],
2367
+ "channel_1" : rng .integers (100 , 1000 , size = 50 ),
2368
+ "channel_2" : rng .integers (100 , 1000 , size = 50 ),
2369
+ }
2370
+ )
2371
+
2372
+ mmm = MMM (
2373
+ date_column = "date" ,
2374
+ channel_columns = ["channel_1" , "channel_2" ],
2375
+ adstock = GeometricAdstock (l_max = 4 ),
2376
+ saturation = LogisticSaturation (),
2377
+ )
2378
+
2379
+ with pytest .raises (ValueError , match = "Could not convert" ):
2380
+ mmm .build_model (X = X , y = toy_y )
2381
+
2382
+ def test_transformer_fitting_failure (
2383
+ self , toy_X : pd .DataFrame , toy_y : pd .Series , monkeypatch
2384
+ ):
2385
+ """Test handling when transformer fitting fails."""
2386
+
2387
+ def mock_failing_transformer (* args , ** kwargs ):
2388
+ raise RuntimeError ("Transformer failure" )
2389
+
2390
+ mmm = MMM (
2391
+ date_column = "date" ,
2392
+ channel_columns = ["channel_1" , "channel_2" ],
2393
+ adstock = GeometricAdstock (l_max = 4 ),
2394
+ saturation = LogisticSaturation (),
2395
+ )
2396
+
2397
+ monkeypatch .setattr (mmm , "max_abs_scale_target_data" , mock_failing_transformer )
2398
+
2399
+ with pytest .warns (UserWarning , match = "Failed to fit transformers" ):
2400
+ mmm .build_model (X = toy_X , y = toy_y )
2401
+
2402
+ def test_causal_graph_with_warnings (self , toy_X : pd .DataFrame , toy_y : pd .Series ):
2403
+ """Test causal graph initialization with warnings when treatment nodes not specified."""
2404
+ dag = """
2405
+ digraph {
2406
+ channel_1 -> y;
2407
+ channel_2 -> y;
2408
+ control_1 -> y;
2409
+ }
2410
+ """
2411
+
2412
+ with pytest .warns (UserWarning , match = "No treatment nodes provided" ):
2413
+ mmm = MMM (
2414
+ date_column = "date" ,
2415
+ channel_columns = ["channel_1" , "channel_2" ],
2416
+ control_columns = ["control_1" ],
2417
+ adstock = GeometricAdstock (l_max = 4 ),
2418
+ saturation = LogisticSaturation (),
2419
+ dag = dag ,
2420
+ outcome_node = "y" ,
2421
+ )
2422
+
2423
+ assert mmm .treatment_nodes == ["channel_1" , "channel_2" ]
2424
+
2425
+ def test_causal_graph_excludes_seasonality (
2426
+ self , toy_X : pd .DataFrame , toy_y : pd .Series
2427
+ ):
2428
+ """Test that seasonality is excluded when not in adjustment set."""
2429
+ dag = """
2430
+ digraph {
2431
+ channel_1 -> y;
2432
+ channel_2 -> y;
2433
+ control_1 -> y;
2434
+ }
2435
+ """
2436
+
2437
+ with pytest .warns (UserWarning , match = "Yearly seasonality excluded" ):
2438
+ mmm = MMM (
2439
+ date_column = "date" ,
2440
+ channel_columns = ["channel_1" , "channel_2" ],
2441
+ control_columns = ["control_1" ],
2442
+ adstock = GeometricAdstock (l_max = 4 ),
2443
+ saturation = LogisticSaturation (),
2444
+ yearly_seasonality = 2 ,
2445
+ dag = dag ,
2446
+ treatment_nodes = ["channel_1" , "channel_2" ],
2447
+ outcome_node = "y" ,
2448
+ )
2449
+
2450
+ assert mmm .yearly_seasonality is None
2451
+
2452
+ def test_x_data_not_dataframe_error (self , toy_y : pd .Series ):
2453
+ """Test error when X data is not a DataFrame."""
2454
+ mmm = MMM (
2455
+ date_column = "date" ,
2456
+ channel_columns = ["channel_1" , "channel_2" ],
2457
+ adstock = GeometricAdstock (l_max = 4 ),
2458
+ saturation = LogisticSaturation (),
2459
+ )
2460
+
2461
+ mmm .preprocessed_data = {"X" : np .array ([[1 , 2 ], [3 , 4 ]]), "y" : toy_y }
2462
+
2463
+ with pytest .raises (TypeError , match = "must be a DataFrame" ):
2464
+ mmm ._compute_scales ()
2465
+
2466
+ def test_invalid_synth_dataset_granularity (
2467
+ self , mmm_fitted : MMM , toy_X : pd .DataFrame
2468
+ ):
2469
+ """Test invalid time granularity in synthetic dataset creation."""
2470
+ from xarray import DataArray
2471
+
2472
+ allocation = DataArray (
2473
+ [1000 , 2000 ],
2474
+ dims = ["channel" ],
2475
+ coords = {"channel" : mmm_fitted .channel_columns },
2476
+ )
2477
+
2478
+ with pytest .raises (ValueError , match = "Unsupported time granularity" ):
2479
+ mmm_fitted ._create_synth_dataset (
2480
+ df = toy_X ,
2481
+ date_column = "date" ,
2482
+ allocation_strategy = allocation ,
2483
+ channels = mmm_fitted .channel_columns ,
2484
+ controls = None ,
2485
+ target_col = "y" ,
2486
+ time_granularity = "invalid_granularity" ,
2487
+ time_length = 10 ,
2488
+ lag = 4 ,
2489
+ )
2490
+
2491
+ def test_invalid_allocation_dims (self , mmm_fitted : MMM , toy_X : pd .DataFrame ):
2492
+ """Test error when allocation strategy has wrong dimensions."""
2493
+ from xarray import DataArray
2494
+
2495
+ allocation = DataArray (
2496
+ [1000 , 2000 ],
2497
+ dims = ["wrong_dim" ],
2498
+ coords = {"wrong_dim" : ["a" , "b" ]},
2499
+ )
2500
+
2501
+ with pytest .raises (
2502
+ ValueError , match = "must have a single dimension named 'channel'"
2503
+ ):
2504
+ mmm_fitted ._create_synth_dataset (
2505
+ df = toy_X ,
2506
+ date_column = "date" ,
2507
+ allocation_strategy = allocation ,
2508
+ channels = mmm_fitted .channel_columns ,
2509
+ controls = None ,
2510
+ target_col = "y" ,
2511
+ time_granularity = "weekly" ,
2512
+ time_length = 10 ,
2513
+ lag = 4 ,
2514
+ )
2515
+
2230
2516
2231
2517
class TestMMMHelperMethods :
2232
2518
"""Tests for internal helper methods added during refactoring."""
@@ -2630,3 +2916,13 @@ def test_add_original_scale_deterministics(
2630
2916
assert "control_contribution_original_scale" in deterministic_names
2631
2917
assert "yearly_seasonality_contribution_original_scale" in deterministic_names
2632
2918
assert "y_original_scale" in deterministic_names
2919
+
2920
+ def test_prepare_target_data_with_none_and_no_preprocessed (self , mmm_fitted : MMM ):
2921
+ """Test _prepare_target_data when y is None but preprocessed_data doesn't match expected types."""
2922
+ mmm_fitted .preprocessed_data ["y" ] = "not_a_series_or_array" # type: ignore
2923
+
2924
+ result = mmm_fitted ._prepare_target_data (y = None , n_rows = 10 )
2925
+
2926
+ assert "target_data" in result or "target" in result
2927
+ if "target_data" in result :
2928
+ assert len (result ["target_data" ]) == 10
0 commit comments