Skip to content

Commit 270b0c7

Browse files
Hotfix - Fix dimension assertion for adding multidimensional event effects to multidimensional MMM (#1969)
* Fix dimension validation for multidimensional events. * Add a test for adding multidimensional events to multidimensional.MMM
1 parent af9799f commit 270b0c7

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

pymc_marketing/mmm/multidimensional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ def add_events(
490490
If the event effect dimensions do not contain the prefix and model dimensions.
491491
492492
"""
493-
if not set(effect.dims).issubset((prefix, self.dims)):
493+
if not set(effect.dims).issubset((prefix, *self.dims)):
494494
raise ValueError(
495495
f"Event effect dims {effect.dims} must contain {prefix} and {self.dims}"
496496
)

tests/mmm/test_multidimensional.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -813,12 +813,13 @@ def create(
813813
prefix: str = "holiday",
814814
sigma_dims: str | None = None,
815815
effect_size: Prior | None = None,
816+
dims: tuple[str] | str | None = None,
816817
):
817818
basis = GaussianBasis()
818819
return EventEffect(
819820
basis=basis,
820821
effect_size=Prior("Normal"),
821-
dims=(prefix,),
822+
dims=dims or (prefix,),
822823
)
823824

824825
return create
@@ -882,10 +883,14 @@ def test_mmm_with_events(
882883
)
883884
assert len(mmm.mu_effects) == 1
884885

886+
df_events_with_country = df_events.copy()
887+
df_events_with_country["country"] = "A"
885888
mmm.add_events(
886-
df_events,
889+
df_events_with_country,
887890
prefix="another_event_type",
888-
effect=create_event_effect(prefix="another_event_type"),
891+
effect=create_event_effect(
892+
prefix="another_event_type", dims=("country", "another_event_type")
893+
),
889894
)
890895
assert len(mmm.mu_effects) == 2
891896

0 commit comments

Comments
 (0)