Skip to content

Commit aad143f

Browse files
committed
add random seed, model maths, use Gammas for sigma, simplify plot code
1 parent 05992cb commit aad143f

File tree

2 files changed

+489
-354
lines changed

2 files changed

+489
-354
lines changed

examples/generalized_linear_models/GLM-simpsons-paradox.ipynb

Lines changed: 389 additions & 315 deletions
Large diffs are not rendered by default.

examples/generalized_linear_models/GLM-simpsons-paradox.myst.md

Lines changed: 100 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ kernelspec:
1313
(GLM-simpsons-paradox)=
1414
# Simpson's paradox and mixed models
1515

16-
:::{post} March, 2022
16+
:::{post} September, 2024
1717
:tags: regression, hierarchical model, linear model, posterior predictive, Simpson's paradox
1818
:category: beginner
1919
:author: Benjamin T. Vincent
@@ -109,13 +109,26 @@ First we examine the simplest model - plain linear regression which pools all th
109109

110110
+++
111111

112+
We could describe this model mathematically as:
113+
114+
$$
115+
\begin{aligned}
116+
\beta_0, \beta_1 &\sim \text{Normal}(0, 5) \\
117+
\sigma &\sim \text{Gamma}(2, 2) \\
118+
\mu_i &= \beta_0 + \beta_1 x_i \\
119+
y_i &\sim \text{Normal}(\mu_i, \sigma)
120+
\end{aligned}
121+
$$
122+
123+
+++
124+
112125
### Build model
113126

114127
```{code-cell} ipython3
115128
with pm.Model() as linear_regression:
116-
sigma = pm.HalfCauchy("sigma", beta=2)
117129
β0 = pm.Normal("β0", 0, sigma=5)
118130
β1 = pm.Normal("β1", 0, sigma=5)
131+
sigma = pm.Gamma("sigma", 2, 2)
119132
x = pm.Data("x", data.x, dims="obs_id")
120133
μ = pm.Deterministic("μ", β0 + β1 * x, dims="obs_id")
121134
pm.Normal("y", mu=μ, sigma=sigma, observed=data.y, dims="obs_id")
@@ -129,7 +142,7 @@ pm.model_to_graphviz(linear_regression)
129142

130143
```{code-cell} ipython3
131144
with linear_regression:
132-
idata = pm.sample()
145+
idata = pm.sample(random_seed=rng)
133146
```
134147

135148
```{code-cell} ipython3
@@ -145,7 +158,7 @@ xi = np.linspace(data.x.min(), data.x.max(), 20)
145158
# do posterior predictive inference
146159
with linear_regression:
147160
pm.set_data({"x": xi})
148-
idata.extend(pm.sample_posterior_predictive(idata, var_names=["y", "μ"]))
161+
idata.extend(pm.sample_posterior_predictive(idata, var_names=["y", "μ"], random_seed=rng))
149162
```
150163

151164
```{code-cell} ipython3
@@ -227,23 +240,38 @@ One of the clear things about this analysis is that we have credible evidence th
227240

228241
## Model 2: Independent slopes and intercepts model
229242

230-
We will use the same data in this analysis, but this time we will use our knowledge that data come from groups. More specifically we will essentially fit independent regressions to data within each group.
243+
We will use the same data in this analysis, but this time we will use our knowledge that data come from groups. More specifically we will essentially fit independent regressions to data within each group. This could also be described as an unpooled model.
244+
245+
+++
246+
247+
We could describe this model mathematically as:
248+
249+
$$
250+
\begin{aligned}
251+
\vec{\beta_0}, \vec{\beta_1} &\sim \text{Normal}(0, 5) \\
252+
\sigma &\sim \text{Gamma}(2, 2) \\
253+
\mu_i &= \beta_0[g_i] + \beta_1[g_i] x_i \\
254+
y_i &\sim \text{Normal}(\mu_i, g_i)
255+
\end{aligned}
256+
$$
257+
258+
Where $g_i$ is the group index for observation $i$. So the parameters $\beta_0$ and $\beta_1$ are now length $g$ vectors, not scalars. And the $[g_i]$ acts as an index to look up the group for the $i^{\th}$ observation.
231259

232260
```{code-cell} ipython3
233261
coords = {"group": group_list}
234262
235263
with pm.Model(coords=coords) as ind_slope_intercept:
236264
# Define priors
237-
sigma = pm.HalfCauchy("sigma", beta=2, dims="group")
238265
β0 = pm.Normal("β0", 0, sigma=5, dims="group")
239266
β1 = pm.Normal("β1", 0, sigma=5, dims="group")
267+
sigma = pm.Gamma("sigma", 2, 2)
240268
# Data
241269
x = pm.Data("x", data.x, dims="obs_id")
242270
g = pm.Data("g", data.group_idx, dims="obs_id")
243271
# Linear model
244272
μ = pm.Deterministic("μ", β0[g] + β1[g] * x, dims="obs_id")
245273
# Define likelihood
246-
pm.Normal("y", mu=μ, sigma=sigma[g], observed=data.y, dims="obs_id")
274+
pm.Normal("y", mu=μ, sigma=sigma, observed=data.y, dims="obs_id")
247275
```
248276

