@@ -2382,3 +2382,151 @@ def test_specify_time_varying_configuration(
2382
2382
mmm .model [expected_rv ["name" ]].owner .op .__class__ .__name__
2383
2383
== expected_rv ["kind" ]
2384
2384
)
2385
+
2386
+
2387
+ def test_multidimensional_mmm_serializes_and_deserializes_dag_and_nodes (
2388
+ single_dim_data , mock_pymc_sample
2389
+ ):
2390
+ dag = """
2391
+ digraph {
2392
+ channel_1 -> y;
2393
+ control_1 -> channel_1;
2394
+ control_1 -> y;
2395
+ }
2396
+ """
2397
+ treatment_nodes = ["channel_1" ]
2398
+ outcome_node = "y"
2399
+
2400
+ X , y = single_dim_data
2401
+ y = y .rename ("y" )
2402
+
2403
+ mmm = MMM (
2404
+ date_column = "date" ,
2405
+ target_column = "y" ,
2406
+ channel_columns = ["channel_1" , "channel_2" ],
2407
+ adstock = GeometricAdstock (l_max = 2 ),
2408
+ saturation = LogisticSaturation (),
2409
+ dag = dag ,
2410
+ treatment_nodes = treatment_nodes ,
2411
+ outcome_node = outcome_node ,
2412
+ )
2413
+
2414
+ mmm .fit (X = X , y = y )
2415
+
2416
+ mmm .save ("test_model_multi" )
2417
+ loaded_mmm = MMM .load ("test_model_multi" )
2418
+
2419
+ assert loaded_mmm .dag == dag
2420
+ assert loaded_mmm .treatment_nodes == treatment_nodes
2421
+ assert loaded_mmm .outcome_node == outcome_node
2422
+
2423
+
2424
+ def test_multidimensional_mmm_causal_attributes_initialization ():
2425
+ dag = """
2426
+ digraph {
2427
+ channel_1 -> target;
2428
+ control_1 -> channel_1;
2429
+ control_1 -> target;
2430
+ }
2431
+ """
2432
+ treatment_nodes = ["channel_1" ]
2433
+ outcome_node = "target"
2434
+
2435
+ mmm = MMM (
2436
+ date_column = "date" ,
2437
+ target_column = "target" ,
2438
+ channel_columns = ["channel_1" , "channel_2" ],
2439
+ control_columns = ["control_1" , "control_2" ],
2440
+ adstock = GeometricAdstock (l_max = 2 ),
2441
+ saturation = LogisticSaturation (),
2442
+ dag = dag ,
2443
+ treatment_nodes = treatment_nodes ,
2444
+ outcome_node = outcome_node ,
2445
+ )
2446
+
2447
+ assert mmm .dag == dag
2448
+ assert mmm .treatment_nodes == treatment_nodes
2449
+ assert mmm .outcome_node == outcome_node
2450
+
2451
+
2452
+ def test_multidimensional_mmm_causal_attributes_default_treatment_nodes ():
2453
+ dag = """
2454
+ digraph {
2455
+ channel_1 -> target;
2456
+ channel_2 -> target;
2457
+ control_1 -> channel_1;
2458
+ control_1 -> target;
2459
+ }
2460
+ """
2461
+ outcome_node = "target"
2462
+
2463
+ with pytest .warns (
2464
+ UserWarning , match = "No treatment nodes provided, using channel columns"
2465
+ ):
2466
+ mmm = MMM (
2467
+ date_column = "date" ,
2468
+ target_column = "target" ,
2469
+ channel_columns = ["channel_1" , "channel_2" ],
2470
+ control_columns = ["control_1" , "control_2" ],
2471
+ adstock = GeometricAdstock (l_max = 2 ),
2472
+ saturation = LogisticSaturation (),
2473
+ dag = dag ,
2474
+ outcome_node = outcome_node ,
2475
+ )
2476
+
2477
+ assert mmm .treatment_nodes == ["channel_1" , "channel_2" ]
2478
+ assert mmm .outcome_node == "target"
2479
+
2480
+
2481
+ def test_multidimensional_mmm_adjustment_set_updates_control_columns ():
2482
+ dag = """
2483
+ digraph {
2484
+ channel_1 -> target;
2485
+ control_1 -> channel_1;
2486
+ control_1 -> target;
2487
+ }
2488
+ """
2489
+ treatment_nodes = ["channel_1" ]
2490
+ outcome_node = "target"
2491
+
2492
+ mmm = MMM (
2493
+ date_column = "date" ,
2494
+ target_column = "target" ,
2495
+ channel_columns = ["channel_1" , "channel_2" ],
2496
+ control_columns = ["control_1" , "control_2" ],
2497
+ adstock = GeometricAdstock (l_max = 2 ),
2498
+ saturation = LogisticSaturation (),
2499
+ dag = dag ,
2500
+ treatment_nodes = treatment_nodes ,
2501
+ outcome_node = outcome_node ,
2502
+ )
2503
+
2504
+ assert mmm .control_columns == ["control_1" ]
2505
+
2506
+
2507
+ def test_multidimensional_mmm_missing_dag_does_not_initialize_causal_graph ():
2508
+ mmm = MMM (
2509
+ date_column = "date" ,
2510
+ target_column = "target" ,
2511
+ channel_columns = ["channel_1" , "channel_2" ],
2512
+ adstock = GeometricAdstock (l_max = 2 ),
2513
+ saturation = LogisticSaturation (),
2514
+ )
2515
+
2516
+ assert mmm .dag is None
2517
+ assert not hasattr (mmm , "causal_graphical_model" )
2518
+
2519
+
2520
+ def test_multidimensional_mmm_only_dag_provided_does_not_initialize_graph ():
2521
+ mmm = MMM (
2522
+ date_column = "date" ,
2523
+ target_column = "target" ,
2524
+ channel_columns = ["channel_1" , "channel_2" ],
2525
+ adstock = GeometricAdstock (l_max = 2 ),
2526
+ saturation = LogisticSaturation (),
2527
+ dag = "digraph {channel_1 -> target;}" ,
2528
+ )
2529
+
2530
+ assert mmm .treatment_nodes is None
2531
+ assert mmm .outcome_node is None
2532
+ assert not hasattr (mmm , "causal_graphical_model" )
0 commit comments