Skip to content

Commit bf00dfd

Browse files
authored
MMM: Fix Scaling Intercept (#1845)
* patch intercept scaling * fix condition * test init * undo * fix tests
1 parent 236f39d commit bf00dfd

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

pymc_marketing/mmm/mmm.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1983,13 +1983,35 @@ def sample_posterior_predictive(
19831983
date=slice(self.adstock.l_max, None)
19841984
)
19851985

1986+
intercept_condition = (
1987+
"intercept" in sample_posterior_predictive_kwargs.get("var_names", [])
1988+
and not self.time_varying_intercept
1989+
)
1990+
19861991
if original_scale:
1992+
# We need to expand the intercept to the date dimension
1993+
# because the target transformer is applied to the date dimension
1994+
if intercept_condition:
1995+
posterior_predictive_samples["intercept"] = (
1996+
posterior_predictive_samples["intercept"]
1997+
.expand_dims(
1998+
dim={"date": posterior_predictive_samples["date"]}, axis=0
1999+
)
2000+
.rename("intercept")
2001+
)
2002+
19872003
posterior_predictive_samples = apply_sklearn_transformer_across_dim(
19882004
data=posterior_predictive_samples,
19892005
func=self.get_target_transformer().inverse_transform,
19902006
dim_name="date",
19912007
)
19922008

2009+
# We need to remove the date dimension after the inverse transform
2010+
if intercept_condition:
2011+
posterior_predictive_samples["intercept"] = (
2012+
posterior_predictive_samples["intercept"].isel(date=0)
2013+
)
2014+
19932015
return posterior_predictive_samples
19942016

19952017
def add_lift_test_measurements(

tests/mmm/test_mmm.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,32 +1000,52 @@ def new_date_ranges_to_test():
10001000
"new_dates",
10011001
new_date_ranges_to_test(),
10021002
)
1003-
@pytest.mark.parametrize("combined", [True, False])
1004-
@pytest.mark.parametrize("original_scale", [True, False])
1003+
@pytest.mark.parametrize(
1004+
argnames="combined",
1005+
argvalues=[True, False],
1006+
ids=["combined", "not_combined"],
1007+
)
1008+
@pytest.mark.parametrize(
1009+
argnames="original_scale",
1010+
argvalues=[True, False],
1011+
ids=["original_scale", "scaled"],
1012+
)
1013+
@pytest.mark.parametrize(
1014+
argnames="var_names",
1015+
argvalues=[None, ["mu", "y_sigma", "channel_contribution"], ["mu", "intercept"]],
1016+
ids=["no_var_names", "var_names", "var_names_with_intercept"],
1017+
)
10051018
def test_new_data_sample_posterior_predictive_method(
10061019
generate_data,
10071020
toy_X,
10081021
model_name: str,
10091022
new_dates: pd.DatetimeIndex,
10101023
combined: bool,
10111024
original_scale: bool,
1025+
var_names: list[str] | None,
10121026
request,
10131027
) -> None:
10141028
"""This is the method that is used in all the other methods that generate predictions."""
10151029
mmm = request.getfixturevalue(model_name)
10161030
X = generate_data(new_dates)
10171031

1032+
kwargs = {"var_names": var_names} if var_names is not None else {}
1033+
10181034
posterior_predictive = mmm.sample_posterior_predictive(
10191035
X=X,
10201036
extend_idata=False,
10211037
combined=combined,
10221038
original_scale=original_scale,
1039+
**kwargs,
10231040
)
10241041
pd.testing.assert_index_equal(
10251042
pd.DatetimeIndex(posterior_predictive.coords["date"]),
10261043
new_dates,
10271044
)
10281045

1046+
if var_names is not None:
1047+
assert var_names == list(posterior_predictive.data_vars)
1048+
10291049

10301050
@pytest.mark.parametrize(
10311051
"predictions",

0 commit comments

Comments
 (0)