File tree Expand file tree Collapse file tree 2 files changed +9
-4
lines changed Expand file tree Collapse file tree 2 files changed +9
-4
lines changed Original file line number Diff line number Diff line change @@ -490,7 +490,7 @@ def add_events(
490
490
If the event effect dimensions do not contain the prefix and model dimensions.
491
491
492
492
"""
493
- if not set (effect .dims ).issubset ((prefix , self .dims )):
493
+ if not set (effect .dims ).issubset ((prefix , * self .dims )):
494
494
raise ValueError (
495
495
f"Event effect dims { effect .dims } must contain { prefix } and { self .dims } "
496
496
)
Original file line number Diff line number Diff line change @@ -813,12 +813,13 @@ def create(
813
813
prefix : str = "holiday" ,
814
814
sigma_dims : str | None = None ,
815
815
effect_size : Prior | None = None ,
816
+ dims : tuple [str ] | str | None = None ,
816
817
):
817
818
basis = GaussianBasis ()
818
819
return EventEffect (
819
820
basis = basis ,
820
821
effect_size = Prior ("Normal" ),
821
- dims = (prefix ,),
822
+ dims = dims or (prefix ,),
822
823
)
823
824
824
825
return create
@@ -882,10 +883,14 @@ def test_mmm_with_events(
882
883
)
883
884
assert len (mmm .mu_effects ) == 1
884
885
886
+ df_events_with_country = df_events .copy ()
887
+ df_events_with_country ["country" ] = "A"
885
888
mmm .add_events (
886
- df_events ,
889
+ df_events_with_country ,
887
890
prefix = "another_event_type" ,
888
- effect = create_event_effect (prefix = "another_event_type" ),
891
+ effect = create_event_effect (
892
+ prefix = "another_event_type" , dims = ("country" , "another_event_type" )
893
+ ),
889
894
)
890
895
assert len (mmm .mu_effects ) == 2
891
896
You can’t perform that action at this time.
0 commit comments