Skip to content

Commit e353729

Browse files
authored
Fix grid shape logic in saturation_curves and add tests (#1889)
* Fix grid shape logic in saturation_curves and add tests Corrects the grid shape calculation in MMMPlotSuite.saturation_curves to handle cases with and without additional dimensions. Add tests to validate axes shape for both single-dimension and multi-dimension channel_data scenarios. * Deleting comment
1 parent b67196e commit e353729

File tree

2 files changed

+129
-6
lines changed

2 files changed

+129
-6
lines changed

pymc_marketing/mmm/plot.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -531,12 +531,14 @@ def saturation_curves(
531531
# — 1. figure out grid shape based on scatter data dimensions —
532532
cdims = self.idata.constant_data.channel_data.dims
533533
additional_dims = [d for d in cdims if d not in ("date", "channel")]
534-
additional_coords = (
535-
[self.idata.constant_data.coords[d].values for d in additional_dims]
536-
if additional_dims
537-
else [()]
538-
)
539-
combos = list(itertools.product(*additional_coords))
534+
if additional_dims:
535+
additional_coords = [
536+
self.idata.constant_data.coords[d].values for d in additional_dims
537+
]
538+
combos = list(itertools.product(*additional_coords))
539+
else:
540+
# No extra dims: single combination
541+
combos = [()]
540542
channels = self.idata.constant_data.coords["channel"].values
541543
n_rows, n_cols = len(channels), len(combos)
542544

tests/mmm/test_plot.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,3 +495,124 @@ def test_saturation_curves_scatter_deprecation_warning(mock_suite_with_constant_
495495
assert isinstance(fig, Figure)
496496
assert isinstance(axes, np.ndarray)
497497
assert all(isinstance(ax, Axes) for ax in axes.flat)
498+
499+
500+
@pytest.fixture(scope="module")
501+
def mock_idata_with_constant_data_single_dim() -> az.InferenceData:
502+
"""Mock InferenceData where channel_data has only ('date','channel') dims."""
503+
seed = sum(map(ord, "Saturation single-dim tests"))
504+
rng = np.random.default_rng(seed)
505+
normal = rng.normal
506+
507+
dates = pd.date_range("2025-01-01", periods=12, freq="W-MON")
508+
channels = ["channel_1", "channel_2", "channel_3"]
509+
510+
posterior = xr.Dataset(
511+
{
512+
"channel_contribution": xr.DataArray(
513+
normal(size=(2, 10, 12, 3)),
514+
dims=("chain", "draw", "date", "channel"),
515+
coords={
516+
"chain": np.arange(2),
517+
"draw": np.arange(10),
518+
"date": dates,
519+
"channel": channels,
520+
},
521+
),
522+
"channel_contribution_original_scale": xr.DataArray(
523+
normal(size=(2, 10, 12, 3)) * 100.0,
524+
dims=("chain", "draw", "date", "channel"),
525+
coords={
526+
"chain": np.arange(2),
527+
"draw": np.arange(10),
528+
"date": dates,
529+
"channel": channels,
530+
},
531+
),
532+
}
533+
)
534+
535+
constant_data = xr.Dataset(
536+
{
537+
"channel_data": xr.DataArray(
538+
rng.uniform(0, 10, size=(12, 3)),
539+
dims=("date", "channel"),
540+
coords={"date": dates, "channel": channels},
541+
),
542+
"channel_scale": xr.DataArray(
543+
[100.0, 150.0, 200.0], dims=("channel",), coords={"channel": channels}
544+
),
545+
"target_scale": xr.DataArray(
546+
[1000.0], dims="target", coords={"target": ["y"]}
547+
),
548+
}
549+
)
550+
551+
return az.InferenceData(posterior=posterior, constant_data=constant_data)
552+
553+
554+
@pytest.fixture(scope="module")
555+
def mock_suite_with_constant_data_single_dim(mock_idata_with_constant_data_single_dim):
556+
return MMMPlotSuite(idata=mock_idata_with_constant_data_single_dim)
557+
558+
559+
@pytest.fixture(scope="module")
560+
def mock_saturation_curve_single_dim() -> xr.DataArray:
561+
"""Saturation curve with dims ('chain','draw','channel','x')."""
562+
seed = sum(map(ord, "Saturation curve single-dim"))
563+
rng = np.random.default_rng(seed)
564+
x_values = np.linspace(0, 1, 50)
565+
channels = ["channel_1", "channel_2", "channel_3"]
566+
567+
# shape: (chains=2, draws=10, channel=3, x=50)
568+
curve_array = np.empty((2, 10, len(channels), len(x_values)))
569+
for ci in range(2):
570+
for di in range(10):
571+
for c in range(len(channels)):
572+
curve_array[ci, di, c, :] = x_values / (1 + x_values) + rng.normal(
573+
0, 0.02, size=x_values.shape
574+
)
575+
576+
return xr.DataArray(
577+
curve_array,
578+
dims=("chain", "draw", "channel", "x"),
579+
coords={
580+
"chain": np.arange(2),
581+
"draw": np.arange(10),
582+
"channel": channels,
583+
"x": x_values,
584+
},
585+
name="saturation_curve",
586+
)
587+
588+
589+
def test_saturation_curves_single_dim_axes_shape(
590+
mock_suite_with_constant_data_single_dim, mock_saturation_curve_single_dim
591+
):
592+
"""When there are no extra dims, columns should default to 1 (no ncols=0)."""
593+
fig, axes = mock_suite_with_constant_data_single_dim.saturation_curves(
594+
curve=mock_saturation_curve_single_dim, n_samples=3
595+
)
596+
597+
assert isinstance(fig, Figure)
598+
assert isinstance(axes, np.ndarray)
599+
# Expect (n_channels, 1)
600+
assert axes.shape[1] == 1
601+
assert axes.shape[0] == mock_saturation_curve_single_dim.sizes["channel"]
602+
603+
604+
def test_saturation_curves_multi_dim_axes_shape(
605+
mock_suite_with_constant_data, mock_saturation_curve
606+
):
607+
"""With an extra dim (e.g., 'country'), expect (n_channels, n_countries)."""
608+
fig, axes = mock_suite_with_constant_data.saturation_curves(
609+
curve=mock_saturation_curve, n_samples=2
610+
)
611+
612+
assert isinstance(fig, Figure)
613+
assert isinstance(axes, np.ndarray)
614+
n_channels = mock_saturation_curve.sizes["channel"]
615+
n_countries = mock_suite_with_constant_data.idata.constant_data.channel_data.sizes[
616+
"country"
617+
]
618+
assert axes.shape == (n_channels, n_countries)

0 commit comments

Comments
 (0)