Skip to content

Commit 488ac3c

Browse files
drbenvincentcetagostiniwilliambdeanjuanitorduz
authored
Sensitivity analysis and marginal effects (#1673)
* initial stab at CounterfactualSweep class + associated example notebook * attempt to add the new notebook to the examples gallery * delete commented code * fix example in docs and re-run notebook with some hidden inputs/outputs * add some TODO's to the notebook * Update pymc_marketing/mmm/marginal_effects.py Co-authored-by: Will Dean <[email protected]> * improve type hinting * update docstring of plot_marginal_effects method * Use Literal in type hint * change to use pymc_marketing.mmm.multidimensional.MMM * scaling of the marginal effects plot to not put undue emphasis on numerical imprecision * Results now returned as self contained xr.Dataset. Plot methods are now static methods * X no longer required as an input to CounterfactualSweep * remove redundant sweep_values index * rename to SensivityAnalysis * compute gradient with xarray instead of numpy * add MMM.sensitivity_analysis as wrapper to call SensitivityAnalysis * formatting * rename notebook * remove commented code in notebook * fix scaling + add crosshairs on plots * combine into a single plot function * api change, results now stored in idata, and fix crosshairs * minor tweaks * better sweep values for additive sweep example * move plot_sensitivity_analysis into MMMPlotSuite * rename example in the gallery view. Docs updated * add functionality to plot y-axis in percentage terms * add a check for presence of idata.sensitivity_analysis * update API according to Carlos' suggestions * predictors -> var_names * Add tests for SensitivityAnalysis class * Add tests for plot_sensitivity_analysis in sensitivity analysis * more tests to increase code coverage of plot code * update notebook --------- Co-authored-by: Carlos Trujillo <[email protected]> Co-authored-by: Will Dean <[email protected]> Co-authored-by: Juan Orduz <[email protected]>
1 parent bf00dfd commit 488ac3c

File tree

7 files changed

+10389
-0
lines changed

7 files changed

+10389
-0
lines changed

docs/source/gallery/gallery.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ Welcome to the PyMC-Marketing example gallery! This gallery provides visual navi
3030
:img-top: ../gallery/images/mmm_multidimensional_example.png
3131
:link: ../notebooks/mmm/mmm_multidimensional_example.html
3232
:::
33+
34+
:::{grid-item-card} Sensitivity Analysis and Marginal Effects
35+
:img-top: ../gallery/images/mmm_sensitivity_analysis.png
36+
:link: ../notebooks/mmm/mmm_sensitivity_analysis.html
37+
:::
3338
::::
3439

3540
### Budget Allocation
@@ -110,6 +115,7 @@ Welcome to the PyMC-Marketing example gallery! This gallery provides visual navi
110115
:img-top: ../gallery/images/mmm_counterfactuals.png
111116
:link: ../notebooks/mmm/mmm_counterfactuals.html
112117
:::
118+
113119
::::
114120

115121
### Case Studies

docs/source/notebooks/mmm/mmm_sensitivity_analysis.ipynb

Lines changed: 9449 additions & 0 deletions
Large diffs are not rendered by default.

pymc_marketing/mmm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
preprocessing_method_X,
6363
preprocessing_method_y,
6464
)
65+
from pymc_marketing.mmm.sensitivity_analysis import SensitivityAnalysis
6566
from pymc_marketing.mmm.validating import validation_method_X, validation_method_y
6667

