Skip to content

Commit de5679f

Browse files
authored
Allow plot MMM components in the original scale (#870)
* add original scale implementation * add plot nb * change location * undo * make mypy happy * test plot * add test * update plot readme * fix test * improve variable description
1 parent 9129a9e commit de5679f

File tree

5 files changed

+1446
-1288
lines changed

5 files changed

+1446
-1288
lines changed
42 KB
Loading

docs/source/notebooks/mmm/mmm_example.ipynb

Lines changed: 1257 additions & 1288 deletions
Large diffs are not rendered by default.

pymc_marketing/mmm/delayed_saturated_mmm.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,6 +1072,168 @@ def plot_channel_parameter(self, param_name: str, **plt_kwargs: Any) -> plt.Figu
10721072
)
10731073
return fig
10741074

1075+
def get_ts_contribution_posterior(
1076+
self, var_contribution: str, original_scale: bool = False
1077+
) -> DataArray:
1078+
"""Get the posterior distribution of the time series contributions of a given variable.
1079+
1080+
Parameters
1081+
----------
1082+
var_contribution : str
1083+
The variable for which to get the contributions. It must be a valid variable
1084+
in the `fit_result` attribute.
1085+
original_scale : bool, optional
1086+
Whether to plot in the original scale.
1087+
1088+
Returns
1089+
-------
1090+
DataArray
1091+
The posterior distribution of the time series contributions.
1092+
"""
1093+
contributions = self._format_model_contributions(
1094+
var_contribution=var_contribution
1095+
)
1096+
1097+
if original_scale:
1098+
return apply_sklearn_transformer_across_dim(
1099+
data=contributions,
1100+
func=self.get_target_transformer().inverse_transform,
1101+
dim_name="date",
1102+
)
1103+
1104+
return contributions
1105+
1106+
def plot_components_contributions(
1107+
self, original_scale: bool = False, **plt_kwargs: Any
1108+
) -> plt.Figure:
1109+
"""Plot the target variable and the posterior predictive model components in
1110+
the scaled space.
1111+
1112+
Parameters
1113+
----------
1114+
original_scale : bool, optional
1115+
Whether to plot in the original scale.
1116+
1117+
**plt_kwargs
1118+
Additional keyword arguments to pass to `plt.subplots`.
1119+
1120+
Returns
1121+
-------
1122+
plt.Figure
1123+
"""
1124+
channel_contributions = self.get_ts_contribution_posterior(
1125+
var_contribution="channel_contributions", original_scale=original_scale
1126+
)
1127+
1128+
means = [channel_contributions.mean(["chain", "draw"])]
1129+
contribution_vars = [
1130+
az.hdi(channel_contributions, hdi_prob=0.94).channel_contributions
1131+
]
1132+
1133+
for arg, var_contribution in zip(
1134+
["control_columns", "yearly_seasonality"],
1135+
["control_contributions", "fourier_contributions"],
1136+
strict=True,
1137+
):
1138+
if getattr(self, arg, None):
1139+
contributions = self.get_ts_contribution_posterior(
1140+
var_contribution=var_contribution, original_scale=original_scale
1141+
)
1142+
1143+
means.append(contributions.mean(["chain", "draw"]))
1144+
contribution_vars.append(
1145+
az.hdi(contributions, hdi_prob=0.94)[var_contribution]
1146+
)
1147+
1148+
fig, ax = plt.subplots(**plt_kwargs)
1149+
1150+
for i, (mean, hdi, var_contribution) in enumerate(
1151+
zip(
1152+
means,
1153+
contribution_vars,
1154+
[
1155+
"channel_contribution",
1156+
"control_contribution",
1157+
"fourier_contribution",
1158+
],
1159+
strict=False,
1160+
)
1161+
):
1162+
if self.X is not None:
1163+
ax.fill_between(
1164+
x=self.X[self.date_column],
1165+
y1=hdi.isel(hdi=0),
1166+
y2=hdi.isel(hdi=1),
1167+
color=f"C{i}",
1168+
alpha=0.25,
1169+
label=f"$94\\%$ HDI ({var_contribution})",
1170+
)
1171+
ax.plot(
1172+
np.asarray(self.X[self.date_column]),
1173+
np.asarray(mean),
1174+
color=f"C{i}",
1175+
)
1176+
if self.X is not None:
1177+
intercept = az.extract(
1178+
self.fit_result, var_names=["intercept"], combined=False
1179+
)
1180+
1181+
if original_scale:
1182+
intercept = apply_sklearn_transformer_across_dim(
1183+
data=intercept,
1184+
func=self.get_target_transformer().inverse_transform,
1185+
dim_name="chain",
1186+
)
1187+
1188+
if intercept.ndim == 2:
1189+
# Intercept has a stationary prior
1190+
intercept_hdi = np.repeat(
1191+
a=az.hdi(intercept).intercept.data[None, ...],
1192+
repeats=self.X[self.date_column].shape[0],
1193+
axis=0,
1194+
)
1195+
elif intercept.ndim == 3:
1196+
# Intercept has a time-varying prior
1197+
intercept_hdi = az.hdi(intercept).intercept.data
1198+
1199+
ax.plot(
1200+
np.asarray(self.X[self.date_column]),
1201+
np.full(len(self.X[self.date_column]), intercept.mean().data),
1202+
color=f"C{i + 1}",
1203+
)
1204+
ax.fill_between(
1205+
x=self.X[self.date_column],
1206+
y1=intercept_hdi[:, 0],
1207+
y2=intercept_hdi[:, 1],
1208+
color=f"C{i + 1}",
1209+
alpha=0.25,
1210+
label="$94\\%$ HDI (intercept)",
1211+
)
1212+
1213+
y_to_plot = (
1214+
self.get_target_transformer().inverse_transform(
1215+
np.asarray(self.preprocessed_data["y"]).reshape(-1, 1)
1216+
)
1217+
if original_scale
1218+
else np.asarray(self.preprocessed_data["y"])
1219+
)
1220+
1221+
ylabel = self.output_var if original_scale else f"{self.output_var} scaled"
1222+
1223+
ax.plot(
1224+
np.asarray(self.X[self.date_column]),
1225+
y_to_plot,
1226+
label=ylabel,
1227+
color="black",
1228+
)
1229+
ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.1), ncol=3)
1230+
ax.set(
1231+
title="Posterior Predictive Model Components",
1232+
xlabel="date",
1233+
ylabel=ylabel,
1234+
)
1235+
return fig
1236+
10751237
def plot_channel_contributions_grid(
10761238
self,
10771239
start: float,

tests/mmm/test_delayed_saturated_mmm.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,31 @@ def test_channel_contributions_forward_pass_recovers_contribution(
463463
y=mmm_fitted.y.max(),
464464
)
465465

466+
@pytest.mark.parametrize(
467+
argnames="original_scale",
468+
argvalues=[False, True],
469+
ids=["scaled", "original-scale"],
470+
)
471+
@pytest.mark.parametrize(
472+
argnames="var_contribution",
473+
argvalues=["channel_contributions", "control_contributions"],
474+
ids=["channel_contribution", "control_contribution"],
475+
)
476+
def test_get_ts_contribution_posterior(
477+
self,
478+
mmm_fitted_with_posterior_predictive: MMM,
479+
var_contribution: str,
480+
original_scale: bool,
481+
):
482+
ts_posterior = (
483+
mmm_fitted_with_posterior_predictive.get_ts_contribution_posterior(
484+
var_contribution=var_contribution, original_scale=original_scale
485+
)
486+
)
487+
assert ts_posterior.dims == ("chain", "draw", "date")
488+
assert ts_posterior.chain.size == 1
489+
assert ts_posterior.draw.size == 500
490+
466491
@pytest.mark.parametrize(
467492
argnames="original_scale",
468493
argvalues=[False, True],

tests/mmm/test_plotting.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,8 @@ def mock_fitted_mmm(mock_mmm, toy_X, toy_y):
219219
("plot_direct_contribution_curves", {"same_axes": True}),
220220
("plot_direct_contribution_curves", {"channels": ["channel_2"]}),
221221
("plot_channel_parameter", {"param_name": "adstock_alpha"}),
222+
("plot_components_contributions", {}),
223+
("plot_components_contributions", {"original_scale": True}),
222224
],
223225
)
224226
def test_delayed_saturated_mmm_plots(

0 commit comments

Comments
 (0)