Skip to content

Commit 26d1d2a

Browse files
PabloRoquejuanitorduzcetagostini
authored
Multidimensional support in plot_sensitivity_analysis (#1886)
* Multidimensional support in plot_sensitivity_analysis * Add tests for multidim sensitivity analysis plots * Minor string change --------- Co-authored-by: Juan Orduz <[email protected]> Co-authored-by: Carlos Trujillo <[email protected]>
1 parent b3865fa commit 26d1d2a

File tree

3 files changed

+283
-84
lines changed

3 files changed

+283
-84
lines changed

pymc_marketing/mmm/plot.py

Lines changed: 182 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,108 +1252,207 @@ def plot_sensitivity_analysis(
12521252
ax: plt.Axes | None = None,
12531253
marginal: bool = False,
12541254
percentage: bool = False,
1255-
) -> plt.Axes:
1255+
sharey: bool = True,
1256+
) -> tuple[Figure, NDArray[Axes]] | plt.Axes:
12561257
"""
1257-
Plot the counterfactual uplift or marginal effects curve.
1258+
Plot counterfactual uplift or marginal effects curves.
1259+
1260+
Handles additional (non sweep/date/chain/draw) dimensions by creating one subplot
1261+
per combination of those dimensions - consistent with other plot_* methods.
12581262
12591263
Parameters
12601264
----------
1261-
results : xr.Dataset
1262-
The dataset containing the results of the sweep.
1263-
hdi_prob : float, optional
1264-
The probability for computing the highest density interval (HDI). Default is 0.94.
1265-
ax : Optional[plt.Axes], optional
1266-
An optional matplotlib Axes on which to plot. If None, a new Axes is created.
1267-
marginal : bool, optional
1268-
If True, plot marginal effects. If False (default), plot uplift.
1269-
percentage : bool, optional
1270-
If True, plot the results on the y-axis as percentages, instead of absolute
1271-
values. Default is False.
1265+
hdi_prob : float, default 0.94
1266+
HDI probability mass.
1267+
ax : plt.Axes, optional
1268+
Only used when there are no extra dimensions (single panel case).
1269+
marginal : bool, default False
1270+
Plot marginal effects instead of uplift.
1271+
percentage : bool, default False
1272+
Express uplift as a percentage of actual (not supported for marginal).
1273+
sharey : bool, default True
1274+
Share y-axis across subplots (only relevant for multi-panel case).
12721275
12731276
Returns
12741277
-------
1275-
plt.Axes
1276-
The Axes object with the plot.
1278+
(fig, axes) if multi-panel, else a single Axes (backwards compatible single-dim case).
12771279
"""
1278-
if ax is None:
1279-
_, ax = plt.subplots(figsize=(10, 6))
1280-
12811280
if percentage and marginal:
12821281
raise ValueError("Not implemented marginal effects in percentage scale.")
12831282

1284-
# Check if sensitivity analysis results exist in idata
12851283
if not hasattr(self.idata, "sensitivity_analysis"):
12861284
raise ValueError(
12871285
"No sensitivity analysis results found in 'self.idata'. "
1288-
"Please run the sensitivity analysis first using 'mmm.sensitivity.run_sweep()' method."
1286+
"Run 'mmm.sensitivity.run_sweep()' first."
1287+
)
1288+
1289+
results: xr.Dataset = self.idata.sensitivity_analysis # type: ignore
1290+
1291+
# Required variable presence checks
1292+
required_var = "marginal_effects" if marginal else "y"
1293+
if required_var not in results:
1294+
raise ValueError(
1295+
f"Expected '{required_var}' in sensitivity_analysis results, found: {list(results.data_vars)}"
1296+
)
1297+
if "sweep" not in results.dims:
1298+
raise ValueError(
1299+
"Sensitivity analysis results must contain 'sweep' dimension."
12891300
)
12901301

