Skip to content

Commit 33d486a

Browse files
committed
POC implement MultiDimensionalMMM with dimmed variables
1 parent 216136b commit 33d486a

File tree

8 files changed

+606
-261
lines changed

8 files changed

+606
-261
lines changed

pymc_marketing/mmm/components/adstock.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,14 @@ def function(self, x, alpha):
5656
from __future__ import annotations
5757

5858
import numpy as np
59-
import pytensor.tensor as pt
6059
import xarray as xr
6160
from pydantic import Field, validate_call
6261
from pymc_extras.deserialize import deserialize, register_deserialization
6362
from pymc_extras.prior import Prior
63+
from pytensor.tensor import TensorVariable
64+
from pytensor.xtensor import as_xtensor
6465

66+
# from pymc.dims.distributions import Weibull # TODO
6567
from pymc_marketing.mmm.components.base import (
6668
SupportedPrior,
6769
Transformation,
@@ -198,10 +200,15 @@ class BinomialAdstock(AdstockTransformation):
198200

199201
lookup_name = "binomial"
200202

201-
def function(self, x, alpha):
203+
def function(self, x, alpha, *, dim: str):
202204
"""Binomial adstock function."""
203205
return binomial_adstock(
204-
x, alpha=alpha, l_max=self.l_max, normalize=self.normalize, mode=self.mode
206+
x,
207+
alpha=alpha,
208+
l_max=self.l_max,
209+
normalize=self.normalize,
210+
mode=self.mode,
211+
dim=dim,
205212
)
206213

207214
default_priors = {"alpha": Prior("Beta", alpha=1, beta=3)}
@@ -231,10 +238,15 @@ class GeometricAdstock(AdstockTransformation):
231238

232239
lookup_name = "geometric"
233240

234-
def function(self, x, alpha):
241+
def function(self, x, alpha, *, dim: str):
235242
"""Geometric adstock function."""
236243
return geometric_adstock(
237-
x, alpha=alpha, l_max=self.l_max, normalize=self.normalize, mode=self.mode
244+
x,
245+
alpha=alpha,
246+
l_max=self.l_max,
247+
normalize=self.normalize,
248+
mode=self.mode,
249+
dim=dim,
238250
)
239251

240252
default_priors = {"alpha": Prior("Beta", alpha=1, beta=3)}
@@ -264,7 +276,7 @@ class DelayedAdstock(AdstockTransformation):
264276

265277
lookup_name = "delayed"
266278

267-
def function(self, x, alpha, theta):
279+
def function(self, x, alpha, theta, *, dim: str):
268280
"""Delayed adstock function."""
269281
return delayed_adstock(
270282
x,
@@ -273,6 +285,7 @@ def function(self, x, alpha, theta):
273285
l_max=self.l_max,
274286
normalize=self.normalize,
275287
mode=self.mode,
288+
dim=dim,
276289
)
277290

278291
default_priors = {
@@ -305,7 +318,7 @@ class WeibullPDFAdstock(AdstockTransformation):
305318

306319
lookup_name = "weibull_pdf"
307320

308-
def function(self, x, lam, k):
321+
def function(self, x, lam, k, *, dim: str):
309322
"""Weibull adstock function."""
310323
return weibull_adstock(
311324
x=x,
@@ -315,6 +328,7 @@ def function(self, x, lam, k):
315328
mode=self.mode,
316329
type=WeibullType.PDF,
317330
normalize=self.normalize,
331+
dim=dim,
318332
)
319333

320334
default_priors = {
@@ -347,7 +361,7 @@ class WeibullCDFAdstock(AdstockTransformation):
347361

348362
lookup_name = "weibull_cdf"
349363

350-
def function(self, x, lam, k):
364+
def function(self, x, lam, k, *, dim: str):
351365
"""Weibull adstock function."""
352366
return weibull_adstock(
353367
x=x,
@@ -357,6 +371,7 @@ def function(self, x, lam, k):
357371
mode=self.mode,
358372
type=WeibullType.CDF,
359373
normalize=self.normalize,
374+
dim=dim,
360375
)
361376

362377
default_priors = {
@@ -370,9 +385,10 @@ class NoAdstock(AdstockTransformation):
370385

371386
lookup_name: str = "no_adstock"
372387

373-
def function(self, x):
388+
def function(self, x, *, dim: str | None = None) -> TensorVariable:
374389
"""No adstock function."""
375-
return pt.as_tensor_variable(x)
390+
x = as_xtensor(x)
391+
return x
376392

377393
default_priors = {}
378394

pymc_marketing/mmm/components/base.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,18 @@
2929
import numpy as np
3030
import numpy.typing as npt
3131
import pymc as pm
32+
import pymc.dims as pmd
3233
import xarray as xr
3334
from matplotlib.axes import Axes
3435
from matplotlib.figure import Figure
3536
from pydantic import InstanceOf
3637
from pymc.distributions.shape_utils import Dims
37-
from pymc_extras.prior import Prior, VariableFactory, create_dim_handler
38+
from pymc_extras.prior import Prior, VariableFactory
3839
from pytensor import tensor as pt
3940
from pytensor.tensor.variable import TensorVariable
41+
from pytensor.xtensor import as_xtensor
4042

43+
from pymc_marketing.mmm.dims import XPrior
4144
from pymc_marketing.model_config import parse_model_config
4245
from pymc_marketing.plot import (
4346
SelToString,
@@ -220,13 +223,13 @@ def priors(self) -> dict[str, SupportedPrior]:
220223
return self.function_priors
221224

222225
@function_priors.setter # type: ignore
223-
def function_priors(self, priors: dict[str, Any | Prior] | None) -> None:
226+
def function_priors(self, priors: dict[str, Any | XPrior] | None) -> None:
224227
priors = priors or {}
225228

226229
non_distributions = [
227230
key
228231
for key, value in priors.items()
229-
if not isinstance(value, Prior) and not isinstance(value, dict)
232+
if not isinstance(value, XPrior) and not isinstance(value, dict)
230233
]
231234

232235
priors = parse_model_config(priors, non_distributions=non_distributions)
@@ -313,9 +316,10 @@ def _has_defaults_for_all_arguments(self) -> None:
313316
function_signature = signature(self.function)
314317

315318
# Remove the first one as assumed to be the data
319+
# And the last one, as assumed to be dim
316320
parameters_that_need_priors = set(
317321
list(function_signature.parameters.keys())[1:]
318-
)
322+
) - {"dim"}
319323
parameters_with_priors = set(self.default_priors.keys())
320324

321325
missing_priors = parameters_that_need_priors - parameters_with_priors
@@ -379,8 +383,6 @@ def _create_distributions(
379383
if idx is not None:
380384
dims = ("N", *dims)
381385

382-
dim_handler = create_dim_handler(dims)
383-
384386
def create_variable(parameter_name: str, variable_name: str) -> TensorVariable:
385387
dist = self.function_priors[parameter_name]
386388
if not hasattr(dist, "create_variable"):
@@ -392,10 +394,7 @@ def create_variable(parameter_name: str, variable_name: str) -> TensorVariable:
392394
if idx is not None and any(dim in idx for dim in dist_dims):
393395
var = index_variable(var, dist.dims, idx)
394396

395-
dist_dims = [dim for dim in dist_dims if dim not in idx]
396-
dist_dims = ("N", *dist_dims)
397-
398-
return dim_handler(var, dist_dims)
397+
return var
399398

400399
return {
401400
parameter_name: create_variable(parameter_name, variable_name)
@@ -521,16 +520,19 @@ def _sample_curve(
521520
}
522521
)
523522

523+
x = as_xtensor(x, dims=(x_dim,))
524+
524525
with pm.Model(coords=coords):
525-
pm.Deterministic(
526+
pmd.Deterministic(
526527
var_name,
527-
self.apply(x, dims=output_core_dims),
528+
self.apply(x, dims=output_core_dims, core_dim=x_dim),
528529
dims=(x_dim, *output_core_dims),
529530
)
530531

531532
return pm.sample_posterior_predictive(
532533
parameters,
533534
var_names=[var_name],
535+
progressbar=False,
534536
).posterior_predictive[var_name]
535537

536538
def plot_curve_samples(
@@ -616,7 +618,9 @@ def plot_curve_hdi(
616618
def apply(
617619
self,
618620
x: pt.TensorLike,
621+
*,
619622
dims: Dims | None = None,
623+
core_dim: str,
620624
idx: dict[str, pt.TensorLike] | None = None,
621625
) -> TensorVariable:
622626
"""Call within a model context.
@@ -646,13 +650,18 @@ def apply(
646650
647651
transformation = ...
648652
649-
coords = {"channel": ["TV", "Radio", "Digital"]}
653+
coords = {
654+
"channel": ["TV", "Radio", "Digital"],
655+
"date": range(10),
656+
}
650657
with pm.Model(coords=coords):
651-
transformed_data = transformation.apply(data, dims="channel")
658+
transformed_data = transformation.apply(
659+
data, dims="channel", core_dim="date"
660+
)
652661
653662
"""
654663
kwargs = self._create_distributions(dims=dims, idx=idx)
655-
return self.function(x, **kwargs)
664+
return self.function(x, dim=core_dim, **kwargs)
656665

657666

658667
def _serialize_value(value: Any) -> Any:

0 commit comments

Comments
 (0)