Skip to content

Commit be685bd

Browse files
authored
Add causal graph support to MMM with DAG serialization (#1882)
Introduces optional causal graph configuration to the MMM class, allowing users to specify a DAG, treatment nodes, and outcome node for causal identification. The model now serializes and deserializes these attributes, updates control columns based on computed adjustment sets, and includes tests for the new causal graph functionality.
1 parent 5924a65 commit be685bd

File tree

2 files changed

+204
-0
lines changed

2 files changed

+204
-0
lines changed

pymc_marketing/mmm/multidimensional.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from pymc_marketing.mmm import SoftPlusHSGP
3838
from pymc_marketing.mmm.additive_effect import EventAdditiveEffect, MuEffect
3939
from pymc_marketing.mmm.budget_optimizer import OptimizerCompatibleModelWrapper
40+
from pymc_marketing.mmm.causal import CausalGraphModel
4041
from pymc_marketing.mmm.components.adstock import (
4142
AdstockTransformation,
4243
adstock_from_dict,
@@ -176,6 +177,17 @@ def __init__(
176177
bool,
177178
Field(strict=True, description="Apply adstock before saturation?"),
178179
] = True,
180+
dag: str | None = Field(
181+
None,
182+
description="Optional DAG provided as a string Dot format for causal identification.",
183+
),
184+
treatment_nodes: list[str] | tuple[str] | None = Field(
185+
None,
186+
description="Column names of the variables of interest to identify causal effects on outcome.",
187+
),
188+
outcome_node: str | None = Field(
189+
None, description="Name of the outcome variable."
190+
),
179191
) -> None:
180192
"""Define the constructor method."""
181193
# Your existing initialization logic
@@ -234,6 +246,44 @@ def __init__(
234246
self.channel_columns = channel_columns
235247
self.yearly_seasonality = yearly_seasonality
236248

249+
# Causal graph configuration
250+
self.dag = dag
251+
self.treatment_nodes = treatment_nodes
252+
self.outcome_node = outcome_node
253+
254+
# Initialize causal graph if provided
255+
if self.dag is not None and self.outcome_node is not None:
256+
if self.treatment_nodes is None:
257+
self.treatment_nodes = self.channel_columns
258+
warnings.warn(
259+
"No treatment nodes provided, using channel columns as treatment nodes.",
260+
stacklevel=2,
261+
)
262+
self.causal_graphical_model = CausalGraphModel.build_graphical_model(
263+
graph=self.dag,
264+
treatment=self.treatment_nodes,
265+
outcome=self.outcome_node,
266+
)
267+
268+
self.control_columns = self.causal_graphical_model.compute_adjustment_sets(
269+
control_columns=self.control_columns,
270+
channel_columns=self.channel_columns,
271+
)
272+
273+
# Only apply yearly seasonality adjustment if an adjustment set was computed
274+
if hasattr(self.causal_graphical_model, "adjustment_set") and (
275+
self.causal_graphical_model.adjustment_set is not None
276+
):
277+
if (
278+
"yearly_seasonality"
279+
not in self.causal_graphical_model.adjustment_set
280+
):
281+
warnings.warn(
282+
"Yearly seasonality excluded as it's not required for adjustment.",
283+
stacklevel=2,
284+
)
285+
self.yearly_seasonality = None
286+
237287
super().__init__(model_config=model_config, sampler_config=sampler_config)
238288

239289
if self.yearly_seasonality is not None:
@@ -335,6 +385,9 @@ def create_idata_attrs(self) -> dict[str, str]:
335385
attrs["time_varying_media"] = json.dumps(self.time_varying_media)
336386
attrs["target_column"] = self.target_column
337387
attrs["scaling"] = json.dumps(self.scaling.model_dump(mode="json"))
388+
attrs["dag"] = json.dumps(getattr(self, "dag", None))
389+
attrs["treatment_nodes"] = json.dumps(getattr(self, "treatment_nodes", None))
390+
attrs["outcome_node"] = json.dumps(getattr(self, "outcome_node", None))
338391

339392
return attrs
340393

@@ -395,6 +448,9 @@ def attrs_to_init_kwargs(cls, attrs: dict[str, str]) -> dict[str, Any]:
395448
"sampler_config": json.loads(attrs["sampler_config"]),
396449
"dims": tuple(json.loads(attrs.get("dims", "[]"))),
397450
"scaling": json.loads(attrs.get("scaling", "null")),
451+
"dag": json.loads(attrs.get("dag", "null")),
452+
"treatment_nodes": json.loads(attrs.get("treatment_nodes", "null")),
453+
"outcome_node": json.loads(attrs.get("outcome_node", "null")),
398454
}
399455

400456
@property

tests/mmm/test_multidimensional.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2382,3 +2382,151 @@ def test_specify_time_varying_configuration(
23822382
mmm.model[expected_rv["name"]].owner.op.__class__.__name__
23832383
== expected_rv["kind"]
23842384
)
2385+
2386+
2387+
def test_multidimensional_mmm_serializes_and_deserializes_dag_and_nodes(
2388+
single_dim_data, mock_pymc_sample
2389+
):
2390+
dag = """
2391+
digraph {
2392+
channel_1 -> y;
2393+
control_1 -> channel_1;
2394+
control_1 -> y;
2395+
}
2396+
"""
2397+
treatment_nodes = ["channel_1"]
2398+
outcome_node = "y"
2399+
2400+
X, y = single_dim_data
2401+
y = y.rename("y")
2402+
2403+
mmm = MMM(
2404+
date_column="date",
2405+
target_column="y",
2406+
channel_columns=["channel_1", "channel_2"],
2407+
adstock=GeometricAdstock(l_max=2),
2408+
saturation=LogisticSaturation(),
2409+
dag=dag,
2410+
treatment_nodes=treatment_nodes,
2411+
outcome_node=outcome_node,
2412+
)
2413+
2414+
mmm.fit(X=X, y=y)
2415+
2416+
mmm.save("test_model_multi")
2417+
loaded_mmm = MMM.load("test_model_multi")
2418+
2419+
assert loaded_mmm.dag == dag
2420+
assert loaded_mmm.treatment_nodes == treatment_nodes
2421+
assert loaded_mmm.outcome_node == outcome_node
2422+
2423+
2424+
def test_multidimensional_mmm_causal_attributes_initialization():
2425+
dag = """
2426+
digraph {
2427+
channel_1 -> target;
2428+
control_1 -> channel_1;
2429+
control_1 -> target;
2430+
}
2431+
"""
2432+
treatment_nodes = ["channel_1"]
2433+
outcome_node = "target"
2434+
2435+
mmm = MMM(
2436+
date_column="date",
2437+
target_column="target",
2438+
channel_columns=["channel_1", "channel_2"],
2439+
control_columns=["control_1", "control_2"],
2440+
adstock=GeometricAdstock(l_max=2),
2441+
saturation=LogisticSaturation(),
2442+
dag=dag,
2443+
treatment_nodes=treatment_nodes,
2444+
outcome_node=outcome_node,
2445+
)
2446+
2447+
assert mmm.dag == dag
2448+
assert mmm.treatment_nodes == treatment_nodes
2449+
assert mmm.outcome_node == outcome_node
2450+
2451+
2452+
def test_multidimensional_mmm_causal_attributes_default_treatment_nodes():
2453+
dag = """
2454+
digraph {
2455+
channel_1 -> target;
2456+
channel_2 -> target;
2457+
control_1 -> channel_1;
2458+
control_1 -> target;
2459+
}
2460+
"""
2461+
outcome_node = "target"
2462+
2463+
with pytest.warns(
2464+
UserWarning, match="No treatment nodes provided, using channel columns"
2465+
):
2466+
mmm = MMM(
2467+
date_column="date",
2468+
target_column="target",
2469+
channel_columns=["channel_1", "channel_2"],
2470+
control_columns=["control_1", "control_2"],
2471+
adstock=GeometricAdstock(l_max=2),
2472+
saturation=LogisticSaturation(),
2473+
dag=dag,
2474+
outcome_node=outcome_node,
2475+
)
2476+
2477+
assert mmm.treatment_nodes == ["channel_1", "channel_2"]
2478+
assert mmm.outcome_node == "target"
2479+
2480+
2481+
def test_multidimensional_mmm_adjustment_set_updates_control_columns():
2482+
dag = """
2483+
digraph {
2484+
channel_1 -> target;
2485+
control_1 -> channel_1;
2486+
control_1 -> target;
2487+
}
2488+
"""
2489+
treatment_nodes = ["channel_1"]
2490+
outcome_node = "target"
2491+
2492+
mmm = MMM(
2493+
date_column="date",
2494+
target_column="target",
2495+
channel_columns=["channel_1", "channel_2"],
2496+
control_columns=["control_1", "control_2"],
2497+
adstock=GeometricAdstock(l_max=2),
2498+
saturation=LogisticSaturation(),
2499+
dag=dag,
2500+
treatment_nodes=treatment_nodes,
2501+
outcome_node=outcome_node,
2502+
)
2503+
2504+
assert mmm.control_columns == ["control_1"]
2505+
2506+
2507+
def test_multidimensional_mmm_missing_dag_does_not_initialize_causal_graph():
2508+
mmm = MMM(
2509+
date_column="date",
2510+
target_column="target",
2511+
channel_columns=["channel_1", "channel_2"],
2512+
adstock=GeometricAdstock(l_max=2),
2513+
saturation=LogisticSaturation(),
2514+
)
2515+
2516+
assert mmm.dag is None
2517+
assert not hasattr(mmm, "causal_graphical_model")
2518+
2519+
2520+
def test_multidimensional_mmm_only_dag_provided_does_not_initialize_graph():
2521+
mmm = MMM(
2522+
date_column="date",
2523+
target_column="target",
2524+
channel_columns=["channel_1", "channel_2"],
2525+
adstock=GeometricAdstock(l_max=2),
2526+
saturation=LogisticSaturation(),
2527+
dag="digraph {channel_1 -> target;}",
2528+
)
2529+
2530+
assert mmm.treatment_nodes is None
2531+
assert mmm.outcome_node is None
2532+
assert not hasattr(mmm, "causal_graphical_model")

0 commit comments

Comments
 (0)