6768
__all__ = [
@@ -90,6 +91,7 @@
9091
"PeriodicCovFunc",
9192
"RootSaturation",
9293
"SaturationTransformation",
94+
"SensitivityAnalysis",
9395
"SoftPlusHSGP",
9496
"TanhSaturation",
9597
"TanhSaturationBaselined",

pymc_marketing/mmm/multidimensional.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
)
5454
from pymc_marketing.mmm.plot import MMMPlotSuite
5555
from pymc_marketing.mmm.scaling import Scaling, VariableScaling
56+
from pymc_marketing.mmm.sensitivity_analysis import SensitivityAnalysis
5657
from pymc_marketing.mmm.tvp import infer_time_index
5758
from pymc_marketing.mmm.utility import UtilityFunctionType, average_response
5859
from pymc_marketing.mmm.utils import (
@@ -1433,6 +1434,28 @@ def sample_posterior_predictive(
14331434

14341435
return posterior_predictive_samples
14351436

1437+
@property
1438+
def sensitivity(self) -> SensitivityAnalysis:
1439+
"""Access sensitivity analysis functionality.
1440+
1441+
Returns a SensitivityAnalysis instance that can be used to run
1442+
counterfactual sweeps on the model.
1443+
1444+
Returns
1445+
-------
1446+
SensitivityAnalysis
1447+
An instance configured with this MMM model.
1448+
1449+
Examples
1450+
--------
1451+
>>> mmm.sensitivity.run_sweep(
1452+
... var_names=["channel_1", "channel_2"],
1453+
... sweep_values=np.linspace(0.5, 2.0, 10),
1454+
... sweep_type="multiplicative",
1455+
... )
1456+
"""
1457+
return SensitivityAnalysis(mmm=self)
1458+
14361459
def _make_channel_transform(
14371460
self, df_lift_test: pd.DataFrame
14381461
) -> Callable[[np.ndarray], np.ndarray]:

pymc_marketing/mmm/plot.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,3 +1243,115 @@ def allocated_contribution_by_channel_over_time(
12431243

12441244
fig.tight_layout()
12451245
return fig, axes
1246+
1247+
def plot_sensitivity_analysis(
1248+
self,
1249+
hdi_prob: float = 0.94,
1250+
ax: plt.Axes | None = None,
1251+
marginal: bool = False,
1252+
percentage: bool = False,
1253+
) -> plt.Axes:
1254+
"""
1255+
Plot the counterfactual uplift or marginal effects curve.
1256+
1257+
Parameters
1258+
----------
1259+
results : xr.Dataset
1260+
The dataset containing the results of the sweep.
1261+
hdi_prob : float, optional
1262+
The probability for computing the highest density interval (HDI). Default is 0.94.
1263+
ax : Optional[plt.Axes], optional
1264+
An optional matplotlib Axes on which to plot. If None, a new Axes is created.
1265+
marginal : bool, optional
1266+
If True, plot marginal effects. If False (default), plot uplift.
1267+
percentage : bool, optional
1268+
If True, plot the results on the y-axis as percentages, instead of absolute
1269+
values. Default is False.
1270+
1271+
Returns
1272+
-------
1273+
plt.Axes
1274+
The Axes object with the plot.
1275+
"""
1276+
if ax is None:
1277+
_, ax = plt.subplots(figsize=(10, 6))
1278+
1279+
if percentage and marginal:
1280+
raise ValueError("Not implemented marginal effects in percentage scale.")
1281+
1282+
# Check if sensitivity analysis results exist in idata
1283+
if not hasattr(self.idata, "sensitivity_analysis"):
1284+
raise ValueError(
1285+
"No sensitivity analysis results found in 'self.idata'. "
1286+
"Please run the sensitivity analysis first using 'mmm.sensitivity.run_sweep()' method."
1287+
)
1288+
1289+
# grab sensitivity analysis results from idata
1290+
results = self.idata.sensitivity_analysis
1291+
1292+
x = results.sweep.values
1293+
if marginal:
1294+
y = results.marginal_effects.mean(dim=["chain", "draw"]).sum(dim="date")
1295+
y_hdi = results.marginal_effects.sum(dim="date")
1296+
color = "C1"
1297+
label = "Posterior mean marginal effect"
1298+
title = "Marginal effects plot"
1299+
ylabel = r"Marginal effect, $\frac{d\mathbb{E}[Y]}{dX}$"
1300+
else:
1301+
if percentage:
1302+
actual = self.idata.posterior_predictive["y"]
1303+
y = results.y.mean(dim=["chain", "draw"]).sum(dim="date") / actual.mean(
1304+
dim=["chain", "draw"]
1305+
).sum(dim="date")
1306+
y_hdi = results.y.sum(dim="date") / actual.sum(dim="date")
1307+
else:
1308+
y = results.y.mean(dim=["chain", "draw"]).sum(dim="date")
1309+
y_hdi = results.y.sum(dim="date")
1310+
color = "C0"
1311+
label = "Posterior mean"
1312+
title = "Sensitivity analysis plot"
1313+
ylabel = "Total uplift (sum over dates)"
1314+
1315+
ax.plot(x, y, label=label, color=color)
1316+
1317+
az.plot_hdi(
1318+
x,
1319+
y_hdi,
1320+
hdi_prob=hdi_prob,
1321+
color=color,
1322+
fill_kwargs={"alpha": 0.5, "label": f"{hdi_prob * 100:.0f}% HDI"},
1323+
plot_kwargs={"color": color, "alpha": 0.5},
1324+
smooth=False,
1325+
ax=ax,
1326+
)
1327+
1328+
ax.set(title=title)
1329+
if results.sweep_type == "absolute":
1330+
ax.set_xlabel(f"Absolute value of: {results.var_names}")
1331+
else:
1332+
ax.set_xlabel(
1333+
f"{results.sweep_type.capitalize()} change of: {results.var_names}"
1334+
)
1335+
ax.set_ylabel(ylabel)
1336+
plt.legend()
1337+
1338+
# Set y-axis limits based on the sign of y values
1339+
y_values = y.values if hasattr(y, "values") else np.array(y)
1340+
if np.all(y_values < 0):
1341+
ax.set_ylim(top=0)
1342+
elif np.all(y_values > 0):
1343+
ax.set_ylim(bottom=0)
1344+
1345+
ax.yaxis.set_major_formatter(
1346+
plt.FuncFormatter(lambda x, _: f"{x:.1%}" if percentage else f"{x:,.1f}")
1347+
)
1348+
1349+
# Add reference lines
1350+
if results.sweep_type == "multiplicative":
1351+
ax.axvline(x=1, color="k", linestyle="--", alpha=0.5)
1352+
if not marginal:
1353+
ax.axhline(y=0, color="k", linestyle="--", alpha=0.5)
1354+
elif results.sweep_type == "additive":
1355+
ax.axvline(x=0, color="k", linestyle="--", alpha=0.5)
1356+
1357+
return ax
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright 2022 - 2025 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Counterfactual sweeps for Marketing Mix Models (MMM)."""
16+
17+
from typing import Literal
18+
19+
import numpy as np
20+
import pandas as pd
21+
import xarray as xr
22+
23+
24+
class SensitivityAnalysis:
25+
"""SensitivityAnalysis class is used to perform counterfactual analysis on MMM's."""
26+
27+
def __init__(self, mmm) -> None:
28+
"""
29+
Initialize the SensitivityAnalysis with a reference to the MMM instance.
30+
31+
Parameters
32+
----------
33+
mmm : MMM
34+
The marketing mix model instance used for predictions.
35+
"""
36+
self.mmm = mmm
37+
38+
def run_sweep(
39+
self,
40+
var_names: list[str],
41+
sweep_values: np.ndarray,
42+
sweep_type: Literal[
43+
"multiplicative", "additive", "absolute"
44+
] = "multiplicative",
45+
) -> xr.Dataset:
46+
"""Run the model's predict function over the sweep grid and store results.
47+
48+
Parameters
49+
----------
50+
var_names : list[str]
51+
List of variable names to intervene on.
52+
sweep_values : np.ndarray
53+
Array of sweep values.
54+
sweep_type : Literal["multiplicative", "additive", "absolute"], optional
55+
Type of intervention to apply, by default "multiplicative".
56+
- 'multiplicative': Multiply the original predictor values by each sweep value.
57+
- 'additive': Add each sweep value to the original predictor values.
58+
- 'absolute': Set the predictor values directly to each sweep value (ignoring original values).
59+
60+
Returns
61+
-------
62+
xr.Dataset
63+
Dataset containing the sensitivity analysis results.
64+
"""
65+
# Validate that idata exists
66+
if not hasattr(self.mmm, "idata"):
67+
raise ValueError("idata does not exist. Build the model first and fit.")
68+
69+
# Store parameters for this run
70+
self.var_names = var_names
71+
self.sweep_values = sweep_values
72+
self.sweep_type = sweep_type
73+
74+
# TODO: Ideally we can use this --------------------------------------------
75+
# actual = self.mmm._get_group_predictive_data(
76+
# group="posterior_predictive", original_scale=True
77+
# )["y"]
78+
actual = self.mmm.idata["posterior_predictive"]["y"]
79+
# --------------------------------------------------------------------------
80+
predictions = []
81+
for sweep_value in self.sweep_values:
82+
X_new = self.create_intervention(sweep_value)
83+
counterfac = self.mmm.predict(X_new, extend_idata=False, progressbar=False)
84+
uplift = counterfac - actual
85+
predictions.append(uplift)
86+
87+
results = (
88+
xr.concat(predictions, dim="sweep")
89+
.assign_coords(sweep=self.sweep_values)
90+
.transpose(..., "sweep")
91+
)
92+
93+
marginal_effects = self.compute_marginal_effects(results, self.sweep_values)
94+
95+
results = xr.Dataset(
96+
{
97+
"y": results,
98+
"marginal_effects": marginal_effects,
99+
}
100+
)
101+
# Add metadata to the results
102+
results.attrs["sweep_type"] = self.sweep_type
103+
results.attrs["var_names"] = self.var_names
104+
105+
# Add results to the MMM's idata
106+
if hasattr(self.mmm.idata, "sensitivity_analysis"):
107+
delattr(self.mmm.idata, "sensitivity_analysis")
108+
self.mmm.idata.add_groups({"sensitivity_analysis": results}) # type: ignore
109+
110+
return results
111+
112+
def create_intervention(self, sweep_value: float) -> pd.DataFrame:
113+
"""Apply the intervention to the predictors."""
114+
X_new = self.mmm.X.copy()
115+
if self.sweep_type == "multiplicative":
116+
for var_name in self.var_names:
117+
X_new[var_name] *= sweep_value
118+
elif self.sweep_type == "additive":
119+
for var_name in self.var_names:
120+
X_new[var_name] += sweep_value
121+
elif self.sweep_type == "absolute":
122+
for var_name in self.var_names:
123+
X_new[var_name] = sweep_value
124+
else:
125+
raise ValueError(f"Unsupported sweep_type: {self.sweep_type}")
126+
return X_new
127+
128+
@staticmethod
129+
def compute_marginal_effects(results, sweep_values) -> xr.DataArray:
130+
"""Compute marginal effects via finite differences from the sweep results."""
131+
marginal_effects = results.differentiate(coord="sweep")
132+
marginal_effects = xr.DataArray(
133+
marginal_effects,
134+
dims=results.dims,
135+
coords=results.coords,
136+
)
137+
return marginal_effects

0 commit comments

Comments
 (0)