Skip to content
Merged
265 changes: 182 additions & 83 deletions pymc_marketing/mmm/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,108 +1252,207 @@ def plot_sensitivity_analysis(
ax: plt.Axes | None = None,
marginal: bool = False,
percentage: bool = False,
) -> plt.Axes:
sharey: bool = True,
) -> tuple[Figure, NDArray[Axes]] | plt.Axes:
"""
Plot the counterfactual uplift or marginal effects curve.
Plot counterfactual uplift or marginal effects curves.

Handles additional (non sweep/date/chain/draw) dimensions by creating one subplot
per combination of those dimensions - consistent with other plot_* methods.

Parameters
----------
results : xr.Dataset
The dataset containing the results of the sweep.
hdi_prob : float, optional
The probability for computing the highest density interval (HDI). Default is 0.94.
ax : Optional[plt.Axes], optional
An optional matplotlib Axes on which to plot. If None, a new Axes is created.
marginal : bool, optional
If True, plot marginal effects. If False (default), plot uplift.
percentage : bool, optional
If True, plot the results on the y-axis as percentages, instead of absolute
values. Default is False.
hdi_prob : float, default 0.94
HDI probability mass.
ax : plt.Axes, optional
Only used when there are no extra dimensions (single panel case).
marginal : bool, default False
Plot marginal effects instead of uplift.
percentage : bool, default False
Express uplift as a percentage of actual (not supported for marginal).
sharey : bool, default True
Share y-axis across subplots (only relevant for multi-panel case).

Returns
-------
plt.Axes
The Axes object with the plot.
(fig, axes) if multi-panel, else a single Axes (backwards compatible single-dim case).
"""
if ax is None:
_, ax = plt.subplots(figsize=(10, 6))

if percentage and marginal:
raise ValueError("Not implemented marginal effects in percentage scale.")

# Check if sensitivity analysis results exist in idata
if not hasattr(self.idata, "sensitivity_analysis"):
raise ValueError(
"No sensitivity analysis results found in 'self.idata'. "
"Please run the sensitivity analysis first using 'mmm.sensitivity.run_sweep()' method."
"Run 'mmm.sensitivity.run_sweep()' first."
)

results: xr.Dataset = self.idata.sensitivity_analysis # type: ignore

# Required variable presence checks
required_var = "marginal_effects" if marginal else "y"
if required_var not in results:
raise ValueError(
f"Expected '{required_var}' in sensitivity_analysis results, found: {list(results.data_vars)}"
)
if "sweep" not in results.dims:
raise ValueError(
"Sensitivity analysis results must contain 'sweep' dimension."
)

# grab sensitivity analysis results from idata
results = self.idata.sensitivity_analysis

x = results.sweep.values
if marginal:
y = results.marginal_effects.mean(dim=["chain", "draw"]).sum(dim="date")
y_hdi = results.marginal_effects.sum(dim="date")
color = "C1"
label = "Posterior mean marginal effect"
title = "Marginal effects plot"
ylabel = r"Marginal effect, $\frac{d\mathbb{E}[Y]}{dX}$"
# Identify additional dimensions
ignored_dims = {"chain", "draw", "date", "sweep"}
base_data = results.marginal_effects if marginal else results.y
additional_dims = [d for d in base_data.dims if d not in ignored_dims]

# Build all coordinate combinations
if additional_dims:
additional_coords = [results.coords[d].values for d in additional_dims]
dim_combinations = list(itertools.product(*additional_coords))
else:
if percentage:
actual = self.idata.posterior_predictive["y"]
y = results.y.mean(dim=["chain", "draw"]).sum(dim="date") / actual.mean(
dim=["chain", "draw"]
).sum(dim="date")
y_hdi = results.y.sum(dim="date") / actual.sum(dim="date")
else:
y = results.y.mean(dim=["chain", "draw"]).sum(dim="date")
y_hdi = results.y.sum(dim="date")
color = "C0"
label = "Posterior mean"
title = "Sensitivity analysis plot"
ylabel = "Total uplift (sum over dates)"

ax.plot(x, y, label=label, color=color)

az.plot_hdi(
x,
y_hdi,
hdi_prob=hdi_prob,
color=color,
fill_kwargs={"alpha": 0.5, "label": f"{hdi_prob * 100:.0f}% HDI"},
plot_kwargs={"color": color, "alpha": 0.5},
smooth=False,
ax=ax,
)
dim_combinations = [()]

multi_panel = len(dim_combinations) > 1

# If user provided ax but multiple panels needed, raise (consistent with other methods)
if multi_panel and ax is not None:
raise ValueError(
"Cannot use 'ax' when there are extra dimensions. "
"Let the function create its own subplots."
)

