Skip to content

Commit b069eae

Browse files
committed
moved adding mean plot to separate private function
1 parent 6656989 commit b069eae

File tree

3 files changed

+50
-26
lines changed

3 files changed

+50
-26
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,6 @@ dmypy.json
142142

143143
# PyCharm .idea files
144144
.idea/
145+
146+
# shell scripts
147+
*.sh

docs/source/notebooks/mmm/mmm_plotting_options.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
},
1919
{
2020
"cell_type": "code",
21-
"execution_count": 2,
21+
"execution_count": 5,
2222
"metadata": {},
2323
"outputs": [
2424
{
@@ -61,7 +61,7 @@
6161
},
6262
{
6363
"cell_type": "code",
64-
"execution_count": 4,
64+
"execution_count": 6,
6565
"metadata": {},
6666
"outputs": [],
6767
"source": [
@@ -78,7 +78,7 @@
7878
},
7979
{
8080
"cell_type": "code",
81-
"execution_count": 5,
81+
"execution_count": 7,
8282
"metadata": {},
8383
"outputs": [
8484
{

pymc_marketing/mmm/base.py

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -382,13 +382,9 @@ def plot_posterior_predictive(
382382
plt.Figure
383383
384384
"""
385-
try:
386-
posterior_predictive_data: Dataset = self.posterior_predictive
387-
388-
except Exception as e:
389-
raise RuntimeError(
390-
"Make sure the model has bin fitted and the posterior predictive has been sampled!"
391-
) from e
385+
posterior_predictive_data: Dataset = self._get_posterior_predictive_data(
386+
original_scale=original_scale
387+
)
392388

393389
target_to_plot = np.asarray(
394390
self.y
@@ -408,13 +404,6 @@ def plot_posterior_predictive(
408404
else:
409405
fig = ax.figure
410406

411-
if original_scale:
412-
posterior_predictive_data = apply_sklearn_transformer_across_dim(
413-
data=posterior_predictive_data,
414-
func=self.get_target_transformer().inverse_transform,
415-
dim_name="date",
416-
)
417-
418407
for hdi_prob, alpha in zip((0.94, 0.50), (0.2, 0.4), strict=True):
419408
likelihood_hdi: DataArray = az.hdi(
420409
ary=posterior_predictive_data, hdi_prob=hdi_prob
@@ -430,15 +419,8 @@ def plot_posterior_predictive(
430419
)
431420

432421
if add_mean:
433-
mean_prediction = posterior_predictive_data[self.output_var].mean(
434-
dim=["chain", "draw"]
435-
)
436-
437-
ax.plot(
438-
np.asarray(posterior_predictive_data.date),
439-
mean_prediction,
440-
color="C0",
441-
label="Mean Prediction",
422+
ax = self._add_mean_to_plot(
423+
ax=ax, original_scale=original_scale, color="red"
442424
)
443425

444426
ax.plot(
@@ -456,6 +438,45 @@ def plot_posterior_predictive(
456438

457439
return fig
458440

441+
def _get_posterior_predictive_data(self, original_scale: bool = False) -> Dataset:
442+
"""Get the posterior predictive data."""
443+
try:
444+
posterior_predictive_data: Dataset = self.posterior_predictive
445+
446+
except Exception as e:
447+
raise RuntimeError(
448+
"Make sure the model has bin fitted and the posterior predictive has been sampled!"
449+
) from e
450+
451+
if original_scale:
452+
posterior_predictive_data = apply_sklearn_transformer_across_dim(
453+
data=posterior_predictive_data,
454+
func=self.get_target_transformer().inverse_transform,
455+
dim_name="date",
456+
)
457+
return posterior_predictive_data
458+
459+
def _add_mean_to_plot(
460+
self, ax, original_scale: bool = False, color="blue", linestyle="-", **kwargs
461+
) -> plt.Axes:
462+
"""Add mean prediction to existing plot."""
463+
posterior_predictive_data: Dataset = self._get_posterior_predictive_data(
464+
original_scale=original_scale
465+
)
466+
467+
mean_prediction = posterior_predictive_data[self.output_var].mean(
468+
dim=["chain", "draw"]
469+
)
470+
471+
ax.plot(
472+
np.asarray(posterior_predictive_data.date),
473+
mean_prediction,
474+
color=color,
475+
linestyle=linestyle,
476+
label="Mean Prediction",
477+
)
478+
return ax
479+
459480
def get_errors(self, original_scale: bool = False) -> DataArray:
460481
"""Get model errors posterior distribution.
461482

0 commit comments

Comments
 (0)