Skip to content

Commit 85108b7

Browse files
Benjamin T. Vincentdrbenvincent
andauthored
update rolling regression notebook to v4 (pymc-devs#352)
* create truncated regression example * delete truncated regression example from main branch * create truncated regression example * delete truncated regression example from main branch * create truncated regression example * delete truncated regression example from main branch * fix incorrect statement about pm.NormalMixture * update to v4 * remove 2 stale imports * remove version print + suppress matplotlib warnings * fixed? * how about now? Co-authored-by: Benjamin T. Vincent <[email protected]>
1 parent 1202cb5 commit 85108b7

File tree

2 files changed

+174
-134
lines changed

2 files changed

+174
-134
lines changed

examples/generalized_linear_models/GLM-rolling-regression.ipynb

Lines changed: 130 additions & 105 deletions
Large diffs are not rendered by default.

myst_nbs/generalized_linear_models/GLM-rolling-regression.myst.md

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,18 @@ jupytext:
66
format_version: 0.13
77
jupytext_version: 1.13.7
88
kernelspec:
9-
display_name: Python 3
9+
display_name: Python 3 (ipykernel)
1010
language: python
1111
name: python3
1212
---
1313

14+
(GLM-rolling-regression)=
1415
# Rolling Regression
1516

16-
:::{post} Sept 15, 2021
17-
:tags: generalized linear model, pymc3.Exponential, pymc3.GaussianRandomWalk, pymc3.HalfCauchy, pymc3.HalfNormal, pymc3.Model, pymc3.Normal, regression
17+
:::{post} June, 2022
18+
:tags: generalized linear model, regression
1819
:category: intermediate
20+
:author: Thomas Wiecki, Benjamin T. Vincent
1921
:::
2022

2123
+++
@@ -25,28 +27,25 @@ kernelspec:
2527
* One common example is the price of gold (GLD) and the price of gold mining operations (GFI).
2628

2729
```{code-cell} ipython3
28-
%matplotlib inline
29-
3030
import os
31+
import warnings
3132
3233
import arviz as az
33-
import bambi as bmb
3434
import matplotlib.pyplot as plt
3535
import matplotlib.ticker as mticker
3636
import numpy as np
3737
import pandas as pd
38-
import pymc3 as pm
38+
import pymc as pm
3939
import xarray as xr
4040
41-
from matplotlib import cm
41+
from matplotlib import MatplotlibDeprecationWarning
4242
43-
print(f"Running on PyMC3 v{pm.__version__}")
43+
warnings.filterwarnings(action="ignore", category=MatplotlibDeprecationWarning)
4444
```
4545

4646
```{code-cell} ipython3
4747
RANDOM_SEED = 8927
4848
rng = np.random.default_rng(RANDOM_SEED)
49-
5049
%config InlineBackend.figure_format = 'retina'
5150
az.style.use("arviz-darkgrid")
5251
```
@@ -84,19 +83,19 @@ cb.ax.set_yticklabels(ticklabels);
8483
A naive approach would be to estimate a linear model and ignore the time domain.
8584

8685
```{code-cell} ipython3
87-
with pm.Model() as model: # model specifications in PyMC3 are wrapped in a with-statement
86+
with pm.Model() as model: # model specifications in PyMC are wrapped in a with-statement
8887
# Define priors
89-
sigma = pm.HalfCauchy("sigma", beta=10, testval=1.0)
88+
sigma = pm.HalfCauchy("sigma", beta=10)
9089
alpha = pm.Normal("alpha", mu=0, sigma=20)
9190
beta = pm.Normal("beta", mu=0, sigma=20)
9291
92+
mu = pm.Deterministic("mu", alpha + beta * prices_zscored.GFI.to_numpy())
93+
9394
# Define likelihood
94-
likelihood = pm.Normal(
95-
"y", mu=alpha + beta * prices_zscored.GFI, sigma=sigma, observed=prices_zscored.GLD
96-
)
95+
likelihood = pm.Normal("y", mu=mu, sigma=sigma, observed=prices_zscored.GLD.to_numpy())
9796
9897
# Inference
99-
trace_reg = pm.sample(tune=2000, return_inferencedata=True)
98+
trace_reg = pm.sample(tune=2000)
10099
```
101100

102101
The posterior predictive plot shows how bad the fit is.
@@ -110,16 +109,29 @@ ax = fig.add_subplot(
110109
title="Posterior predictive regression lines",
111110
)
112111
sc = ax.scatter(prices_zscored.GFI, prices_zscored.GLD, c=colors, cmap=mymap, lw=0)
113-
pm.plot_posterior_predictive_glm(
114-
trace_reg,
115-
samples=100,
116-
label="posterior predictive regression lines",
117-
lm=lambda x, sample: sample["alpha"] + sample["beta"] * x,
118-
eval=np.linspace(prices_zscored.GFI.min(), prices_zscored.GFI.max(), 100),
112+
113+
xi = xr.DataArray(prices_zscored.GFI.values)
114+
az.plot_hdi(
115+
xi,
116+
trace_reg.posterior.mu,
117+
color="k",
118+
hdi_prob=0.95,
119+
ax=ax,
120+
fill_kwargs={"alpha": 0.25},
121+
smooth=False,
119122
)
123+
az.plot_hdi(
124+
xi,
125+
trace_reg.posterior.mu,
126+
color="k",
127+
hdi_prob=0.5,
128+
ax=ax,
129+
fill_kwargs={"alpha": 0.25},
130+
smooth=False,
131+
)
132+
120133
cb = plt.colorbar(sc, ticks=ticks)
121-
cb.ax.set_yticklabels(ticklabels)
122-
ax.legend(loc=0);
134+
cb.ax.set_yticklabels(ticklabels);
123135
```
124136

125137
## Rolling regression
@@ -148,18 +160,18 @@ Perform the regression given coefficients and data and link to the data via the
148160
```{code-cell} ipython3
149161
with model_randomwalk:
150162
# Define regression
151-
regression = alpha + beta * prices_zscored.GFI
163+
regression = alpha + beta * prices_zscored.GFI.values
152164
153165
# Assume prices are Normally distributed, the mean comes from the regression.
154166
sd = pm.HalfNormal("sd", sigma=0.1)
155-
likelihood = pm.Normal("y", mu=regression, sigma=sd, observed=prices_zscored.GLD)
167+
likelihood = pm.Normal("y", mu=regression, sigma=sd, observed=prices_zscored.GLD.to_numpy())
156168
```
157169

158170
Inference. Despite this being quite a complex model, NUTS handles it wells.
159171

160172
```{code-cell} ipython3
161173
with model_randomwalk:
162-
trace_rw = pm.sample(tune=2000, cores=4, target_accept=0.9, return_inferencedata=True)
174+
trace_rw = pm.sample(tune=2000, target_accept=0.9)
163175
```
164176

165177
Increasing the tree-depth does indeed help but it makes sampling very slow. The results look identical with this run, however.
@@ -231,13 +243,16 @@ cb = plt.colorbar(sc, ticks=ticks)
231243
cb.ax.set_yticklabels(ticklabels);
232244
```
233245

234-
Author: Thomas Wiecki
246+
## Authors
247+
248+
- Created by [Thomas Wiecki](https://github.com/twiecki/)
249+
- Updated by [Benjamin T. Vincent](https://github.com/drbenvincent) June 2022
235250

236251
+++
237252

238253
## Watermark
239254

240255
```{code-cell} ipython3
241256
%load_ext watermark
242-
%watermark -n -u -v -iv -w -p theano
257+
%watermark -n -u -v -iv -w -p aesara,aeppl,xarray
243258
```

0 commit comments

Comments
 (0)