ax.set(title=title)
if results.sweep_type == "absolute":
ax.set_xlabel(f"Absolute value of: {results.var_names}")
# Prepare figure/axes
if multi_panel:
fig, axes = self._init_subplots(n_subplots=len(dim_combinations), ncols=1)
if sharey:
# Align y limits later - collect mins/maxs
y_mins, y_maxs = [], []
else:
ax.set_xlabel(
f"{results.sweep_type.capitalize()} change of: {results.var_names}"
if ax is None:
fig, axes_arr = plt.subplots(figsize=(10, 6))
ax = axes_arr # type: ignore
fig = ax.get_figure() # type: ignore
axes = np.array([[ax]]) # type: ignore

sweep_values = results.coords["sweep"].values

# Helper: select subset (only dims present)
def _select(data: xr.DataArray, indexers: dict) -> xr.DataArray:
valid = {k: v for k, v in indexers.items() if k in data.dims}
return data.sel(**valid)

for row_idx, combo in enumerate(dim_combinations):
current_ax = axes[row_idx][0] if multi_panel else ax # type: ignore
indexers = (
dict(zip(additional_dims, combo, strict=False))
if additional_dims
else {}
)
ax.set_ylabel(ylabel)
plt.legend()

# Set y-axis limits based on the sign of y values
y_values = y.values if hasattr(y, "values") else np.array(y)
if np.all(y_values < 0):
ax.set_ylim(top=0)
elif np.all(y_values > 0):
ax.set_ylim(bottom=0)

ax.yaxis.set_major_formatter(
plt.FuncFormatter(lambda x, _: f"{x:.1%}" if percentage else f"{x:,.1f}")
)

# Add reference lines
if results.sweep_type == "multiplicative":
ax.axvline(x=1, color="k", linestyle="--", alpha=0.5)
if not marginal:
ax.axhline(y=0, color="k", linestyle="--", alpha=0.5)
elif results.sweep_type == "additive":
ax.axvline(x=0, color="k", linestyle="--", alpha=0.5)
if marginal:
eff = _select(results.marginal_effects, indexers)
# mean over chain/draw, sum over date (and any leftover dims not indexed)
leftover = [d for d in eff.dims if d in ("date",) and d != "sweep"]
y_mean = eff.mean(dim=["chain", "draw"]).sum(dim=leftover)
y_hdi_data = eff.sum(dim=leftover)
color = "C1"
label = "Posterior mean marginal effect"
title = "Marginal effects"
ylabel = r"Marginal effect, $\frac{d\mathbb{E}[Y]}{dX}$"
else:
y_da = _select(results.y, indexers)
leftover = [d for d in y_da.dims if d in ("date",) and d != "sweep"]
if percentage:
actual = self.idata.posterior_predictive["y"] # type: ignore
actual_sel = _select(actual, indexers)
actual_mean = actual_sel.mean(dim=["chain", "draw"]).sum(
dim=leftover
)
actual_sum = actual_sel.sum(dim=leftover)
y_mean = (
y_da.mean(dim=["chain", "draw"]).sum(dim=leftover) / actual_mean
)
y_hdi_data = y_da.sum(dim=leftover) / actual_sum
else:
y_mean = y_da.mean(dim=["chain", "draw"]).sum(dim=leftover)
y_hdi_data = y_da.sum(dim=leftover)
color = "C0"
label = "Posterior mean uplift"
title = "Sensitivity analysis"
ylabel = "Total uplift (sum over dates)"

# Ensure ordering: y_mean dimension 'sweep'
if "sweep" not in y_mean.dims:
raise ValueError("Expected 'sweep' dim after aggregation.")

current_ax.plot(sweep_values, y_mean, label=label, color=color) # type: ignore

# Plot HDI
az.plot_hdi(
sweep_values,
y_hdi_data,
hdi_prob=hdi_prob,
color=color,
fill_kwargs={"alpha": 0.4, "label": f"{hdi_prob * 100:.0f}% HDI"},
plot_kwargs={"color": color, "alpha": 0.5},
smooth=False,
ax=current_ax,
)

return ax
# Titles / labels
if additional_dims:
subplot_title = self._build_subplot_title(
additional_dims, combo, fallback_title=title
)
else:
subplot_title = title
current_ax.set_title(subplot_title) # type: ignore
if results.sweep_type == "absolute":
current_ax.set_xlabel(f"Absolute value of: {results.var_names}") # type: ignore
else:
current_ax.set_xlabel( # type: ignore
f"{results.sweep_type.capitalize()} change of: {results.var_names}"
)
current_ax.set_ylabel(ylabel) # type: ignore

# Baseline reference lines
if results.sweep_type == "multiplicative":
current_ax.axvline(x=1, color="k", linestyle="--", alpha=0.5) # type: ignore
if not marginal:
current_ax.axhline(y=0, color="k", linestyle="--", alpha=0.5) # type: ignore
elif results.sweep_type == "additive":
current_ax.axvline(x=0, color="k", linestyle="--", alpha=0.5) # type: ignore

# Format y
if percentage:
current_ax.yaxis.set_major_formatter( # type: ignore
plt.FuncFormatter(lambda v, _: f"{v:.1%}") # type: ignore
)
else:
current_ax.yaxis.set_major_formatter( # type: ignore
plt.FuncFormatter(lambda v, _: f"{v:,.1f}") # type: ignore
)

# Adjust y-lims sign aware
y_vals = y_mean.values
if np.all(y_vals < 0):
current_ax.set_ylim(top=0) # type: ignore
elif np.all(y_vals > 0):
current_ax.set_ylim(bottom=0) # type: ignore

if multi_panel and sharey:
y_mins.append(current_ax.get_ylim()[0]) # type: ignore
y_maxs.append(current_ax.get_ylim()[1]) # type: ignore

current_ax.legend(loc="best") # type: ignore

# Share y limits if requested
if multi_panel and sharey:
global_min, global_max = min(y_mins), max(y_maxs)
for row_idx in range(len(dim_combinations)):
axes[row_idx][0].set_ylim(global_min, global_max)

if multi_panel:
fig.tight_layout()
return fig, axes
else:
return ax # single axis for backwards compatibility
Loading