Skip to content

Commit fb00872

Browse files
authored
Allow for indexing of apply parameters (#1847)
1 parent e4c7143 commit fb00872

File tree

2 files changed

+153
-5
lines changed

2 files changed

+153
-5
lines changed

pymc_marketing/mmm/components/base.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,27 @@ def __init__(self) -> None:
8888
super().__init__(msg)
8989

9090

91+
def index_variable(var, dims, idx) -> TensorVariable:
92+
"""Index a variable based on the provided dimensions and index.
93+
94+
Parameters
95+
----------
96+
var : TensorVariable
97+
The variable to index.
98+
dims : tuple[str, ...]
99+
The dims of the variable.
100+
idx : dict[str, pt.TensorLike]
101+
The index to use for the variable.
102+
103+
Returns
104+
-------
105+
TensorVariable
106+
The indexed variable.
107+
108+
"""
109+
return var[tuple(idx[dim] if dim in idx else slice(None) for dim in dims)]
110+
111+
91112
class Transformation:
92113
"""Base class for adstock and saturation functions.
93114
@@ -338,17 +359,42 @@ def _infer_output_core_dims(self) -> tuple[str, ...]:
338359
return tuple(list({str(dim): None for dims in parameter_dims for dim in dims}))
339360

340361
def _create_distributions(
341-
self, dims: Dims | None = None
362+
self,
363+
dims: Dims | None = None,
364+
idx: dict[str, pt.TensorLike] | None = None,
342365
) -> dict[str, TensorVariable]:
343-
dim_handler = create_dim_handler(dims or self._infer_output_core_dims())
366+
if isinstance(dims, str):
367+
dims = (dims,)
368+
369+
dims = dims or self.combined_dims
370+
if idx is not None:
371+
n_idx_dims = len(idx)
372+
dummy_dims = tuple(f"DUMMY_{i}" for i in range(n_idx_dims))
373+
if len(dummy_dims) > 1:
374+
raise NotImplementedError(
375+
"The indexing with multiple dimensions is not supported yet."
376+
)
377+
378+
dims = (*dummy_dims, *dims)
379+
380+
dim_handler = create_dim_handler(dims)
344381

345382
def create_variable(parameter_name: str, variable_name: str) -> TensorVariable:
346383
dist = self.function_priors[parameter_name]
347384
if not hasattr(dist, "create_variable"):
348385
return dist
349386

350387
var = dist.create_variable(variable_name)
351-
return dim_handler(var, dist.dims)
388+
389+
dist_dims = dist.dims
390+
if idx is not None:
391+
var = index_variable(var, dist.dims, idx)
392+
393+
dist_dims = tuple(
394+
[(dim if dim not in idx else "DUMMY_0") for dim in dist.dims]
395+
)
396+
397+
return dim_handler(var, dist_dims)
352398

353399
return {
354400
parameter_name: create_variable(parameter_name, variable_name)
@@ -566,7 +612,12 @@ def plot_curve_hdi(
566612
hdi_kwargs=hdi_kwargs,
567613
)
568614

569-
def apply(self, x: pt.TensorLike, dims: Dims | None = None) -> TensorVariable:
615+
def apply(
616+
self,
617+
x: pt.TensorLike,
618+
dims: Dims | None = None,
619+
idx: dict[str, pt.TensorLike] | None = None,
620+
) -> TensorVariable:
570621
"""Call within a model context.
571622
572623
Used internally of the MMM to apply the transformation to the data.
@@ -599,7 +650,7 @@ def apply(self, x: pt.TensorLike, dims: Dims | None = None) -> TensorVariable:
599650
transformed_data = transformation.apply(data, dims="channel")
600651
601652
"""
602-
kwargs = self._create_distributions(dims=dims)
653+
kwargs = self._create_distributions(dims=dims, idx=idx)
603654
return self.function(x, **kwargs)
604655

605656

tests/mmm/components/test_base.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
ParameterPriorException,
2727
Transformation,
2828
create_registration_meta,
29+
index_variable,
2930
)
3031
from pymc_marketing.mmm.components.saturation import TanhSaturation
3132

@@ -477,3 +478,99 @@ def create_variable(self, name: str):
477478

478479
curve = saturation.sample_curve(prior, 10)
479480
assert curve.dims == ("chain", "draw", "x", "dim_a")
481+
482+
483+
@pytest.mark.parametrize(
484+
"var, dims, idx, expected",
485+
[
486+
(
487+
np.array([[1, 2, 3], [4, 5, 6]]),
488+
("geo", "channel"),
489+
{"geo": [0, 0, 1, 1]},
490+
np.array(
491+
[
492+
[1, 2, 3],
493+
[1, 2, 3],
494+
[4, 5, 6],
495+
[4, 5, 6],
496+
]
497+
),
498+
),
499+
],
500+
)
501+
def test_index_variable(var, dims, idx, expected) -> None:
502+
result = index_variable(var, dims=dims, idx=idx)
503+
if isinstance(result, TensorVariable):
504+
result = result.eval()
505+
506+
np.testing.assert_allclose(result, expected)
507+
508+
509+
def test_apply_idx(new_transformation_class) -> None:
510+
instance = new_transformation_class(
511+
priors={
512+
"a": Prior(
513+
"HalfNormal",
514+
dims="geo",
515+
),
516+
"b": Prior(
517+
"HalfNormal",
518+
dims="channel",
519+
),
520+
}
521+
)
522+
523+
X = np.array(
524+
[
525+
[0, 0, 0],
526+
[1, 1, 1],
527+
[2, 2, 2],
528+
[0, 0, 0],
529+
[1, 1, 1],
530+
[2, 2, 2],
531+
]
532+
)
533+
534+
coords = {"geo": ["A", "B"], "channel": ["TV", "Radio", "Online"]}
535+
with pm.Model(coords=coords) as model:
536+
idx = [0, 0, 0, 1, 1, 1]
537+
Y = instance.apply(X, idx={"geo": idx}, dims="channel")
538+
539+
expected = instance.function(
540+
X,
541+
a=model["new_a"][idx, None],
542+
b=model["new_b"],
543+
)
544+
545+
np.testing.assert_allclose(
546+
Y.eval(),
547+
expected.eval(),
548+
)
549+
550+
551+
def test_apply_index_too_many(new_transformation_class) -> None:
552+
instance = new_transformation_class(
553+
priors={
554+
"a": Prior(
555+
"HalfNormal",
556+
dims=("geo", "product"),
557+
),
558+
"b": Prior(
559+
"HalfNormal",
560+
dims="channel",
561+
),
562+
}
563+
)
564+
565+
coords = {
566+
"geo": ["A", "B"],
567+
"product": ["X", "Y", "Z"],
568+
"channel": ["TV", "Radio", "Online"],
569+
}
570+
with pm.Model(coords=coords):
571+
idx = {
572+
"geo": [0, 0, 0, 1, 1, 1],
573+
"product": [0, 1, 2, 0, 1, 2],
574+
}
575+
with pytest.raises(NotImplementedError, match="The indexing"):
576+
instance.apply(None, idx=idx, dims="channel")

0 commit comments

Comments
 (0)