1291-
# grab sensitivity analysis results from idata
1292-
results = self.idata.sensitivity_analysis
1293-
1294-
x = results.sweep.values
1295-
if marginal:
1296-
y = results.marginal_effects.mean(dim=["chain", "draw"]).sum(dim="date")
1297-
y_hdi = results.marginal_effects.sum(dim="date")
1298-
color = "C1"
1299-
label = "Posterior mean marginal effect"
1300-
title = "Marginal effects plot"
1301-
ylabel = r"Marginal effect, $\frac{d\mathbb{E}[Y]}{dX}$"
1302+
# Identify additional dimensions
1303+
ignored_dims = {"chain", "draw", "date", "sweep"}
1304+
base_data = results.marginal_effects if marginal else results.y
1305+
additional_dims = [d for d in base_data.dims if d not in ignored_dims]
1306+
1307+
# Build all coordinate combinations
1308+
if additional_dims:
1309+
additional_coords = [results.coords[d].values for d in additional_dims]
1310+
dim_combinations = list(itertools.product(*additional_coords))
13021311
else:
1303-
if percentage:
1304-
actual = self.idata.posterior_predictive["y"]
1305-
y = results.y.mean(dim=["chain", "draw"]).sum(dim="date") / actual.mean(
1306-
dim=["chain", "draw"]
1307-
).sum(dim="date")
1308-
y_hdi = results.y.sum(dim="date") / actual.sum(dim="date")
1309-
else:
1310-
y = results.y.mean(dim=["chain", "draw"]).sum(dim="date")
1311-
y_hdi = results.y.sum(dim="date")
1312-
color = "C0"
1313-
label = "Posterior mean"
1314-
title = "Sensitivity analysis plot"
1315-
ylabel = "Total uplift (sum over dates)"
1316-
1317-
ax.plot(x, y, label=label, color=color)
1318-
1319-
az.plot_hdi(
1320-
x,
1321-
y_hdi,
1322-
hdi_prob=hdi_prob,
1323-
color=color,
1324-
fill_kwargs={"alpha": 0.5, "label": f"{hdi_prob * 100:.0f}% HDI"},
1325-
plot_kwargs={"color": color, "alpha": 0.5},
1326-
smooth=False,
1327-
ax=ax,
1328-
)
1312+
dim_combinations = [()]
1313+
1314+
multi_panel = len(dim_combinations) > 1
1315+
1316+
# If user provided ax but multiple panels needed, raise (consistent with other methods)
1317+
if multi_panel and ax is not None:
1318+
raise ValueError(
1319+
"Cannot use 'ax' when there are extra dimensions. "
1320+
"Let the function create its own subplots."
1321+
)
13291322

