Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a889a22
model spec
juanitorduz May 2, 2024
0f18e5c
changes init
juanitorduz May 2, 2024
7aa5b34
improveements
juanitorduz May 2, 2024
621e4bf
try othere way of adding links
juanitorduz May 2, 2024
23269d5
make color cohorent with the color palette
juanitorduz May 2, 2024
73367ba
add link to clases
juanitorduz May 2, 2024
02b55e8
Merge branch 'main' into mmm_nb_improvements
juanitorduz May 2, 2024
41aea18
add new spends plot
juanitorduz May 2, 2024
1811a7a
Merge branch 'main' into mmm_nb_improvements
juanitorduz May 2, 2024
5540b0e
Merge branch 'main' into mmm_nb_improvements
juanitorduz May 2, 2024
3b47eeb
add feedback part 1
juanitorduz May 2, 2024
26df3cf
Merge branch 'main' into mmm_nb_improvements
juanitorduz May 3, 2024
f3cb12b
Merge branch 'main' into mmm_nb_improvements
juanitorduz May 3, 2024
57c9bc0
Merge branch 'main' into mmm_nb_improvements
juanitorduz May 6, 2024
792ea92
add errors plot
juanitorduz May 7, 2024
14680ed
Merge branch 'mmm_nb_improvements' of https://github.com/pymc-labs/py…
juanitorduz May 7, 2024
029da1f
typo
juanitorduz May 7, 2024
2111c25
Update pymc_marketing/mmm/base.py
juanitorduz May 7, 2024
d573082
modularize code
juanitorduz May 7, 2024
afab68e
Merge branch 'mmm_nb_improvements' of https://github.com/pymc-labs/py…
juanitorduz May 7, 2024
78b8627
clean code
juanitorduz May 7, 2024
0ade5f8
add some initial tests
juanitorduz May 7, 2024
771dec8
fix tests
juanitorduz May 7, 2024
a32d7d7
git test base class
juanitorduz May 7, 2024
e822a8b
improvee broadcasting
juanitorduz May 7, 2024
2de1778
add more tests
juanitorduz May 7, 2024
62acfa8
add errors formula
juanitorduz May 7, 2024
b33eeab
fix test
juanitorduz May 7, 2024
66b24ac
make dims consistent
juanitorduz May 8, 2024
2aef35f
Merge branch 'main' into mmm_nb_improvements
juanitorduz May 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/notebooks/general/other_nuts_samplers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"id": "51e3591e",
"metadata": {},
"source": [
"(other_nuts_samplers)=\n",
"# Other NUTS Samplers\n",
"\n",
"In this notebook we show how to fit a CLV model with other NUTS samplers. These alternative samplers can be significantly faster and also sample on the GPU.\n",
Expand Down
2,734 changes: 1,521 additions & 1,213 deletions docs/source/notebooks/mmm/mmm_example.ipynb

Large diffs are not rendered by default.

220 changes: 172 additions & 48 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from pymc_marketing.mmm.budget_optimizer import budget_allocator
from pymc_marketing.mmm.transformers import michaelis_menten
from pymc_marketing.mmm.utils import (
apply_sklearn_transformer_across_dim,
estimate_menten_parameters,
estimate_sigmoid_parameters,
find_sigmoid_inflection_point,
Expand Down Expand Up @@ -337,71 +338,188 @@ def plot_prior_predictive(
def plot_posterior_predictive(
self, original_scale: bool = False, ax: plt.Axes = None, **plt_kwargs: Any
) -> plt.Figure:
posterior_predictive_data: Dataset = self.posterior_predictive
likelihood_hdi_94: DataArray = az.hdi(
ary=posterior_predictive_data, hdi_prob=0.94
)[self.output_var]
likelihood_hdi_50: DataArray = az.hdi(
ary=posterior_predictive_data, hdi_prob=0.50
)[self.output_var]
"""Plot posterior distribution from the model fit.

if original_scale:
likelihood_hdi_94 = self.get_target_transformer().inverse_transform(
Xt=likelihood_hdi_94
)
likelihood_hdi_50 = self.get_target_transformer().inverse_transform(
Xt=likelihood_hdi_50
Parameters
----------
original_scale : bool, optional
Whether to plot in the original scale.
ax : plt.Axes, optional
Matplotlib axis object.
**plt_kwargs
Keyword arguments passed to `plt.subplots`.

Returns
-------
plt.Figure
"""
try:
posterior_predictive_data: Dataset = self.posterior_predictive

except Exception as e:
raise RuntimeError(
"Make sure the model has bin fitted and the posterior predictive has been sampled!"
) from e

target_to_plot = np.asarray(
self.y
if original_scale
else transform_1d_array(self.get_target_transformer().transform, self.y)
)

if len(target_to_plot) != len(posterior_predictive_data.date):
raise ValueError(
"The length of the target variable doesn't match the length of the date column. "
"If you are predicting out-of-sample, please overwrite `self.y` with the "
"corresponding (non-transformed) target variable."
)

if ax is None:
fig, ax = plt.subplots(**plt_kwargs)
else:
fig = ax.figure

if self.X is not None and self.y is not None:
ax.fill_between(
x=posterior_predictive_data.date,
y1=likelihood_hdi_94[:, 0],
y2=likelihood_hdi_94[:, 1],
color="C0",
alpha=0.2,
label="$94\%$ HDI", # noqa: W605
)
for hdi_prob, alpha in zip((0.94, 0.50), (0.2, 0.4), strict=True):
likelihood_hdi: DataArray = az.hdi(
ary=posterior_predictive_data, hdi_prob=hdi_prob
)[self.output_var]

if original_scale:
likelihood_hdi = self.get_target_transformer().inverse_transform(
Xt=likelihood_hdi
)

ax.fill_between(
x=posterior_predictive_data.date,
y1=likelihood_hdi_50[:, 0],
y2=likelihood_hdi_50[:, 1],
y1=likelihood_hdi[:, 0],
y2=likelihood_hdi[:, 1],
color="C0",
alpha=0.3,
label="$50\%$ HDI", # noqa: W605
alpha=alpha,
label=f"${100 * hdi_prob}\%$ HDI", # noqa: W605
)

target_to_plot = np.asarray(
self.y
if original_scale
else transform_1d_array(self.get_target_transformer().transform, self.y)
)
ax.plot(
np.asarray(posterior_predictive_data.date),
target_to_plot,
color="black",
label="Observed",
)
ax.legend()
ax.set(
title="Posterior Predictive Check",
xlabel="date",
ylabel=self.output_var,
)

if len(target_to_plot) != len(posterior_predictive_data.date):
raise ValueError(
"The length of the target variable doesn't match the length of the date column. "
"If you are predicting out-of-sample, please overwrite `self.y` with the "
"corresponding (non-transformed) target variable."
)
return fig

ax.plot(
np.asarray(posterior_predictive_data.date),
target_to_plot,
color="black",
def get_errors(self, original_scale: bool = False) -> DataArray:
"""Get model errors posterior distribution.

errors = true values - predicted

Parameters
----------
original_scale : bool, optional
Whether to plot in the original scale.

Returns
-------
DataArray
"""
try:
posterior_predictive_data: Dataset = self.posterior_predictive

except Exception as e:
raise RuntimeError(
"Make sure the model has bin fitted and the posterior predictive has been sampled!"
) from e

target_array = np.asarray(
transform_1d_array(self.get_target_transformer().transform, self.y)
)

if len(target_array) != len(posterior_predictive_data.date):
raise ValueError(
"The length of the target variable doesn't match the length of the date column. "
"If you are computing out-of-sample errors, please overwrite `self.y` with the "
"corresponding (non-transformed) target variable."
)
ax.set(
title="Posterior Predictive Check",
xlabel="date",
ylabel=self.output_var,

target = (
pd.Series(target_array, index=self.posterior_predictive.date)
.rename_axis("date")
.to_xarray()
)

errors = (
(target - posterior_predictive_data)[self.output_var]
.rename("errors")
.transpose(..., "date")
)

if original_scale:
return apply_sklearn_transformer_across_dim(
data=errors,
func=self.get_target_transformer().inverse_transform,
dim_name="date",
)

return errors

def plot_errors(
self, original_scale: bool = False, ax: plt.Axes = None, **plt_kwargs: Any
) -> plt.Figure:
"""Plot model errors by taking the difference between true values and predicted.

errors = true values - predicted

Parameters
----------
original_scale : bool, optional
Whether to plot in the original scale.
ax : plt.Axes, optional
Matplotlib axis object.
**plt_kwargs
Keyword arguments passed to `plt.subplots`.

Returns
-------
plt.Figure
"""
errors = self.get_errors(original_scale=original_scale)

if ax is None:
fig, ax = plt.subplots(**plt_kwargs)
else:
raise RuntimeError("The model hasn't been fit yet, call .fit() first")
fig = ax.figure

for hdi_prob, alpha in zip((0.94, 0.50), (0.2, 0.4), strict=True):
errors_hdi = az.hdi(ary=errors, hdi_prob=hdi_prob)

ax.fill_between(
x=self.posterior_predictive.date,
y1=errors_hdi["errors"].sel(hdi="lower"),
y2=errors_hdi["errors"].sel(hdi="higher"),
color="C3",
alpha=alpha,
label=f"${100 * hdi_prob}\%$ HDI", # noqa: W605
)

ax.plot(
self.posterior_predictive.date,
errors.mean(dim=("chain", "draw")).to_numpy(),
color="C3",
label="Errors Mean",
)

ax.axhline(y=0.0, linestyle="--", color="black", label="zero")
ax.legend()
ax.set(
title="Errors Posterior Distribution",
xlabel="date",
ylabel="true - predictions",
)
return fig

def _format_model_contributions(self, var_contribution: str) -> DataArray:
Expand Down Expand Up @@ -1411,14 +1529,20 @@ def plot_waterfall_components_decomposition(
cumulative_contribution = 0

for index, row in dataframe.iterrows():
color = "lightblue" if row["contribution"] >= 0 else "salmon"
color = "C0" if row["contribution"] >= 0 else "C3"

bar_start = (
cumulative_contribution + row["contribution"]
if row["contribution"] < 0
else cumulative_contribution
)
ax.barh(row["component"], row["contribution"], left=bar_start, color=color)
ax.barh(
row["component"],
row["contribution"],
left=bar_start,
color=color,
alpha=0.5,
)

if row["contribution"] > 0:
cumulative_contribution += row["contribution"]
Expand Down
16 changes: 16 additions & 0 deletions tests/mmm/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,19 @@ def test_calling_prior_before_sample_prior_predictive_raises_error(
),
):
test_mmm.prior


def test_plot_posterior_predictive_no_fitted(test_mmm) -> None:
with pytest.raises(
RuntimeError,
match="Make sure the model has bin fitted and the posterior predictive has been sampled!",
):
test_mmm.plot_posterior_predictive()


def test_get_errors_raises_not_fitted(test_mmm) -> None:
with pytest.raises(
RuntimeError,
match="Make sure the model has bin fitted and the posterior predictive has been sampled!",
):
test_mmm.get_errors()
74 changes: 74 additions & 0 deletions tests/mmm/test_delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,15 @@ def mmm_fitted(
return mmm


@pytest.fixture(scope="module")
def mmm_fitted_with_posterior_predictive(
mmm_fitted: DelayedSaturatedMMM,
toy_X: pd.DataFrame,
) -> DelayedSaturatedMMM:
_ = mmm_fitted.sample_posterior_predictive(toy_X, extend_idata=True, combined=True)
return mmm_fitted


@pytest.fixture(scope="module")
def mmm_fitted_with_fourier_features(
mmm_with_fourier_features: DelayedSaturatedMMM,
Expand Down Expand Up @@ -415,6 +424,71 @@ def test_channel_contributions_forward_pass_recovers_contribution(
y=mmm_fitted.y.max(),
)

@pytest.mark.parametrize(
argnames="original_scale",
argvalues=[False, True],
ids=["scaled", "original-scale"],
)
def test_get_errors(
self,
mmm_fitted_with_posterior_predictive: DelayedSaturatedMMM,
original_scale: bool,
) -> None:
errors = mmm_fitted_with_posterior_predictive.get_errors(
original_scale=original_scale
)
n_chains = 2
n_draws = 3
assert isinstance(errors, xr.DataArray)
assert errors.name == "errors"
assert errors.shape == (
n_chains,
n_draws,
mmm_fitted_with_posterior_predictive.y.shape[0],
)

def test_get_errors_raises_not_fitted(self) -> None:
my_mmm = DelayedSaturatedMMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
control_columns=["control_1", "control_2"],
)
with pytest.raises(
RuntimeError,
match="Make sure the model has bin fitted and the posterior predictive has been sampled!",
):
my_mmm.get_errors()

def test_posterior_predictive_raises_not_fitted(self) -> None:
my_mmm = DelayedSaturatedMMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
control_columns=["control_1", "control_2"],
)
with pytest.raises(
RuntimeError,
match="Make sure the model has bin fitted and the posterior predictive has been sampled!",
):
my_mmm.plot_posterior_predictive()

def test_get_errors_bad_y_length(
self,
mmm_fitted_with_posterior_predictive: DelayedSaturatedMMM,
):
mmm_fitted_with_posterior_predictive.y = np.array([1, 2])
with pytest.raises(ValueError):
mmm_fitted_with_posterior_predictive.get_errors()

def test_plot_posterior_predictive_bad_y_length(
self,
mmm_fitted_with_posterior_predictive: DelayedSaturatedMMM,
):
mmm_fitted_with_posterior_predictive.y = np.array([1, 2])
with pytest.raises(ValueError):
mmm_fitted_with_posterior_predictive.plot_posterior_predictive()

def test_channel_contributions_forward_pass_is_consistent(
self, mmm_fitted: DelayedSaturatedMMM
) -> None:
Expand Down
3 changes: 3 additions & 0 deletions tests/mmm/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ class ToyMMM(BaseDelayedSaturatedMMM, MaxAbsScaleTarget):
("plot_posterior_predictive", {}),
("plot_posterior_predictive", {"original_scale": True}),
("plot_posterior_predictive", {"ax": plt.subplots()[1]}),
("plot_errors", {}),
("plot_errors", {"original_scale": True}),
("plot_errors", {"ax": plt.subplots()[1]}),
("plot_components_contributions", {}),
("plot_channel_parameter", {"param_name": "alpha"}),
("plot_waterfall_components_decomposition", {"original_scale": True}),
Expand Down