Skip to content

Multidimensional support in plot_sensitivity_analysis #1886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 182 additions & 83 deletions pymc_marketing/mmm/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,108 +1252,207 @@ def plot_sensitivity_analysis(
ax: plt.Axes | None = None,
marginal: bool = False,
percentage: bool = False,
) -> plt.Axes:
sharey: bool = True,
) -> tuple[Figure, NDArray[Axes]] | plt.Axes:
"""
Plot the counterfactual uplift or marginal effects curve.
Plot counterfactual uplift or marginal effects curves.

Handles additional (non sweep/date/chain/draw) dimensions by creating one subplot
per combination of those dimensions - consistent with other plot_* methods.

Parameters
----------
results : xr.Dataset
The dataset containing the results of the sweep.
hdi_prob : float, optional
The probability for computing the highest density interval (HDI). Default is 0.94.
ax : Optional[plt.Axes], optional
An optional matplotlib Axes on which to plot. If None, a new Axes is created.
marginal : bool, optional
If True, plot marginal effects. If False (default), plot uplift.
percentage : bool, optional
If True, plot the results on the y-axis as percentages, instead of absolute
values. Default is False.
hdi_prob : float, default 0.94
HDI probability mass.
ax : plt.Axes, optional
Only used when there are no extra dimensions (single panel case).
marginal : bool, default False
Plot marginal effects instead of uplift.
percentage : bool, default False
Express uplift as a percentage of actual (not supported for marginal).
sharey : bool, default True
Share y-axis across subplots (only relevant for multi-panel case).

Returns
-------
plt.Axes
The Axes object with the plot.
(fig, axes) if multi-panel, else a single Axes (backwards compatible single-dim case).
"""
if ax is None:
_, ax = plt.subplots(figsize=(10, 6))

if percentage and marginal:
raise ValueError("Not implemented marginal effects in percentage scale.")

# Check if sensitivity analysis results exist in idata
if not hasattr(self.idata, "sensitivity_analysis"):
raise ValueError(
"No sensitivity analysis results found in 'self.idata'. "
"Please run the sensitivity analysis first using 'mmm.sensitivity.run_sweep()' method."
"Run 'mmm.sensitivity.run_sweep()' first."
)

results: xr.Dataset = self.idata.sensitivity_analysis # type: ignore

# Required variable presence checks
required_var = "marginal_effects" if marginal else "y"
if required_var not in results:
raise ValueError(
f"Expected '{required_var}' in sensitivity_analysis results, found: {list(results.data_vars)}"
)
if "sweep" not in results.dims:
raise ValueError(
"Sensitivity analysis results must contain 'sweep' dimension."
)

# grab sensitivity analysis results from idata
results = self.idata.sensitivity_analysis

x = results.sweep.values
if marginal:
y = results.marginal_effects.mean(dim=["chain", "draw"]).sum(dim="date")
y_hdi = results.marginal_effects.sum(dim="date")
color = "C1"
label = "Posterior mean marginal effect"
title = "Marginal effects plot"
ylabel = r"Marginal effect, $\frac{d\mathbb{E}[Y]}{dX}$"
# Identify additional dimensions
ignored_dims = {"chain", "draw", "date", "sweep"}
base_data = results.marginal_effects if marginal else results.y
additional_dims = [d for d in base_data.dims if d not in ignored_dims]

# Build all coordinate combinations
if additional_dims:
additional_coords = [results.coords[d].values for d in additional_dims]
dim_combinations = list(itertools.product(*additional_coords))
else:
if percentage:
actual = self.idata.posterior_predictive["y"]
y = results.y.mean(dim=["chain", "draw"]).sum(dim="date") / actual.mean(
dim=["chain", "draw"]
).sum(dim="date")
y_hdi = results.y.sum(dim="date") / actual.sum(dim="date")
else:
y = results.y.mean(dim=["chain", "draw"]).sum(dim="date")
y_hdi = results.y.sum(dim="date")
color = "C0"
label = "Posterior mean"
title = "Sensitivity analysis plot"
ylabel = "Total uplift (sum over dates)"

ax.plot(x, y, label=label, color=color)

az.plot_hdi(
x,
y_hdi,
hdi_prob=hdi_prob,
color=color,
fill_kwargs={"alpha": 0.5, "label": f"{hdi_prob * 100:.0f}% HDI"},
plot_kwargs={"color": color, "alpha": 0.5},
smooth=False,
ax=ax,
)
dim_combinations = [()]

multi_panel = len(dim_combinations) > 1

# If user provided ax but multiple panels needed, raise (consistent with other methods)
if multi_panel and ax is not None:
raise ValueError(
"Cannot use 'ax' when there are extra dimensions. "
"Let the function create its own subplots."
)

