Use different Priors based on channel type #1490
-
Hello, first of all thank you very much for everything. I have my pymc_marketing model working well. To do so, I'm using the following code: from pymc_marketing.mmm import MMM, GeometricAdstock, LogisticSaturation
def generate_first_model_specification(X_train, media_columns: list[str]):
epsilon = 1e-6
spend_shares = (
X_train.melt(
value_vars=media_columns,
var_name='channel',
value_name='spend'
)
.groupby('channel', as_index=False)
.agg({'spend': 'sum'})
.sort_values(by='channel')
.assign(
spend_share=lambda x: (x['spend'] + epsilon) /
(x['spend'].sum() + epsilon * len(x))
)['spend_share']
.to_numpy()
)
prior_sigma = spend_shares
prior_sigma = np.clip(spend_shares, a_min=0.1, a_max=None)
return {
'intercept': Prior('Normal', mu=0.2, sigma=0.05),
'saturation_beta': Prior(
'HalfNormal', sigma=prior_sigma, dims='channel'),
'saturation_lam': Prior('Gamma', alpha=3, beta=1, dims='channel'),
'gamma_control': Prior('Laplace', mu=2, b=0.2),
'gamma_fourier': Prior('Laplace', mu=0, b=1),
'likelihood': Prior('Normal', sigma=Prior('HalfCauchy', beta=0.5)),
'adstock_alpha': Prior('Beta', alpha=2, beta=2, dims='channel'),
}
def build_model(X_train, media_columns: list[str], datetime_column, media_columns, control_columns, fast: bool=False) -> MMM:
model_config = generate_first_model_specification(X_train, media_columns)
return MMM(
model_config=model_config,
sampler_config={
'progressbar': False,
},
time_varying_intercept=not fast,
time_varying_media=False,
date_column=datetime_column,
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
channel_columns=media_columns,
control_columns=control_columns,
yearly_seasonality=8,
) This work perfectly, however with certain complex datasets I'm getting really low accuracy and I need to improve it. To do so I wanted to use different Adstock functions based on the type of channel. Therefore I tried this: from pymc_marketing.mmm import (
DelayedAdstock,
MediaConfig,
MediaConfigList,
MediaTransformation,
GeometricAdstock,
LogisticSaturation,
)
def build_media_configs(media_columns: list[str]) -> MediaConfigList:
possible_offline = {
'tv', 'radio', 'print', 'outdoor', 'ooh', 'exteriores',
'exterior', 'press', 'newspaper', 'magazine', 'direct mail',
'revista', 'revistas', 'correo directo', 'correo', 'mail',
} # Just some offline channels to check
offline_channels = [
channel
for channel in media_columns
if channel.lower() in possible_offline
]
online_channels = [
channel
for channel in media_columns
if channel not in offline_channels
]
media_configs = []
if offline_channels:
media_configs.append(
MediaConfig(
name='offline',
columns=offline_channels,
media_transformation=MediaTransformation(
adstock=DelayedAdstock(
l_max=8, delay=2).set_dims_for_all_priors('offline'),
saturation=LogisticSaturation().set_dims_for_all_priors('offline'),
adstock_first=False,
),
)
)
if online_channels:
media_configs.append(
MediaConfig(
name='online',
columns=online_channels,
media_transformation=MediaTransformation(
adstock=GeometricAdstock(
l_max=8).set_dims_for_all_priors('online'),
saturation=LogisticSaturation().set_dims_for_all_priors('online'),
adstock_first=True,
),
)
)
return media_configs The problem is that I don't know how to include those Can anyone help me with this?Also, if you think that I'm doing something wrong, please let me know. 😄Thank you very much! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
At the moment, the out-of-the-box |
Beta Was this translation helpful? Give feedback.
At the moment, the out-of-the-box
MMM
class doesn't support this functionality. Open to PRs!