Skip to content

Commit bca55e4

Browse files
authored
Support HSGP instance for time_varying_media in MMM (#1881)
* Support HSGP instance for time_varying_media in MMM Allow passing an HSGPBase instance (e.g., SoftPlusHSGP) to the time_varying_media argument in MMM, enabling custom latent process dimensions and priors. Update logic to handle both boolean and HSGPBase types, and add tests to verify correct behavior and dimension broadcasting for both single- and multi-dimensional cases. * Changing tests * update model version
1 parent 28b7c22 commit bca55e4

File tree

2 files changed

+129
-10
lines changed

2 files changed

+129
-10
lines changed

pymc_marketing/mmm/multidimensional.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
)
4848
from pymc_marketing.mmm.events import EventEffect
4949
from pymc_marketing.mmm.fourier import YearlyFourier
50+
from pymc_marketing.mmm.hsgp import HSGPBase
5051
from pymc_marketing.mmm.lift_test import (
5152
add_lift_measurements_to_likelihood_from_saturation,
5253
scale_lift_measurements,
@@ -115,7 +116,7 @@ class MMM(ModelBuilder):
115116
"""
116117

117118
_model_type: str = "MMMM (Multi-Dimensional Marketing Mix Model)"
118-
version: str = "0.0.1"
119+
version: str = "0.0.2"
119120

120121
@validate_call
121122
def __init__(
@@ -137,8 +138,13 @@ def __init__(
137138
Field(strict=True, description="Whether to use a time-varying intercept"),
138139
] = False,
139140
time_varying_media: Annotated[
140-
bool,
141-
Field(strict=True, description="Whether to use time-varying media effects"),
141+
bool | InstanceOf[HSGPBase],
142+
Field(
143+
description=(
144+
"Whether to use time-varying media effects, or pass an HSGP instance "
145+
"(e.g., SoftPlusHSGP) specifying dims and priors."
146+
),
147+
),
142148
] = False,
143149
dims: tuple[str, ...] | None = Field(
144150
None, description="Additional dimensions for the model."
@@ -757,7 +763,7 @@ def _generate_and_preprocess_model_data(
757763
for dim in self.xarray_dataset.coords.dims
758764
}
759765

760-
if self.time_varying_intercept | self.time_varying_media:
766+
if bool(self.time_varying_intercept) or bool(self.time_varying_media):
761767
self._time_index = np.arange(0, X[self.date_column].unique().shape[0])
762768
self._time_index_mid = X[self.date_column].unique().shape[0] // 2
763769
self._time_resolution = (
@@ -1036,7 +1042,7 @@ def build_model(
10361042
for mu_effect in self.mu_effects:
10371043
mu_effect.create_data(self)
10381044

1039-
if self.time_varying_intercept | self.time_varying_media:
1045+
if bool(self.time_varying_intercept) or bool(self.time_varying_media):
10401046
time_index = pm.Data(
10411047
name="time_index",
10421048
value=self._time_index,
@@ -1066,7 +1072,7 @@ def build_model(
10661072
)
10671073

10681074
# Add media logic
1069-
if self.time_varying_media:
1075+
if isinstance(self.time_varying_media, bool) and self.time_varying_media:
10701076
baseline_channel_contribution = pm.Deterministic(
10711077
name="baseline_channel_contribution",
10721078
var=self.forward_pass(
@@ -1079,13 +1085,44 @@ def build_model(
10791085
X=time_index,
10801086
dims=("date", *self.dims),
10811087
**self.model_config["media_tvp_config"],
1082-
).create_variable("media_latent_process")
1088+
).create_variable("media_temporal_latent_multiplier")
10831089

10841090
channel_contribution = pm.Deterministic(
10851091
name="channel_contribution",
10861092
var=baseline_channel_contribution * media_latent_process[..., None],
10871093
dims=("date", *self.dims, "channel"),
10881094
)
1095+
elif isinstance(self.time_varying_media, HSGPBase):
1096+
baseline_channel_contribution = self.forward_pass(
1097+
x=channel_data_, dims=(*self.dims, "channel")
1098+
)
1099+
baseline_channel_contribution.name = "baseline_channel_contribution"
1100+
baseline_channel_contribution.dims = (
1101+
"date",
1102+
*self.dims,
1103+
"channel",
1104+
)
1105+
1106+
# Register internal time index and build latent process
1107+
self.time_varying_media.register_data(time_index)
1108+
media_latent_process = self.time_varying_media.create_variable(
1109+
"media_temporal_latent_multiplier"
1110+
)
1111+
1112+
# Determine broadcasting over channel axis
1113+
media_dims = pm.modelcontext(None).named_vars_to_dims[
1114+
media_latent_process.name
1115+
]
1116+
if "channel" in media_dims:
1117+
media_broadcast = media_latent_process
1118+
else:
1119+
media_broadcast = media_latent_process[..., None]
1120+
1121+
channel_contribution = pm.Deterministic(
1122+
name="channel_contribution",
1123+
var=baseline_channel_contribution * media_broadcast,
1124+
dims=("date", *self.dims, "channel"),
1125+
)
10891126
else:
10901127
channel_contribution = pm.Deterministic(
10911128
name="channel_contribution",
@@ -1681,7 +1718,7 @@ def add_lift_test_measurements(
16811718
# This is coupled with the name of the
16821719
# latent process Deterministic
16831720
time_varying_var_name = (
1684-
"media_latent_process" if self.time_varying_media else None
1721+
"media_temporal_latent_multiplier" if self.time_varying_media else None
16851722
)
16861723
add_lift_measurements_to_likelihood_from_saturation(
16871724
df_lift_test=df_lift_test_scaled,

tests/mmm/test_multidimensional.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from pytensor.tensor.basic import TensorVariable
2626
from scipy.optimize import OptimizeResult
2727

28-
from pymc_marketing.mmm import GeometricAdstock, LogisticSaturation
28+
from pymc_marketing.mmm import GeometricAdstock, LogisticSaturation, SoftPlusHSGP
2929
from pymc_marketing.mmm.additive_effect import EventAdditiveEffect, LinearTrendEffect
3030
from pymc_marketing.mmm.events import EventEffect, GaussianBasis
3131
from pymc_marketing.mmm.lift_test import _swap_columns_and_last_index_level
@@ -265,7 +265,7 @@ def unstack(data, name):
265265
if time_varying_intercept:
266266
assert "intercept_latent_process" in var_names
267267
if time_varying_media:
268-
assert "media_latent_process" in var_names
268+
assert "media_temporal_latent_multiplier" in var_names
269269
if yearly_seasonality is not None:
270270
assert "fourier_contribution" in var_names
271271

@@ -501,6 +501,88 @@ def test_sample_posterior_predictive_partial_overlap_with_include_last_observati
501501
)
502502

503503

504+
@pytest.mark.parametrize(
505+
"hsgp_dims",
506+
[
507+
pytest.param(("date",), id="hsgp-dims=date"),
508+
pytest.param(("date", "channel"), id="hsgp-dims=date,channel"),
509+
],
510+
)
511+
def test_time_varying_media_with_custom_hsgp_single_dim(single_dim_data, hsgp_dims):
512+
"""Ensure passing an HSGP instance to time_varying_media works (single-dim)."""
513+
X, y = single_dim_data
514+
515+
# Build HSGP using the new API
516+
hsgp = SoftPlusHSGP.parameterize_from_data(
517+
X=np.arange(X.shape[0]),
518+
dims=hsgp_dims,
519+
)
520+
521+
mmm = MMM(
522+
date_column="date",
523+
target_column="target",
524+
channel_columns=["channel_1", "channel_2", "channel_3"],
525+
adstock=GeometricAdstock(l_max=2),
526+
saturation=LogisticSaturation(),
527+
time_varying_media=hsgp,
528+
)
529+
530+
mmm.build_model(X, y)
531+
532+
# Check latent multiplier exists with the expected dims
533+
var_name = "media_temporal_latent_multiplier"
534+
assert var_name in mmm.model.named_vars
535+
latent_dims = mmm.model.named_vars_to_dims[var_name]
536+
assert latent_dims == hsgp_dims
537+
538+
# Channel contribution should always be date x channel
539+
assert mmm.model.named_vars_to_dims["channel_contribution"] == ("date", "channel")
540+
541+
542+
@pytest.mark.parametrize(
543+
"hsgp_dims",
544+
[
545+
pytest.param(("date", "country"), id="hsgp-dims=date,country"),
546+
pytest.param(
547+
("date", "country", "channel"), id="hsgp-dims=date,country,channel"
548+
),
549+
],
550+
)
551+
def test_time_varying_media_with_custom_hsgp_multi_dim(df, hsgp_dims):
552+
"""Ensure passing an HSGP instance to time_varying_media works (multi-dim)."""
553+
X = df.drop(columns=["y"])
554+
y = df["y"]
555+
556+
hsgp = SoftPlusHSGP.parameterize_from_data(
557+
X=np.arange(X.shape[0]),
558+
dims=hsgp_dims,
559+
)
560+
561+
mmm = MMM(
562+
date_column="date",
563+
channel_columns=["C1", "C2"],
564+
target_column="y",
565+
dims=("country",),
566+
adstock=GeometricAdstock(l_max=2),
567+
saturation=LogisticSaturation(),
568+
time_varying_media=hsgp,
569+
)
570+
571+
mmm.build_model(X, y)
572+
573+
var_name = "media_temporal_latent_multiplier"
574+
assert var_name in mmm.model.named_vars
575+
latent_dims = mmm.model.named_vars_to_dims[var_name]
576+
assert latent_dims == hsgp_dims
577+
578+
# Channel contribution should always be date x country x channel
579+
assert mmm.model.named_vars_to_dims["channel_contribution"] == (
580+
"date",
581+
"country",
582+
"channel",
583+
)
584+
585+
504586
def test_sample_posterior_predictive_no_overlap_with_include_last_observations(
505587
single_dim_data, mock_pymc_sample
506588
):

0 commit comments

Comments
 (0)