From bf98cb951915c2a2d978ea8c7d942279c0fa3699 Mon Sep 17 00:00:00 2001 From: PabloRoque Date: Wed, 13 Aug 2025 12:45:18 +0200 Subject: [PATCH 1/3] Multidimensional support in plot_sensitivity_analysis --- pymc_marketing/mmm/plot.py | 265 +++++++++++++++++++++++++------------ 1 file changed, 182 insertions(+), 83 deletions(-) diff --git a/pymc_marketing/mmm/plot.py b/pymc_marketing/mmm/plot.py index 57da3ad8..09181779 100644 --- a/pymc_marketing/mmm/plot.py +++ b/pymc_marketing/mmm/plot.py @@ -1250,108 +1250,207 @@ def plot_sensitivity_analysis( ax: plt.Axes | None = None, marginal: bool = False, percentage: bool = False, - ) -> plt.Axes: + sharey: bool = True, + ) -> tuple[Figure, NDArray[Axes]] | plt.Axes: """ - Plot the counterfactual uplift or marginal effects curve. + Plot counterfactual uplift or marginal effects curves. + + Handles additional (non sweep/date/chain/draw) dimensions by creating one subplot + per combination of those dimensions - consistent with other plot_* methods. Parameters ---------- - results : xr.Dataset - The dataset containing the results of the sweep. - hdi_prob : float, optional - The probability for computing the highest density interval (HDI). Default is 0.94. - ax : Optional[plt.Axes], optional - An optional matplotlib Axes on which to plot. If None, a new Axes is created. - marginal : bool, optional - If True, plot marginal effects. If False (default), plot uplift. - percentage : bool, optional - If True, plot the results on the y-axis as percentages, instead of absolute - values. Default is False. + hdi_prob : float, default 0.94 + HDI probability mass. + ax : plt.Axes, optional + Only used when there are no extra dimensions (single panel case). + marginal : bool, default False + Plot marginal effects instead of uplift. + percentage : bool, default False + Express uplift as a percentage of actual (not supported for marginal). + sharey : bool, default True + Share y-axis across subplots (only relevant for multi-panel case). Returns ------- - plt.Axes - The Axes object with the plot. + (fig, axes) if multi-panel, else a single Axes (backwards compatible single-dim case). """ - if ax is None: - _, ax = plt.subplots(figsize=(10, 6)) - if percentage and marginal: raise ValueError("Not implemented marginal effects in percentage scale.") - # Check if sensitivity analysis results exist in idata if not hasattr(self.idata, "sensitivity_analysis"): raise ValueError( "No sensitivity analysis results found in 'self.idata'. " - "Please run the sensitivity analysis first using 'mmm.sensitivity.run_sweep()' method." + "Run 'mmm.sensitivity.run_sweep()' first." + ) + + results: xr.Dataset = self.idata.sensitivity_analysis # type: ignore + + # Required variable presence checks + required_var = "marginal_effects" if marginal else "y" + if required_var not in results: + raise ValueError( + f"Expected '{required_var}' in sensitivity_analysis results, found: {list(results.data_vars)}" + ) + if "sweep" not in results.dims: + raise ValueError( + "Sensitivity analysis results must contain 'sweep' dimension." ) - # grab sensitivity analysis results from idata - results = self.idata.sensitivity_analysis - - x = results.sweep.values - if marginal: - y = results.marginal_effects.mean(dim=["chain", "draw"]).sum(dim="date") - y_hdi = results.marginal_effects.sum(dim="date") - color = "C1" - label = "Posterior mean marginal effect" - title = "Marginal effects plot" - ylabel = r"Marginal effect, $\frac{d\mathbb{E}[Y]}{dX}$" + # Identify additional dimensions + ignored_dims = {"chain", "draw", "date", "sweep"} + base_data = results.marginal_effects if marginal else results.y + additional_dims = [d for d in base_data.dims if d not in ignored_dims] + + # Build all coordinate combinations + if additional_dims: + additional_coords = [results.coords[d].values for d in additional_dims] + dim_combinations = list(itertools.product(*additional_coords)) else: - if percentage: - actual = self.idata.posterior_predictive["y"] - y = results.y.mean(dim=["chain", "draw"]).sum(dim="date") / actual.mean( - dim=["chain", "draw"] - ).sum(dim="date") - y_hdi = results.y.sum(dim="date") / actual.sum(dim="date") - else: - y = results.y.mean(dim=["chain", "draw"]).sum(dim="date") - y_hdi = results.y.sum(dim="date") - color = "C0" - label = "Posterior mean" - title = "Sensitivity analysis plot" - ylabel = "Total uplift (sum over dates)" - - ax.plot(x, y, label=label, color=color) - - az.plot_hdi( - x, - y_hdi, - hdi_prob=hdi_prob, - color=color, - fill_kwargs={"alpha": 0.5, "label": f"{hdi_prob * 100:.0f}% HDI"}, - plot_kwargs={"color": color, "alpha": 0.5}, - smooth=False, - ax=ax, - ) + dim_combinations = [()] + + multi_panel = len(dim_combinations) > 1 + + # If user provided ax but multiple panels needed, raise (consistent with other methods) + if multi_panel and ax is not None: + raise ValueError( + "Cannot use 'ax' when there are extra dimensions. " + "Let the function create its own subplots." + ) - ax.set(title=title) - if results.sweep_type == "absolute": - ax.set_xlabel(f"Absolute value of: {results.var_names}") + # Prepare figure/axes + if multi_panel: + fig, axes = self._init_subplots(n_subplots=len(dim_combinations), ncols=1) + if sharey: + # Align y limits later - collect mins/maxs + y_mins, y_maxs = [], [] else: - ax.set_xlabel( - f"{results.sweep_type.capitalize()} change of: {results.var_names}" + if ax is None: + fig, axes_arr = plt.subplots(figsize=(10, 6)) + ax = axes_arr # type: ignore + fig = ax.get_figure() # type: ignore + axes = np.array([[ax]]) # type: ignore + + sweep_values = results.coords["sweep"].values + + # Helper: select subset (only dims present) + def _select(data: xr.DataArray, indexers: dict) -> xr.DataArray: + valid = {k: v for k, v in indexers.items() if k in data.dims} + return data.sel(**valid) + + for row_idx, combo in enumerate(dim_combinations): + current_ax = axes[row_idx][0] if multi_panel else ax # type: ignore + indexers = ( + dict(zip(additional_dims, combo, strict=False)) + if additional_dims + else {} ) - ax.set_ylabel(ylabel) - plt.legend() - - # Set y-axis limits based on the sign of y values - y_values = y.values if hasattr(y, "values") else np.array(y) - if np.all(y_values < 0): - ax.set_ylim(top=0) - elif np.all(y_values > 0): - ax.set_ylim(bottom=0) - - ax.yaxis.set_major_formatter( - plt.FuncFormatter(lambda x, _: f"{x:.1%}" if percentage else f"{x:,.1f}") - ) - # Add reference lines - if results.sweep_type == "multiplicative": - ax.axvline(x=1, color="k", linestyle="--", alpha=0.5) - if not marginal: - ax.axhline(y=0, color="k", linestyle="--", alpha=0.5) - elif results.sweep_type == "additive": - ax.axvline(x=0, color="k", linestyle="--", alpha=0.5) + if marginal: + eff = _select(results.marginal_effects, indexers) + # mean over chain/draw, sum over date (and any leftover dims not indexed) + leftover = [d for d in eff.dims if d in ("date",) and d != "sweep"] + y_mean = eff.mean(dim=["chain", "draw"]).sum(dim=leftover) + y_hdi_data = eff.sum(dim=leftover) + color = "C1" + label = "Posterior mean marginal effect" + title = "Marginal effects" + ylabel = r"Marginal effect, $\frac{d\mathbb{E}[Y]}{dX}$" + else: + y_da = _select(results.y, indexers) + leftover = [d for d in y_da.dims if d in ("date",) and d != "sweep"] + if percentage: + actual = self.idata.posterior_predictive["y"] # type: ignore + actual_sel = _select(actual, indexers) + actual_mean = actual_sel.mean(dim=["chain", "draw"]).sum( + dim=leftover + ) + actual_sum = actual_sel.sum(dim=leftover) + y_mean = ( + y_da.mean(dim=["chain", "draw"]).sum(dim=leftover) / actual_mean + ) + y_hdi_data = y_da.sum(dim=leftover) / actual_sum + else: + y_mean = y_da.mean(dim=["chain", "draw"]).sum(dim=leftover) + y_hdi_data = y_da.sum(dim=leftover) + color = "C0" + label = "Posterior mean uplift" + title = "Sensitivity analysis" + ylabel = "Total uplift (sum over dates)" + + # Ensure ordering: y_mean dimension 'sweep' + if "sweep" not in y_mean.dims: + raise ValueError("Expected 'sweep' dim after aggregation.") + + current_ax.plot(sweep_values, y_mean, label=label, color=color) # type: ignore + + # Plot HDI + az.plot_hdi( + sweep_values, + y_hdi_data, + hdi_prob=hdi_prob, + color=color, + fill_kwargs={"alpha": 0.4, "label": f"{hdi_prob * 100:.0f}% HDI"}, + plot_kwargs={"color": color, "alpha": 0.5}, + smooth=False, + ax=current_ax, + ) - return ax + # Titles / labels + if additional_dims: + subplot_title = self._build_subplot_title( + additional_dims, combo, fallback_title=title + ) + else: + subplot_title = title + current_ax.set_title(subplot_title) # type: ignore + if results.sweep_type == "absolute": + current_ax.set_xlabel(f"Absolute value of: {results.var_names}") # type: ignore + else: + current_ax.set_xlabel( # type: ignore + f"{results.sweep_type.capitalize()} change of: {results.var_names}" + ) + current_ax.set_ylabel(ylabel) # type: ignore + + # Baseline reference lines + if results.sweep_type == "multiplicative": + current_ax.axvline(x=1, color="k", linestyle="--", alpha=0.5) # type: ignore + if not marginal: + current_ax.axhline(y=0, color="k", linestyle="--", alpha=0.5) # type: ignore + elif results.sweep_type == "additive": + current_ax.axvline(x=0, color="k", linestyle="--", alpha=0.5) # type: ignore + + # Format y + if percentage: + current_ax.yaxis.set_major_formatter( # type: ignore + plt.FuncFormatter(lambda v, _: f"{v:.1%}") # type: ignore + ) + else: + current_ax.yaxis.set_major_formatter( # type: ignore + plt.FuncFormatter(lambda v, _: f"{v:,.1f}") # type: ignore + ) + + # Adjust y-lims sign aware + y_vals = y_mean.values + if np.all(y_vals < 0): + current_ax.set_ylim(top=0) # type: ignore + elif np.all(y_vals > 0): + current_ax.set_ylim(bottom=0) # type: ignore + + if multi_panel and sharey: + y_mins.append(current_ax.get_ylim()[0]) # type: ignore + y_maxs.append(current_ax.get_ylim()[1]) # type: ignore + + current_ax.legend(loc="best") # type: ignore + + # Share y limits if requested + if multi_panel and sharey: + global_min, global_max = min(y_mins), max(y_maxs) + for row_idx in range(len(dim_combinations)): + axes[row_idx][0].set_ylim(global_min, global_max) + + if multi_panel: + fig.tight_layout() + return fig, axes + else: + return ax # single axis for backwards compatibility From e2c8d8e1486bac230561ba88b2894b970a73a058 Mon Sep 17 00:00:00 2001 From: PabloRoque Date: Tue, 19 Aug 2025 08:41:35 +0200 Subject: [PATCH 2/3] Add tests for multidim sensitivity analysis plots --- tests/mmm/test_plot.py | 100 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/tests/mmm/test_plot.py b/tests/mmm/test_plot.py index 28aad026..303856e9 100644 --- a/tests/mmm/test_plot.py +++ b/tests/mmm/test_plot.py @@ -175,12 +175,77 @@ def mock_idata() -> az.InferenceData: ) +@pytest.fixture(scope="module") +def mock_idata_with_sensitivity(mock_idata): + # Copy the mock_idata so we don't mutate the shared fixture + idata = mock_idata.copy() + n_chain, n_draw, n_sweep = 2, 10, 5 + sweep = np.linspace(0.5, 1.5, n_sweep) + # Add a single extra dim for multi-panel test + extra_dim = ["A", "B"] + # y and marginal_effects: dims (chain, draw, sweep, extra) + y = xr.DataArray( + np.random.normal(0, 1, size=(n_chain, n_draw, n_sweep, len(extra_dim))), + dims=("chain", "draw", "sweep", "region"), + coords={ + "chain": np.arange(n_chain), + "draw": np.arange(n_draw), + "sweep": sweep, + "region": extra_dim, + }, + ) + marginal_effects = xr.DataArray( + np.random.normal(0, 1, size=(n_chain, n_draw, n_sweep, len(extra_dim))), + dims=("chain", "draw", "sweep", "region"), + coords={ + "chain": np.arange(n_chain), + "draw": np.arange(n_draw), + "sweep": sweep, + "region": extra_dim, + }, + ) + # Add sweep_type and var_names as attrs/coords + sensitivity_analysis = xr.Dataset( + {"y": y, "marginal_effects": marginal_effects}, + coords={"sweep": sweep, "region": extra_dim}, + attrs={"sweep_type": "multiplicative", "var_names": "test_var"}, + ) + # Attach to idata + idata.sensitivity_analysis = sensitivity_analysis + # Add posterior_predictive for percentage test + idata.posterior_predictive = xr.Dataset( + { + "y": xr.DataArray( + np.abs( + np.random.normal( + 10, 2, size=(n_chain, n_draw, n_sweep, len(extra_dim)) + ) + ), + dims=("chain", "draw", "sweep", "region"), + coords={ + "chain": np.arange(n_chain), + "draw": np.arange(n_draw), + "sweep": sweep, + "region": extra_dim, + }, + ) + } + ) + return idata + + @pytest.fixture(scope="module") def mock_suite(mock_idata): """Fixture to create a mock MMMPlotSuite with a mocked posterior.""" return MMMPlotSuite(idata=mock_idata) +@pytest.fixture(scope="module") +def mock_suite_with_sensitivity(mock_idata_with_sensitivity): + """Fixture to create a mock MMMPlotSuite with sensitivity analysis.""" + return MMMPlotSuite(idata=mock_idata_with_sensitivity) + + def test_contributions_over_time_expand_dims(mock_suite: MMMPlotSuite): fig, ax = mock_suite.contributions_over_time( var=[ @@ -616,3 +681,38 @@ def test_saturation_curves_multi_dim_axes_shape( "country" ] assert axes.shape == (n_channels, n_countries) + + +def test_plot_sensitivity_analysis_basic(mock_suite_with_sensitivity): + # Should return (fig, axes) for multi-panel + fig, axes = mock_suite_with_sensitivity.plot_sensitivity_analysis() + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert all(isinstance(ax, Axes) for ax in axes.flat) + + +def test_plot_sensitivity_analysis_marginal(mock_suite_with_sensitivity): + fig, axes = mock_suite_with_sensitivity.plot_sensitivity_analysis(marginal=True) + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + + +def test_plot_sensitivity_analysis_percentage(mock_suite_with_sensitivity): + fig, axes = mock_suite_with_sensitivity.plot_sensitivity_analysis(percentage=True) + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + + +def test_plot_sensitivity_analysis_error_on_both_modes(mock_suite_with_sensitivity): + with pytest.raises( + ValueError, match="Not implemented marginal effects in percentage scale." + ): + mock_suite_with_sensitivity.plot_sensitivity_analysis( + marginal=True, percentage=True + ) + + +def test_plot_sensitivity_analysis_error_on_missing_results(mock_idata): + suite = MMMPlotSuite(idata=mock_idata) + with pytest.raises(ValueError, match="No sensitivity analysis results found"): + suite.plot_sensitivity_analysis() From 9eb23665c3bfb97dabff58abf0c464a49628abff Mon Sep 17 00:00:00 2001 From: PabloRoque Date: Tue, 19 Aug 2025 09:02:05 +0200 Subject: [PATCH 3/3] Minor string change --- tests/mmm/test_sensitivity_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/mmm/test_sensitivity_analysis.py b/tests/mmm/test_sensitivity_analysis.py index 8695104d..62537ff3 100644 --- a/tests/mmm/test_sensitivity_analysis.py +++ b/tests/mmm/test_sensitivity_analysis.py @@ -292,7 +292,7 @@ def test_plot_sensitivity_analysis_basic(sensitivity_analysis_with_results): assert isinstance(ax, Axes) # Check basic plot properties - assert ax.get_title() == "Sensitivity analysis plot" + assert ax.get_title() == "Sensitivity analysis" assert "Multiplicative change of" in ax.get_xlabel() assert ax.get_ylabel() == "Total uplift (sum over dates)"