Skip to content

Commit 7b56b6e

Browse files
authored
Fourier transformation and HSGPs for streamlit app (#1898)
* adds page for Fourier modes * adds page for HSGPs * gets rid of TODOs * updates README with new page descriptions * updates plot titles
1 parent 9e6643c commit 7b56b6e

File tree

3 files changed

+300
-0
lines changed

3 files changed

+300
-0
lines changed

streamlit/mmm-explainer/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ In this case, you would just need to install the requirements.txt within the str
1414

1515
- **Bayesian Priors**: Interactive charts that demonstrate Bayesian prior distributions, designed to showcase the power of Bayesian methods in handling uncertainty and incorporating prior knowledge into MMM.
1616

17+
- **Fourier Modes Exploration**: Interactive charts that demonstrate Fourier Modes, used to capture seasonal effects. Users can adjust parameters explore differnt periodic trends.
18+
19+
- **Time-Varying Parameters**: Interactive charts that demonstrate Time-Varying effects. Users can adjust parameters to explore different configurations that help capture hidden latent temporal variations within media effects.
20+
1721
- **Customizable Parameters**: All sections of the app include options to customize parameters, allowing users to experiment with different scenarios and understand their impacts on MMM.
1822

1923
## Getting Started
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# Copyright 2022 - 2025 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Streamlit page for fourier modes."""
15+
16+
import plotly.graph_objects as go
17+
18+
import streamlit as st
19+
from pymc_marketing.mmm import MonthlyFourier, YearlyFourier
20+
from pymc_marketing.prior import Prior
21+
22+
# Constants
23+
PLOT_HEIGHT = 500
24+
PLOT_WIDTH = 1000
25+
26+
# -------------------------- TOP OF PAGE INFORMATION -------------------------
27+
28+
# Set browser / tab config
29+
st.set_page_config(
30+
page_title="MMM App - Fourier Modes",
31+
page_icon="🧊",
32+
)
33+
34+
# Give some context for what the page displays
35+
st.title("Fourier Modes")
36+
37+
st.markdown(
38+
"This page demonstrates Fourier seasonality transformations for use \
39+
in MMM. Fourier seasonality relies on sine and cosine \
40+
functions to capture recurring patterns in the data, making it useful \
41+
for modeling periodic trends."
42+
)
43+
44+
st.markdown("___The Fourier component takes the form:___")
45+
46+
# LaTeX string for Fourier seasonal component
47+
fourier_formula = r"""
48+
f(t) = \sum_{k=1}^{K} \Bigg[ a_k \cos\Big(\frac{2 \pi k t}{T}\Big)
49+
+ b_k \sin\Big(\frac{2 \pi k t}{T}\Big) \Bigg]
50+
"""
51+
st.latex(fourier_formula)
52+
53+
st.markdown("""
54+
**Where:**
55+
56+
- $t$ = time index (e.g., day, week, month)
57+
- $T$ = period of the seasonality (e.g., 12 for monthly, 365 for yearly)
58+
- $K$ = order of the Fourier series (number of sine/cosine pairs)
59+
- $a_k, b_k$ = Fourier coefficients
60+
""")
61+
62+
st.markdown(
63+
"🗒️ **Note:** \n \
64+
- Yearly Fourier: A yearly seasonality with a period ($T$) of **_:red[365.25 days]_** \n \
65+
- Monthly Fourier: A monthly seasonality with a period ($T$) of **_:red[365.25 / 12 days]_**"
66+
)
67+
68+
st.divider()
69+
70+
# User inputs
71+
st.subheader(":orange[User Inputs]")
72+
# Slider for selecting the order
73+
n_order = st.slider(
74+
"Fourier order $K$ (n_order)", min_value=1, max_value=20, value=6, step=1
75+
)
76+
# Slider for selecting the scale param
77+
b = st.slider(
78+
"Laplace scale (__b__)", min_value=0.01, max_value=1.0, value=0.1, step=0.01
79+
)
80+
81+
# Setup
82+
prior = Prior("Laplace", mu=0, b=b, dims="fourier")
83+
84+
# Create tabs for plots
85+
tab1, tab2 = st.tabs(["Yearly", "Monthly"])
86+
87+
# -------------------------- YEARLY SEASONALITY -------------------------
88+
with tab1:
89+
st.subheader(":orange[Yearly Seasonality]")
90+
91+
fourier = YearlyFourier(n_order=n_order, prior=prior)
92+
93+
# Displayed in the APP
94+
parameters = fourier.sample_prior()
95+
curve = fourier.sample_curve(parameters)
96+
# Drop chain if it's always 1
97+
curve = curve.squeeze("chain")
98+
# Compute mean and quantiles across draws
99+
mean_trend = curve.mean("draw")
100+
# Grab the days for the x-axis
101+
days = curve.coords["day"].values
102+
103+
# Build Plotly figure
104+
fig = go.Figure()
105+
106+
# Mean line
107+
fig.add_trace(
108+
go.Scatter(
109+
x=days,
110+
y=mean_trend.values,
111+
mode="lines",
112+
line=dict(color="blue"),
113+
name="Mean trend",
114+
)
115+
)
116+
117+
fig.update_layout(
118+
title="Yearly Fourier Trend",
119+
xaxis_title="Day",
120+
yaxis_title="Trend",
121+
height=PLOT_HEIGHT,
122+
width=PLOT_WIDTH,
123+
)
124+
125+
st.plotly_chart(fig, use_container_width=True)
126+
127+
# -------------------------- MONTHLY SEASONALITY -------------------------
128+
with tab2:
129+
st.subheader(":orange[Monthly Seasonality]")
130+
131+
fourier = MonthlyFourier(n_order=n_order, prior=prior)
132+
133+
# Displayed in the APP
134+
parameters = fourier.sample_prior()
135+
curve = fourier.sample_curve(parameters)
136+
# Drop chain if it's always 1
137+
curve = curve.squeeze("chain")
138+
# Compute mean and quantiles across draws
139+
mean_trend = curve.mean("draw")
140+
# Grab the days for the x-axis
141+
days = curve.coords["day"].values
142+
143+
# Build Plotly figure
144+
fig = go.Figure()
145+
146+
# Mean line
147+
fig.add_trace(
148+
go.Scatter(
149+
x=days,
150+
y=mean_trend.values,
151+
mode="lines",
152+
line=dict(color="blue"),
153+
name="Mean trend",
154+
)
155+
)
156+
157+
fig.update_layout(
158+
title="Monthly Fourier Trend",
159+
xaxis_title="Day",
160+
yaxis_title="Trend",
161+
height=PLOT_HEIGHT,
162+
width=PLOT_WIDTH,
163+
)
164+
165+
st.plotly_chart(fig, use_container_width=True)
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright 2022 - 2025 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Streamlit page for HSGP."""
15+
16+
import numpy as np
17+
import pandas as pd
18+
import plotly.graph_objects as go
19+
20+
import streamlit as st
21+
from pymc_marketing.mmm import HSGP
22+
23+
# Constants
24+
PLOT_HEIGHT = 500
25+
PLOT_WIDTH = 1000
26+
SEED = sum(map(ord, "Out of the box GP"))
27+
RNG = np.random.default_rng(SEED)
28+
29+
# -------------------------- TOP OF PAGE INFORMATION -------------------------
30+
31+
# Set browser / tab config
32+
st.set_page_config(
33+
page_title="MMM App - HSGP",
34+
page_icon="🧊",
35+
)
36+
37+
# Give some context for what the page displays
38+
st.title("Time-Varying Parameters")
39+
# TODO: Update this !
40+
st.markdown(
41+
"In real-world scenarios, the effectiveness of marketing activities is not \
42+
static but varies over time due to factors like competitive actions, \
43+
and market dynamics. To account for this, we introduce a time-dependent \
44+
component into the MMM framework using a Gaussian Process, specifically a \
45+
[Hilbert Space GP](https://www.pymc.io/projects/docs/en/stable/api/gp/generated/pymc.gp.HSGP.html). \
46+
This allows us to capture the hidden latent temporal variation of the \
47+
marketing contributions. \
48+
"
49+
)
50+
51+
st.markdown("""
52+
When `time_media_varying` is set to `True`, we capture a single latent \
53+
process that multiplies all channels. We assume all channels \
54+
share the same time-dependent fluctuations, contrasting with \
55+
implementations where each channel has an independent latent \
56+
process. The modified model can be represented as:
57+
""")
58+
59+
tvp_media_formula = r"""
60+
y_{t} = \alpha + \lambda_{t} \cdot \sum_{m=1}^{M}\beta_{m}f(x_{m, t}) \ +
61+
\sum_{c=1}^{C}\gamma_{c}z_{c, t} + \varepsilon_{t},
62+
"""
63+
st.latex(tvp_media_formula)
64+
65+
st.markdown("""
66+
**Where:**
67+
68+
$\\lambda_{t}$ is the time-varying component modeled as a latent process. This shared time-dependent \
69+
variation $\\lambda_{t}$ allows us to capture the overall temporal effects that influence all \
70+
media channels simultaneously.
71+
""")
72+
73+
# Generate some data for the example
74+
n = 52
75+
X = np.arange(n)
76+
77+
st.divider()
78+
79+
# User inputs
80+
st.subheader(":blue[User Inputs]")
81+
82+
# Sliders for params
83+
ls = st.slider("Lengthscale (ls)", min_value=1, max_value=100, value=25, step=1)
84+
eta = st.slider("Variance (eta)", min_value=0.1, max_value=5.0, value=1.0, step=0.1)
85+
m = st.slider(
86+
"The number of basis vectors (m)", min_value=50, max_value=500, value=200, step=10
87+
)
88+
L = st.slider("Boundary condition (L)", min_value=50, max_value=500, value=150, step=10)
89+
90+
# Fixed parameters
91+
dims = "time"
92+
drop_first = False
93+
94+
# Collect kwargs
95+
kwargs = dict(X=X, ls=ls, eta=eta, dims=dims, m=m, L=L, drop_first=drop_first)
96+
97+
hsgp = HSGP(**kwargs)
98+
99+
dates = pd.date_range("2022-01-01", periods=n, freq="W-MON")
100+
coords = {"time": dates}
101+
102+
103+
def sample_curve(hsgp):
104+
"""Use to sample HSGP."""
105+
return hsgp.sample_prior(coords=coords, random_seed=RNG)["f"]
106+
107+
108+
curve = sample_curve(hsgp).rename("False")
109+
curve = curve.squeeze("chain") # drop chain=1
110+
time = curve.coords["time"].values
111+
112+
# Compute posterior mean and credible interval
113+
mean_vals = curve.mean("draw")
114+
115+
fig = go.Figure()
116+
117+
# Mean line
118+
fig.add_trace(
119+
go.Scatter(
120+
x=time, y=mean_vals.values, mode="lines", line=dict(color="blue"), name="Mean"
121+
)
122+
)
123+
124+
fig.update_layout(
125+
title="Time-Dependent Variation",
126+
xaxis_title="Time",
127+
yaxis_title="Value",
128+
template="plotly_white",
129+
)
130+
131+
st.plotly_chart(fig, use_container_width=True)

0 commit comments

Comments
 (0)