Skip to content

Commit 1da30fb

Browse files
committed
Bunch of improvements and fixes
1 parent c134f92 commit 1da30fb

File tree

2 files changed

+331
-220
lines changed

2 files changed

+331
-220
lines changed

examples/generalized_linear_models/GLM-simpsons-paradox.ipynb

Lines changed: 281 additions & 200 deletions
Large diffs are not rendered by default.

examples/generalized_linear_models/GLM-simpsons-paradox.myst.md

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ kernelspec:
2121

2222
+++
2323

24-
[Simpson's Paradox](https://en.wikipedia.org/wiki/Simpson%27s_paradox) describes a situation where there might be a negative relationship between two variables within a group, but when data from multiple groups are combined, that relationship may disappear or even reverse sign. The gif below (from the [Simpson's Paradox](https://en.wikipedia.org/wiki/Simpson%27s_paradox) Wikipedia page) demonstrates this very nicely.
24+
[Simpson's Paradox](https://en.wikipedia.org/wiki/Simpson%27s_paradox) describes a situation where there might be a negative relationship between two variables within a group, but when data from multiple groups are combined, that relationship may disappear or even reverse sign. The gif below (from the Simpson's Paradox [Wikipedia](https://en.wikipedia.org/wiki/Simpson%27s_paradox) page) demonstrates this very nicely.
2525

2626
![](https://upload.wikimedia.org/wikipedia/commons/f/fb/Simpsons_paradox_-_animation.gif)
2727

@@ -41,6 +41,7 @@ import xarray as xr
4141
```{code-cell} ipython3
4242
%config InlineBackend.figure_format = 'retina'
4343
az.style.use("arviz-darkgrid")
44+
plt.rcParams["figure.figsize"] = [12, 6]
4445
rng = np.random.default_rng(1234)
4546
```
4647

@@ -49,8 +50,6 @@ rng = np.random.default_rng(1234)
4950
This data generation was influenced by this [stackexchange](https://stats.stackexchange.com/questions/479201/understanding-simpsons-paradox-with-random-effects) question.
5051

5152
```{code-cell} ipython3
52-
:tags: [hide-input]
53-
5453
def generate():
5554
group_list = ["one", "two", "three", "four", "five"]
5655
trials_per_group = 20
@@ -68,9 +67,8 @@ def generate():
6867
y = rng.normal(intercept + (x - mx) * slope, 1)
6968
data = pd.DataFrame({"group": group, "group_idx": subject, "x": x, "y": y})
7069
return data, group_list
71-
```
7270
73-
```{code-cell} ipython3
71+
7472
data, group_list = generate()
7573
```
7674

@@ -137,7 +135,7 @@ with pm.Model() as model1:
137135
pm.model_to_graphviz(model1)
138136
```
139137

140-
### Do inference
138+
### Conduct inference
141139

142140
```{code-cell} ipython3
143141
with model1:
@@ -150,18 +148,20 @@ az.plot_trace(idata1, var_names=["~μ"]);
150148

151149
### Visualisation
152150

153-
First we'll define a handy predict function which will do out of sample predictions for us. This will be handy when it comes to visualising the model fit.
151+
First we'll define a handy predict function which will do out of sample predictions for us. This will be handy when it comes to visualising the model fits.
154152

155153
```{code-cell} ipython3
156154
def predict(model: pm.Model, idata: az.InferenceData, predict_at: dict) -> az.InferenceData:
157-
"""Do posterior predictive inference"""
155+
"""Do posterior predictive inference at a set of out of sample points specified by `predict_at`."""
158156
with model:
159157
pm.set_data(predict_at)
160158
idata.extend(pm.sample_posterior_predictive(idata, var_names=["y", "μ"], random_seed=rng))
161159
return idata
162160
```
163161

164162
```{code-cell} ipython3
163+
:tags: [hide-output]
164+
165165
idata1 = predict(
166166
model=model1,
167167
idata=idata1,
@@ -234,9 +234,14 @@ The plot on the right shows out posterior beliefs in **parameter space**.
234234

235235
+++
236236

237-
One of the clear things about this analysis is that we have credible evidence that $x$ and $y$ are _positively_ correlated. We can see this from the posterior over the slope (see right hand panel in the figure above).
237+
One of the clear things about this analysis is that we have credible evidence that $x$ and $y$ are _positively_ correlated. We can see this from the posterior over the slope (see right hand panel in the figure above) which we isolate in the plot below.
238238

239-
+++
239+
```{code-cell} ipython3
240+
:tags: [hide-input]
241+
242+
ax = az.plot_posterior(idata1.posterior["β1"], ref_val=0)
243+
ax.set(title="Model 1 strongly suggests a positive slope", xlabel=r"$\beta_1$");
244+
```
240245

241246
## Model 2: Unpooled regression
242247

@@ -260,12 +265,16 @@ $$
260265
\begin{aligned}
261266
\vec{\beta_0}, \vec{\beta_1} &\sim \text{Normal}(0, 5) \\
262267
\sigma &\sim \text{Gamma}(2, 2) \\
263-
\mu_i &= \beta_0[g_i] + \beta_1[g_i] x_i \\
268+
\mu_i &= \vec{\beta_0}[g_i] + \vec{\beta_1}[g_i] x_i \\
264269
y_i &\sim \text{Normal}(\mu_i, g_i)
265270
\end{aligned}
266271
$$
267272

268-
Where $g_i$ is the group index for observation $i$. So the parameters $\beta_0$ and $\beta_1$ are now length $g$ vectors, not scalars. And the $[g_i]$ acts as an index to look up the group for the $i^{\th}$ observation.
273+
Where $g_i$ is the group index for observation $i$. So the parameters $\beta_0$ and $\beta_1$ are now length $g$ vectors, not scalars. And the $[g_i]$ acts as an index to look up the group for the $i^\text{th}$ observation.
274+
275+
+++
276+
277+
### Build model
269278

270279
```{code-cell} ipython3
271280
coords = {"group": group_list}
@@ -284,12 +293,14 @@ with pm.Model(coords=coords) as model2:
284293
pm.Normal("y", mu=μ, sigma=sigma, observed=data.y, dims="obs_id")
285294
```
286295

287-
By plotting the DAG for this model it is clear to see that we now have individual intercept, slope, and variance parameters for each of the groups.
296+
By plotting the DAG for this model it is clear to see that we now have individual intercept and slope parameters for each of the groups.
288297

289298
```{code-cell} ipython3
290299
pm.model_to_graphviz(model2)
291300
```
292301

302+
### Conduct inference
303+
293304
```{code-cell} ipython3
294305
with model2:
295306
idata2 = pm.sample(random_seed=rng)
@@ -303,17 +314,21 @@ az.plot_trace(idata2, var_names=["~μ"]);
303314
# Generate values of xi and g for posterior prediction
304315
n_points = 10
305316
n_groups = len(data.group.unique())
317+
# Generate xi values for each group and concatenate them
306318
xi = np.concatenate(
307319
[
308320
np.linspace(group[1].x.min(), group[1].x.max(), n_points)
309321
for group in data.groupby("group_idx")
310322
]
311323
)
324+
# Generate the group indices array g and cast it to integers
312325
g = np.concatenate([[i] * n_points for i in range(n_groups)]).astype(int)
313326
predict_at = {"x": xi, "g": g}
314327
```
315328

316329
```{code-cell} ipython3
330+
:tags: [hide-output]
331+
317332
idata2 = predict(
318333
model=model2,
319334
idata=idata2,
@@ -382,7 +397,12 @@ plot(idata2);
382397

383398
In contrast to plain regression model (Model 1), when we model on the group level we can see that now the evidence points toward _negative_ relationships between $x$ and $y$.
384399

385-
+++
400+
```{code-cell} ipython3
401+
ax = az.plot_forest(idata2.posterior["β1"], combined=True, figsize=(12, 4))
402+
ax[0].set(
403+
title="Model 2 suggests a negative slopes for each group", xlabel=r"$\beta_1$", ylabel="Group"
404+
);
405+
```
386406

387407
## Model 3: Partial pooling (hierarchical) model
388408

@@ -399,7 +419,7 @@ p_{0\sigma}, p_{1\sigma} &= \text{Gamma}(2, 2) \\
399419
\vec{\beta_0} &\sim \text{Normal}(p_{0\mu}, p_{0\sigma}) \\
400420
\vec{\beta_1} &\sim \text{Normal}(p_{1\mu}, p_{1\sigma}) \\
401421
\sigma &\sim \text{Gamma}(2, 2) \\
402-
\mu_i &= \beta_0[g_i] + \beta_1[g_i] \cdot x_i \\
422+
\mu_i &= \vec{\beta_0}[g_i] + \vec{\beta_1}[g_i] \cdot x_i \\
403423
y_i &\sim \text{Normal}(\mu_i, \sigma)
404424
\end{aligned}
405425
$$
@@ -418,6 +438,10 @@ The hierarchical model we are considering contains a simplification in that the
418438
In one sense this move from Model 2 to Model 3 can be seen as adding parameters, and therefore increasing model complexity. However, in another sense, adding this knowledge about the nested structure of the data actually provides a constraint over parameter space.
419439
:::
420440

441+
+++
442+
443+
### Build model
444+
421445
```{code-cell} ipython3
422446
non_centered = False
423447
@@ -458,6 +482,10 @@ pm.model_to_graphviz(model3)
458482

459483
The nodes `pop_intercept` and `pop_slope` represent the population-level intercept and slope parameters. While the 5 $\beta_0$ and $\beta_1$ nodes represent intercepts and slopes for each of the 5 observed groups (respectively), the `pop_intercept` and `pop_slope` represent what we can infer about the population-level intercept and slope. Equivalently, we could say they represent our beliefs about an as yet unobserved group.
460484

485+
+++
486+
487+
### Conduct inference
488+
461489
```{code-cell} ipython3
462490
with model3:
463491
idata3 = pm.sample(tune=4000, target_accept=0.99, random_seed=rng)
@@ -474,15 +502,19 @@ az.plot_trace(idata3, var_names=["pop_intercept", "pop_slope", "β0", "β1", "si
474502
### Visualise
475503

476504
```{code-cell} ipython3
505+
:tags: [hide-output]
506+
477507
# Generate values of xi and g for posterior prediction
478508
n_points = 10
479509
n_groups = len(data.group.unique())
510+
# Generate xi values for each group and concatenate them
480511
xi = np.concatenate(
481512
[
482513
np.linspace(group[1].x.min(), group[1].x.max(), n_points)
483514
for group in data.groupby("group_idx")
484515
]
485516
)
517+
# Generate the group indices array g and cast it to integers
486518
g = np.concatenate([[i] * n_points for i in range(n_groups)]).astype(int)
487519
predict_at = {"x": xi, "g": g}
488520
@@ -504,13 +536,11 @@ sns.kdeplot(
504536
y=az.extract(idata3, var_names="pop_intercept"),
505537
thresh=0.1,
506538
levels=5,
539+
color="k",
507540
ax=ax[2],
508541
)
509542
510-
ax[2].set(
511-
xlim=[-2, 1],
512-
ylim=[-5, 5],
513-
)
543+
ax[2].set(xlim=[-2, 1], ylim=[-5, 5]);
514544
```
515545

516546
The panel on the right shows the posterior group level posterior of the slope and intercept parameters as a contour plot. We can also just plot the marginal distribution below to see how much belief we have in the slope being less than zero.
@@ -540,7 +570,7 @@ If you are interested in learning more, there are a number of other [PyMC exampl
540570
* Updated by [Benjamin T. Vincent](https://github.com/drbenvincent) in April 2022
541571
* Updated by [Benjamin T. Vincent](https://github.com/drbenvincent) in February 2023 to run on PyMC v5
542572
* Updated to use `az.extract` by [Benjamin T. Vincent](https://github.com/drbenvincent) in February 2023 ([pymc-examples#522](https://github.com/pymc-devs/pymc-examples/pull/522))
543-
* Updated by [Benjamin T. Vincent](https://github.com/drbenvincent) in September 2024
573+
* Updated by [Benjamin T. Vincent](https://github.com/drbenvincent) in September 2024 ([pymc-examples#697](https://github.com/pymc-devs/pymc-examples/pull/697))
544574

545575
+++
546576

0 commit comments

Comments
 (0)