ax.set(title=title)
if results.sweep_type == "absolute":
ax.set_xlabel(f"Absolute value of: {results.var_names}")
# Prepare figure/axes
if multi_panel:
fig, axes = self._init_subplots(n_subplots=len(dim_combinations), ncols=1)
if sharey:
# Align y limits later - collect mins/maxs
y_mins, y_maxs = [], []
else:
ax.set_xlabel(
f"{results.sweep_type.capitalize()} change of: {results.var_names}"
if ax is None:
fig, axes_arr = plt.subplots(figsize=(10, 6))
ax = axes_arr # type: ignore
fig = ax.get_figure() # type: ignore
axes = np.array([[ax]]) # type: ignore

sweep_values = results.coords["sweep"].values

# Helper: select subset (only dims present)
def _select(data: xr.DataArray, indexers: dict) -> xr.DataArray:
valid = {k: v for k, v in indexers.items() if k in data.dims}
return data.sel(**valid)

for row_idx, combo in enumerate(dim_combinations):
current_ax = axes[row_idx][0] if multi_panel else ax # type: ignore
indexers = (
dict(zip(additional_dims, combo, strict=False))
if additional_dims
else {}
)
ax.set_ylabel(ylabel)
plt.legend()

# Set y-axis limits based on the sign of y values
y_values = y.values if hasattr(y, "values") else np.array(y)
if np.all(y_values < 0):
ax.set_ylim(top=0)
elif np.all(y_values > 0):
ax.set_ylim(bottom=0)

ax.yaxis.set_major_formatter(
plt.FuncFormatter(lambda x, _: f"{x:.1%}" if percentage else f"{x:,.1f}")
)

# Add reference lines
if results.sweep_type == "multiplicative":
ax.axvline(x=1, color="k", linestyle="--", alpha=0.5)
if not marginal:
ax.axhline(y=0, color="k", linestyle="--", alpha=0.5)
elif results.sweep_type == "additive":
ax.axvline(x=0, color="k", linestyle="--", alpha=0.5)
if marginal:
eff = _select(results.marginal_effects, indexers)
# mean over chain/draw, sum over date (and any leftover dims not indexed)
leftover = [d for d in eff.dims if d in ("date",) and d != "sweep"]
y_mean = eff.mean(dim=["chain", "draw"]).sum(dim=leftover)
y_hdi_data = eff.sum(dim=leftover)
color = "C1"
label = "Posterior mean marginal effect"
title = "Marginal effects"
ylabel = r"Marginal effect, $\frac{d\mathbb{E}[Y]}{dX}$"
else:
y_da = _select(results.y, indexers)
leftover = [d for d in y_da.dims if d in ("date",) and d != "sweep"]
if percentage:
actual = self.idata.posterior_predictive["y"] # type: ignore
actual_sel = _select(actual, indexers)
actual_mean = actual_sel.mean(dim=["chain", "draw"]).sum(
dim=leftover
)
actual_sum = actual_sel.sum(dim=leftover)
y_mean = (
y_da.mean(dim=["chain", "draw"]).sum(dim=leftover) / actual_mean
)
y_hdi_data = y_da.sum(dim=leftover) / actual_sum
else:
y_mean = y_da.mean(dim=["chain", "draw"]).sum(dim=leftover)
y_hdi_data = y_da.sum(dim=leftover)
color = "C0"
label = "Posterior mean uplift"
title = "Sensitivity analysis"
ylabel = "Total uplift (sum over dates)"

# Ensure ordering: y_mean dimension 'sweep'
if "sweep" not in y_mean.dims:
raise ValueError("Expected 'sweep' dim after aggregation.")

current_ax.plot(sweep_values, y_mean, label=label, color=color) # type: ignore

# Plot HDI
az.plot_hdi(
sweep_values,
y_hdi_data,
hdi_prob=hdi_prob,
color=color,
fill_kwargs={"alpha": 0.4, "label": f"{hdi_prob * 100:.0f}% HDI"},
plot_kwargs={"color": color, "alpha": 0.5},
smooth=False,
ax=current_ax,
)

return ax
# Titles / labels
if additional_dims:
subplot_title = self._build_subplot_title(
additional_dims, combo, fallback_title=title
)
else:
subplot_title = title
current_ax.set_title(subplot_title) # type: ignore
if results.sweep_type == "absolute":
current_ax.set_xlabel(f"Absolute value of: {results.var_names}") # type: ignore
else:
current_ax.set_xlabel( # type: ignore
f"{results.sweep_type.capitalize()} change of: {results.var_names}"
)
current_ax.set_ylabel(ylabel) # type: ignore

# Baseline reference lines
if results.sweep_type == "multiplicative":
current_ax.axvline(x=1, color="k", linestyle="--", alpha=0.5) # type: ignore
if not marginal:
current_ax.axhline(y=0, color="k", linestyle="--", alpha=0.5) # type: ignore
elif results.sweep_type == "additive":
current_ax.axvline(x=0, color="k", linestyle="--", alpha=0.5) # type: ignore

# Format y
if percentage:
current_ax.yaxis.set_major_formatter( # type: ignore
plt.FuncFormatter(lambda v, _: f"{v:.1%}") # type: ignore
)
else:
current_ax.yaxis.set_major_formatter( # type: ignore
plt.FuncFormatter(lambda v, _: f"{v:,.1f}") # type: ignore
)

# Adjust y-lims sign aware
y_vals = y_mean.values
if np.all(y_vals < 0):
current_ax.set_ylim(top=0) # type: ignore
elif np.all(y_vals > 0):
current_ax.set_ylim(bottom=0) # type: ignore

if multi_panel and sharey:
y_mins.append(current_ax.get_ylim()[0]) # type: ignore
y_maxs.append(current_ax.get_ylim()[1]) # type: ignore

current_ax.legend(loc="best") # type: ignore

# Share y limits if requested
if multi_panel and sharey:
global_min, global_max = min(y_mins), max(y_maxs)
for row_idx in range(len(dim_combinations)):
axes[row_idx][0].set_ylim(global_min, global_max)

if multi_panel:
fig.tight_layout()
return fig, axes
else:
return ax # single axis for backwards compatibility
Loading
Loading