Skip to content

Commit 84fc81a

Browse files
Updating MMM Budget allocation examples, functionalities and dependencies (#1849)
* initial commit * Update multidimensional MMM example and utils Refactored the multidimensional MMM budget allocation notebook to improve date handling, add xarray usage, and update sample response distribution parameters. Added a new multidimensional_model.nc file, updated utility functions, and adjusted related tests for compatibility with new data structures and outputs. Also modified .gitignore to allow tracking of .nc files. * deprecation warning * Solving tests * Pushing test with error. * Solving test discrepancies between pytensor and sample posterior predictive * Adding new tests * Add multidimensional_model.nc with Git LFS tracking * Solving issues with tests * Changes * remove repetitive front matter. * We don't need to explain about the model builder in this one. People will get that they can load a pre-trained notebook and other parts of the docs can explain how it works. Goal is to get to optimziation as quickly as possible. * remove discussion of other saturation functions. It doesn't add much to the optimization conversation and other parts of the docs should compare the saturation functions. We're assuming the reader already has read the intro notebook so they at least now that saturation functions have diminishing returns, etc. * minor typos. * Adding missing tests suites --------- Co-authored-by: daniel-saunders-phil <[email protected]>
1 parent 7132631 commit 84fc81a

17 files changed

+7056
-2640
lines changed

.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
docs/source/notebooks/mmm/multidimensional_model.nc filter=lfs diff=lfs merge=lfs -text
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
model:
2+
class: pymc_marketing.mmm.multidimensional.MMM
3+
kwargs:
4+
date_column: "date"
5+
target_column: "y"
6+
channel_columns: ["x1", "x2"]
7+
control_columns: ["event_1", "event_2"]
8+
dims: ["geo"]
9+
yearly_seasonality: 2
10+
scaling:
11+
class: pymc_marketing.mmm.scaling.Scaling
12+
kwargs:
13+
channel:
14+
method: "max"
15+
dims: []
16+
target:
17+
method: "max"
18+
dims: []
19+
20+
# --- media transformations ---------------------------------------
21+
adstock:
22+
class: pymc_marketing.mmm.GeometricAdstock
23+
kwargs:
24+
priors:
25+
alpha:
26+
distribution: "Beta"
27+
alpha: 2
28+
beta: 3
29+
dims: "channel"
30+
l_max: 8
31+
32+
saturation:
33+
class: pymc_marketing.mmm.LogisticSaturation
34+
kwargs:
35+
priors:
36+
beta:
37+
distribution: "Gamma"
38+
mu: [0.35, 0.35]
39+
sigma: [0.1, 0.1]
40+
dims: "channel"
41+
lam:
42+
distribution: "Gamma"
43+
mu: 3
44+
sigma: 2
45+
dims: "channel"
46+
47+
48+
# --- model (hierarchical) priors ---------------------------------
49+
model_config:
50+
intercept:
51+
distribution: Normal
52+
mu: 0.5
53+
sigma: 0.5
54+
dims: geo
55+
56+
gamma_control:
57+
distribution: Normal
58+
mu: 0
59+
sigma: 0.5
60+
dims: control
61+
62+
gamma_fourier:
63+
distribution: Laplace
64+
mu: 0
65+
b:
66+
distribution: HalfNormal
67+
sigma: 0.2
68+
dims: [geo, fourier_mode]
69+
70+
likelihood:
71+
distribution: TruncatedNormal
72+
lower: 0
73+
sigma:
74+
distribution: HalfNormal
75+
sigma:
76+
distribution: HalfNormal
77+
sigma: 1.5
78+
dims: [date, geo]
79+
80+
# ----------------------------------------------------------------------
81+
# Effects with complex priors
82+
effects:
83+
- class: pymc_marketing.mmm.additive_effect.LinearTrendEffect
84+
kwargs:
85+
trend:
86+
class: pymc_marketing.mmm.LinearTrend
87+
kwargs:
88+
n_changepoints: 5
89+
include_intercept: false
90+
dims: ["geo"] # Keep as array format
91+
priors:
92+
delta:
93+
distribution: "Laplace"
94+
mu: 0
95+
b:
96+
distribution: "HalfNormal"
97+
sigma: 0.2
98+
dims: ["changepoint", "geo"]
99+
prefix: "trend"
100+
101+
# ----------------------------------------------------------------------
102+
# (optional) sampler options you plan to forward to pm.sample():
103+
sampler_config:
104+
tune: 1000
105+
draws: 200
106+
chains: 8
107+
random_seed: 42
108+
target_accept: 0.90
109+
nuts_sampler: "nutpie"
110+
111+
# ----------------------------------------------------------------------
112+
# (optional) idata from a previous sample
113+
idata_path: "multidimensional_model.nc"
114+
115+
original_scale_vars:
116+
- channel_contribution
117+
- intercept_contribution
118+
- y
119+
120+
# ----------------------------------------------------------------------
121+
# (optional) Data paths
122+
# data:
123+
# X_path: "data/X.csv"
124+
# y_path: "data/y.csv"

0 commit comments

Comments
 (0)