1330-
ax.set(title=title)
1331-
if results.sweep_type == "absolute":
1332-
ax.set_xlabel(f"Absolute value of: {results.var_names}")
1323+
# Prepare figure/axes
1324+
if multi_panel:
1325+
fig, axes = self._init_subplots(n_subplots=len(dim_combinations), ncols=1)
1326+
if sharey:
1327+
# Align y limits later - collect mins/maxs
1328+
y_mins, y_maxs = [], []
13331329
else:
1334-
ax.set_xlabel(
1335-
f"{results.sweep_type.capitalize()} change of: {results.var_names}"
1330+
if ax is None:
1331+
fig, axes_arr = plt.subplots(figsize=(10, 6))
1332+
ax = axes_arr # type: ignore
1333+
fig = ax.get_figure() # type: ignore
1334+
axes = np.array([[ax]]) # type: ignore
1335+
1336+
sweep_values = results.coords["sweep"].values
1337+
1338+
# Helper: select subset (only dims present)
1339+
def _select(data: xr.DataArray, indexers: dict) -> xr.DataArray:
1340+
valid = {k: v for k, v in indexers.items() if k in data.dims}
1341+
return data.sel(**valid)
1342+
1343+
for row_idx, combo in enumerate(dim_combinations):
1344+
current_ax = axes[row_idx][0] if multi_panel else ax # type: ignore
1345+
indexers = (
1346+
dict(zip(additional_dims, combo, strict=False))
1347+
if additional_dims
1348+
else {}
13361349
)
1337-
ax.set_ylabel(ylabel)
1338-
plt.legend()
1339-
1340-
# Set y-axis limits based on the sign of y values
1341-
y_values = y.values if hasattr(y, "values") else np.array(y)
1342-
if np.all(y_values < 0):
1343-
ax.set_ylim(top=0)
1344-
elif np.all(y_values > 0):
1345-
ax.set_ylim(bottom=0)
1346-
1347-
ax.yaxis.set_major_formatter(
1348-
plt.FuncFormatter(lambda x, _: f"{x:.1%}" if percentage else f"{x:,.1f}")
1349-
)
13501350

1351-
# Add reference lines
1352-
if results.sweep_type == "multiplicative":
1353-
ax.axvline(x=1, color="k", linestyle="--", alpha=0.5)
1354-
if not marginal:
1355-
ax.axhline(y=0, color="k", linestyle="--", alpha=0.5)
1356-
elif results.sweep_type == "additive":
1357-
ax.axvline(x=0, color="k", linestyle="--", alpha=0.5)
1351+
if marginal:
1352+
eff = _select(results.marginal_effects, indexers)
1353+
# mean over chain/draw, sum over date (and any leftover dims not indexed)
1354+
leftover = [d for d in eff.dims if d in ("date",) and d != "sweep"]
1355+
y_mean = eff.mean(dim=["chain", "draw"]).sum(dim=leftover)
1356+
y_hdi_data = eff.sum(dim=leftover)
1357+
color = "C1"
1358+
label = "Posterior mean marginal effect"
1359+
title = "Marginal effects"
1360+
ylabel = r"Marginal effect, $\frac{d\mathbb{E}[Y]}{dX}$"
1361+
else:
1362+
y_da = _select(results.y, indexers)
1363+
leftover = [d for d in y_da.dims if d in ("date",) and d != "sweep"]
1364+
if percentage:
1365+
actual = self.idata.posterior_predictive["y"] # type: ignore
1366+
actual_sel = _select(actual, indexers)
1367+
actual_mean = actual_sel.mean(dim=["chain", "draw"]).sum(
1368+
dim=leftover
1369+
)
1370+
actual_sum = actual_sel.sum(dim=leftover)
1371+
y_mean = (
1372+
y_da.mean(dim=["chain", "draw"]).sum(dim=leftover) / actual_mean
1373+
)
1374+
y_hdi_data = y_da.sum(dim=leftover) / actual_sum
1375+
else:
1376+
y_mean = y_da.mean(dim=["chain", "draw"]).sum(dim=leftover)
1377+
y_hdi_data = y_da.sum(dim=leftover)
1378+
color = "C0"
1379+
label = "Posterior mean uplift"
1380+
title = "Sensitivity analysis"
1381+
ylabel = "Total uplift (sum over dates)"
1382+
1383+
# Ensure ordering: y_mean dimension 'sweep'
1384+
if "sweep" not in y_mean.dims:
1385+
raise ValueError("Expected 'sweep' dim after aggregation.")
1386+
1387+
current_ax.plot(sweep_values, y_mean, label=label, color=color) # type: ignore
1388+
1389+
# Plot HDI
1390+
az.plot_hdi(
1391+
sweep_values,
1392+
y_hdi_data,
1393+
hdi_prob=hdi_prob,
1394+
color=color,
1395+
fill_kwargs={"alpha": 0.4, "label": f"{hdi_prob * 100:.0f}% HDI"},
1396+
plot_kwargs={"color": color, "alpha": 0.5},
1397+
smooth=False,
1398+
ax=current_ax,
1399+
)
13581400

1359-
return ax
1401+
# Titles / labels
1402+
if additional_dims:
1403+
subplot_title = self._build_subplot_title(
1404+
additional_dims, combo, fallback_title=title
1405+
)
1406+
else:
1407+
subplot_title = title
1408+
current_ax.set_title(subplot_title) # type: ignore
1409+
if results.sweep_type == "absolute":
1410+
current_ax.set_xlabel(f"Absolute value of: {results.var_names}") # type: ignore
1411+
else:
1412+
current_ax.set_xlabel( # type: ignore
1413+
f"{results.sweep_type.capitalize()} change of: {results.var_names}"
1414+
)
1415+
current_ax.set_ylabel(ylabel) # type: ignore
1416+
1417+
# Baseline reference lines
1418+
if results.sweep_type == "multiplicative":
1419+
current_ax.axvline(x=1, color="k", linestyle="--", alpha=0.5) # type: ignore
1420+
if not marginal:
1421+
current_ax.axhline(y=0, color="k", linestyle="--", alpha=0.5) # type: ignore
1422+
elif results.sweep_type == "additive":
1423+
current_ax.axvline(x=0, color="k", linestyle="--", alpha=0.5) # type: ignore
1424+
1425+
# Format y
1426+
if percentage:
1427+
current_ax.yaxis.set_major_formatter( # type: ignore
1428+
plt.FuncFormatter(lambda v, _: f"{v:.1%}") # type: ignore
1429+
)
1430+
else:
1431+
current_ax.yaxis.set_major_formatter( # type: ignore
1432+
plt.FuncFormatter(lambda v, _: f"{v:,.1f}") # type: ignore
1433+
)
1434+
1435+
# Adjust y-lims sign aware
1436+
y_vals = y_mean.values
1437+
if np.all(y_vals < 0):
1438+
current_ax.set_ylim(top=0) # type: ignore
1439+
elif np.all(y_vals > 0):
1440+
current_ax.set_ylim(bottom=0) # type: ignore
1441+
1442+
if multi_panel and sharey:
1443+
y_mins.append(current_ax.get_ylim()[0]) # type: ignore
1444+
y_maxs.append(current_ax.get_ylim()[1]) # type: ignore
1445+
1446+
current_ax.legend(loc="best") # type: ignore
1447+
1448+
# Share y limits if requested
1449+
if multi_panel and sharey:
1450+
global_min, global_max = min(y_mins), max(y_maxs)
1451+
for row_idx in range(len(dim_combinations)):
1452+
axes[row_idx][0].set_ylim(global_min, global_max)
1453+
1454+
if multi_panel:
1455+
fig.tight_layout()
1456+
return fig, axes
1457+
else:
1458+
return ax # single axis for backwards compatibility

0 commit comments

Comments
 (0)