47
47
)
48
48
from pymc_marketing .mmm .events import EventEffect
49
49
from pymc_marketing .mmm .fourier import YearlyFourier
50
+ from pymc_marketing .mmm .hsgp import HSGPBase
50
51
from pymc_marketing .mmm .lift_test import (
51
52
add_lift_measurements_to_likelihood_from_saturation ,
52
53
scale_lift_measurements ,
@@ -115,7 +116,7 @@ class MMM(ModelBuilder):
115
116
"""
116
117
117
118
_model_type : str = "MMMM (Multi-Dimensional Marketing Mix Model)"
118
- version : str = "0.0.1 "
119
+ version : str = "0.0.2 "
119
120
120
121
@validate_call
121
122
def __init__ (
@@ -137,8 +138,13 @@ def __init__(
137
138
Field (strict = True , description = "Whether to use a time-varying intercept" ),
138
139
] = False ,
139
140
time_varying_media : Annotated [
140
- bool ,
141
- Field (strict = True , description = "Whether to use time-varying media effects" ),
141
+ bool | InstanceOf [HSGPBase ],
142
+ Field (
143
+ description = (
144
+ "Whether to use time-varying media effects, or pass an HSGP instance "
145
+ "(e.g., SoftPlusHSGP) specifying dims and priors."
146
+ ),
147
+ ),
142
148
] = False ,
143
149
dims : tuple [str , ...] | None = Field (
144
150
None , description = "Additional dimensions for the model."
@@ -757,7 +763,7 @@ def _generate_and_preprocess_model_data(
757
763
for dim in self .xarray_dataset .coords .dims
758
764
}
759
765
760
- if self .time_varying_intercept | self .time_varying_media :
766
+ if bool ( self .time_varying_intercept ) or bool ( self .time_varying_media ) :
761
767
self ._time_index = np .arange (0 , X [self .date_column ].unique ().shape [0 ])
762
768
self ._time_index_mid = X [self .date_column ].unique ().shape [0 ] // 2
763
769
self ._time_resolution = (
@@ -1036,7 +1042,7 @@ def build_model(
1036
1042
for mu_effect in self .mu_effects :
1037
1043
mu_effect .create_data (self )
1038
1044
1039
- if self .time_varying_intercept | self .time_varying_media :
1045
+ if bool ( self .time_varying_intercept ) or bool ( self .time_varying_media ) :
1040
1046
time_index = pm .Data (
1041
1047
name = "time_index" ,
1042
1048
value = self ._time_index ,
@@ -1066,7 +1072,7 @@ def build_model(
1066
1072
)
1067
1073
1068
1074
# Add media logic
1069
- if self .time_varying_media :
1075
+ if isinstance ( self . time_varying_media , bool ) and self .time_varying_media :
1070
1076
baseline_channel_contribution = pm .Deterministic (
1071
1077
name = "baseline_channel_contribution" ,
1072
1078
var = self .forward_pass (
@@ -1079,13 +1085,44 @@ def build_model(
1079
1085
X = time_index ,
1080
1086
dims = ("date" , * self .dims ),
1081
1087
** self .model_config ["media_tvp_config" ],
1082
- ).create_variable ("media_latent_process " )
1088
+ ).create_variable ("media_temporal_latent_multiplier " )
1083
1089
1084
1090
channel_contribution = pm .Deterministic (
1085
1091
name = "channel_contribution" ,
1086
1092
var = baseline_channel_contribution * media_latent_process [..., None ],
1087
1093
dims = ("date" , * self .dims , "channel" ),
1088
1094
)
1095
+ elif isinstance (self .time_varying_media , HSGPBase ):
1096
+ baseline_channel_contribution = self .forward_pass (
1097
+ x = channel_data_ , dims = (* self .dims , "channel" )
1098
+ )
1099
+ baseline_channel_contribution .name = "baseline_channel_contribution"
1100
+ baseline_channel_contribution .dims = (
1101
+ "date" ,
1102
+ * self .dims ,
1103
+ "channel" ,
1104
+ )
1105
+
1106
+ # Register internal time index and build latent process
1107
+ self .time_varying_media .register_data (time_index )
1108
+ media_latent_process = self .time_varying_media .create_variable (
1109
+ "media_temporal_latent_multiplier"
1110
+ )
1111
+
1112
+ # Determine broadcasting over channel axis
1113
+ media_dims = pm .modelcontext (None ).named_vars_to_dims [
1114
+ media_latent_process .name
1115
+ ]
1116
+ if "channel" in media_dims :
1117
+ media_broadcast = media_latent_process
1118
+ else :
1119
+ media_broadcast = media_latent_process [..., None ]
1120
+
1121
+ channel_contribution = pm .Deterministic (
1122
+ name = "channel_contribution" ,
1123
+ var = baseline_channel_contribution * media_broadcast ,
1124
+ dims = ("date" , * self .dims , "channel" ),
1125
+ )
1089
1126
else :
1090
1127
channel_contribution = pm .Deterministic (
1091
1128
name = "channel_contribution" ,
@@ -1681,7 +1718,7 @@ def add_lift_test_measurements(
1681
1718
# This is coupled with the name of the
1682
1719
# latent process Deterministic
1683
1720
time_varying_var_name = (
1684
- "media_latent_process " if self .time_varying_media else None
1721
+ "media_temporal_latent_multiplier " if self .time_varying_media else None
1685
1722
)
1686
1723
add_lift_measurements_to_likelihood_from_saturation (
1687
1724
df_lift_test = df_lift_test_scaled ,
0 commit comments