249277
By plotting the DAG for this model it is clear to see that we now have individual intercept, slope, and variance parameters for each of the groups.
@@ -254,7 +282,7 @@ pm.model_to_graphviz(ind_slope_intercept)
254282

255283
```{code-cell} ipython3
256284
with ind_slope_intercept:
257-
idata = pm.sample()
285+
idata = pm.sample(random_seed=rng)
258286
259287
az.plot_trace(idata, var_names=["~μ"]);
260288
```
@@ -273,7 +301,7 @@ xi, g = np.concatenate(xi), np.concatenate(g)
273301
# Do the posterior prediction
274302
with ind_slope_intercept:
275303
pm.set_data({"x": xi, "g": g.astype(int)})
276-
idata.extend(pm.sample_posterior_predictive(idata, var_names=["μ", "y"]))
304+
idata.extend(pm.sample_posterior_predictive(idata, var_names=["μ", "y"], random_seed=rng))
277305
```
278306

279307
```{code-cell} ipython3
@@ -369,38 +397,66 @@ We can go beyond Model 2 and incorporate even more knowledge about the structure
369397

370398
In one sense this move from Model 2 to Model 3 can be seen as adding parameters, and therefore increasing model complexity. However, in another sense, adding this knowledge about the nested structure of the data actually provides a constraint over parameter space.
371399

