Skip to content

Commit 4e0487a

Browse files
authored
support case with additional dimension (#1857)
1 parent 1baa7ca commit 4e0487a

File tree

2 files changed

+55
-21
lines changed

2 files changed

+55
-21
lines changed

pymc_marketing/mmm/components/base.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -368,14 +368,7 @@ def _create_distributions(
368368

369369
dims = dims or self.combined_dims
370370
if idx is not None:
371-
n_idx_dims = len(idx)
372-
dummy_dims = tuple(f"DUMMY_{i}" for i in range(n_idx_dims))
373-
if len(dummy_dims) > 1:
374-
raise NotImplementedError(
375-
"The indexing with multiple dimensions is not supported yet."
376-
)
377-
378-
dims = (*dummy_dims, *dims)
371+
dims = ("N", *dims)
379372

380373
dim_handler = create_dim_handler(dims)
381374

@@ -387,12 +380,11 @@ def create_variable(parameter_name: str, variable_name: str) -> TensorVariable:
387380
var = dist.create_variable(variable_name)
388381

389382
dist_dims = dist.dims
390-
if idx is not None:
383+
if idx is not None and any(dim in idx for dim in dist_dims):
391384
var = index_variable(var, dist.dims, idx)
392385

393-
dist_dims = tuple(
394-
[(dim if dim not in idx else "DUMMY_0") for dim in dist.dims]
395-
)
386+
dist_dims = [dim for dim in dist_dims if dim not in idx]
387+
dist_dims = ("N", *dist_dims)
396388

397389
return dim_handler(var, dist_dims)
398390

tests/mmm/components/test_base.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,22 @@ def create_variable(self, name: str):
483483
@pytest.mark.parametrize(
484484
"var, dims, idx, expected",
485485
[
486+
(
487+
np.arange(2 * 3 * 4).reshape(2, 3, 4),
488+
("geo", "product", "channel"),
489+
{
490+
"geo": [0, 0, 1, 1],
491+
"product": [0, 2, 1, 0],
492+
},
493+
np.array(
494+
[
495+
[0, 1, 2, 3],
496+
[8, 9, 10, 11],
497+
[16, 17, 18, 19],
498+
[12, 13, 14, 15],
499+
]
500+
),
501+
),
486502
(
487503
np.array([[1, 2, 3], [4, 5, 6]]),
488504
("geo", "channel"),
@@ -548,7 +564,7 @@ def test_apply_idx(new_transformation_class) -> None:
548564
)
549565

550566

551-
def test_apply_index_too_many(new_transformation_class) -> None:
567+
def test_apply_idx_more_dims(new_transformation_class) -> None:
552568
instance = new_transformation_class(
553569
priors={
554570
"a": Prior(
@@ -557,20 +573,46 @@ def test_apply_index_too_many(new_transformation_class) -> None:
557573
),
558574
"b": Prior(
559575
"HalfNormal",
560-
dims="channel",
576+
dims=("product", "channel"),
561577
),
562578
}
563579
)
564580

581+
X = np.array(
582+
[
583+
[0, 0, 0],
584+
[1, 1, 1],
585+
[2, 2, 2],
586+
[0, 0, 0],
587+
[1, 1, 1],
588+
[2, 2, 2],
589+
]
590+
)
591+
565592
coords = {
566593
"geo": ["A", "B"],
567594
"product": ["X", "Y", "Z"],
568595
"channel": ["TV", "Radio", "Online"],
569596
}
570-
with pm.Model(coords=coords):
571-
idx = {
572-
"geo": [0, 0, 0, 1, 1, 1],
573-
"product": [0, 1, 2, 0, 1, 2],
574-
}
575-
with pytest.raises(NotImplementedError, match="The indexing"):
576-
instance.apply(None, idx=idx, dims="channel")
597+
with pm.Model(coords=coords) as model:
598+
geo_idx = [0, 0, 0, 1, 1, 1]
599+
product_idx = [0, 2, 1, 0, 1, 0]
600+
Y = instance.apply(
601+
X,
602+
idx={
603+
"geo": geo_idx,
604+
"product": product_idx,
605+
},
606+
dims="channel",
607+
)
608+
609+
expected = instance.function(
610+
X,
611+
a=model["new_a"][geo_idx, product_idx, None],
612+
b=model["new_b"][product_idx],
613+
)
614+
615+
np.testing.assert_allclose(
616+
Y.eval(),
617+
expected.eval(),
618+
)

0 commit comments

Comments
 (0)