372-
Note: This model was producing divergent samples, so a reparameterisation trick is used. See the blog post [Why hierarchical models are awesome, tricky, and Bayesian](https://twiecki.io/blog/2017/02/08/bayesian-hierchical-non-centered/) by Thomas Wiecki for more information on this.
400+
+++
401+
402+
And we could describe this model mathematically as:
403+
404+
$$
405+
\begin{aligned}
406+
p_{0\mu}, p_{1\mu} &= \text{Normal}(0, 1) \\
407+
p_{0\sigma}, p_{1\sigma} &= \text{Gamma}(2, 2) \\
408+
\vec{\beta_0} &\sim \text{Normal}(p_{0\mu}, p_{0\sigma}) \\
409+
\vec{\beta_1} &\sim \text{Normal}(p_{1\mu}, p_{1\sigma}) \\
410+
\sigma &\sim \text{Gamma}(2, 2) \\
411+
\mu_i &= \beta_0[g_i] + \beta_1[g_i] \cdot x_i \\
412+
y_i &\sim \text{Normal}(\mu_i, \sigma)
413+
\end{aligned}
414+
$$
415+
416+
where $\beta_0$ and $\beta_1$ are the population-level parameters, and $\gamma_0$ and $\gamma_1$ are the group offset parameters.
417+
418+
+++
419+
420+
:::{admonition} **Independence assumptions**
421+
:class: note
422+
423+
The hierarchical model we are considering contains a simplification in that the population level slope and intercept are assumed to be independent. It is possible to relax this assumption and model any correlation between these parameters by using a multivariate normal distribution.
424+
:::
425+
426+
+++
427+
428+
This model could also be called a partial pooling model.
373429

374430
```{code-cell} ipython3
375-
non_centered = True
431+
non_centered = False
376432
377433
with pm.Model(coords=coords) as hierarchical:
378-
# Hyperpriors
379-
intercept_mu = pm.Normal("intercept_mu", 0, sigma=1)
380-
intercept_sigma = pm.HalfNormal("intercept_sigma", sigma=2)
381-
slope_mu = pm.Normal("slope_mu", 0, sigma=1)
382-
slope_sigma = pm.HalfNormal("slope_sigma", sigma=2)
383-
sigma_hyperprior = pm.HalfNormal("sigma_hyperprior", sigma=0.5)
384-
385434
# Define priors
386-
sigma = pm.HalfNormal("sigma", sigma=sigma_hyperprior, dims="group")
387-
435+
intercept_mu = pm.Normal("intercept_mu", 0, 1)
436+
slope_mu = pm.Normal("slope_mu", 0, 1)
437+
intercept_sigma = pm.Gamma("intercept_sigma", 2, 2)
438+
slope_sigma = pm.Gamma("slope_sigma", 2, 2)
439+
sigma = pm.Gamma("sigma", 2, 2)
388440
if non_centered:
389-
β0_offset = pm.Normal("β0_offset", 0, sigma=1, dims="group")
390-
β0 = pm.Deterministic("β0", intercept_mu + β0_offset * intercept_sigma, dims="group")
391-
β1_offset = pm.Normal("β1_offset", 0, sigma=1, dims="group")
392-
β1 = pm.Deterministic("β1", slope_mu + β1_offset * slope_sigma, dims="group")
441+
gamma_0 = pm.Normal("gamma_0", 0, 1, dims="group")
442+
β0 = pm.Deterministic("β0", intercept_mu + gamma_0 * intercept_sigma, dims="group")
443+
gamma_1_offset = pm.Normal("gamma_1_offset", 0, 1, dims="group")
444+
β1 = pm.Deterministic("β1", slope_mu + gamma_1_offset * slope_sigma, dims="group")
393445
else:
394-
β0 = pm.Normal("β0", intercept_mu, sigma=intercept_sigma, dims="group")
395-
β1 = pm.Normal("β1", slope_mu, sigma=slope_sigma, dims="group")
446+
β0 = pm.Normal("β0", intercept_mu, intercept_sigma, dims="group")
447+
β1 = pm.Normal("β1", slope_mu, slope_sigma, dims="group")
448+
449+
# Sample from population level slope and intercepts for convenience
450+
pm.Normal("pop_intercept", intercept_mu, intercept_sigma)
451+
pm.Normal("pop_slope", slope_mu, slope_sigma)
396452
397453
# Data
398454
x = pm.Data("x", data.x, dims="obs_id")
399455
g = pm.Data("g", data.group_idx, dims="obs_id")
400456
# Linear model
401457
μ = pm.Deterministic("μ", β0[g] + β1[g] * x, dims="obs_id")
402458
# Define likelihood
403-
pm.Normal("y", mu=μ, sigma=sigma[g], observed=data.y, dims="obs_id")
459+
pm.Normal("y", mu=μ, sigma=sigma, observed=data.y, dims="obs_id")
404460
```
405461

406462
Plotting the DAG now makes it clear that the group-level intercept and slope parameters are drawn from a population level distributions. That is, we have hyper-priors for the slopes and intercept parameters. This particular model does not have a hyper-prior for the measurement error - this is just left as one parameter per group, as in the previous model.
@@ -411,9 +467,17 @@ pm.model_to_graphviz(hierarchical)
411467

412468
```{code-cell} ipython3
413469
with hierarchical:
414-
idata = pm.sample(tune=2000, target_accept=0.99)
470+
idata = pm.sample(tune=4000, target_accept=0.99, random_seed=rng)
471+
```
415472

416-
az.plot_trace(idata, var_names=["~μ"]);
473+
:::{admonition} **Divergences**
474+
:class: note
475+
476+
Note that despite having a longer tune period and increased `target_accept`, this model can still generate a low number of divergent samples. If the reader is interested, you can explore the a "reparameterisation trick" is used by setting the flag `non_centered=True`. See the blog post [Why hierarchical models are awesome, tricky, and Bayesian](https://twiecki.io/blog/2017/02/08/bayesian-hierchical-non-centered/) by Thomas Wiecki for more information on this.
477+
:::
478+
479+
```{code-cell} ipython3
480+
az.plot_trace(idata, var_names=["pop_intercept", "pop_slope", "β0", "β1", "sigma"]);
417481
```
418482

419483
### Visualise
@@ -430,7 +494,7 @@ xi, g = np.concatenate(xi), np.concatenate(g)
430494
# Do the posterior prediction
431495
with hierarchical:
432496
pm.set_data({"x": xi, "g": g.astype(int)})
433-
idata.extend(pm.sample_posterior_predictive(idata, var_names=["μ", "y"]))
497+
idata.extend(pm.sample_posterior_predictive(idata, var_names=["μ", "y"], random_seed=rng))
434498
```
435499

436500
```{code-cell} ipython3
@@ -498,15 +562,12 @@ ax[1].set(xlabel="x", ylabel="y", title="Posterior Predictive")
498562
499563
# parameter space ---------------------------------------------------
500564
# plot posterior for population level slope and intercept
501-
slope = rng.normal(
502-
az.extract(idata, var_names="slope_mu"),
503-
az.extract(idata, var_names="slope_sigma"),
504-
)
505-
intercept = rng.normal(
506-
az.extract(idata, var_names="intercept_mu"),
507-
az.extract(idata, var_names="intercept_sigma"),
565+
ax[2].scatter(
566+
az.extract(idata, var_names="pop_slope"),
567+
az.extract(idata, var_names="pop_intercept"),
568+
color="k",
569+
alpha=0.05,
508570
)
509-
ax[2].scatter(slope, intercept, color="k", alpha=0.05)
510571
# plot posterior for group level slope and intercept
511572
for i, _ in enumerate(group_list):
512573
ax[2].scatter(
@@ -526,7 +587,7 @@ The panel on the right shows the posterior group level posterior of the slope an
526587
```{code-cell} ipython3
527588
:tags: [hide-input]
528589
529-
az.plot_posterior(slope, ref_val=0)
590+
az.plot_posterior(idata.posterior["pop_slope"], ref_val=0)
530591
plt.title("Population level slope parameter");
531592
```
532593

0 commit comments

Comments
 (0)