Skip to content

Commit 8944da3

Browse files
committed
final tweaks before review
1 parent b944a08 commit 8944da3

File tree

2 files changed

+259
-277
lines changed

2 files changed

+259
-277
lines changed

examples/generalized_linear_models/GLM-simpsons-paradox.ipynb

Lines changed: 238 additions & 246 deletions
Large diffs are not rendered by default.

examples/generalized_linear_models/GLM-simpsons-paradox.myst.md

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ def predict(model: pm.Model, idata: az.InferenceData, predict_at: dict) -> az.In
168168
return idata
169169
```
170170

171+
And now let's use that `predict` function to do out of sample predictions which we will use for visualisation.
172+
171173
```{code-cell} ipython3
172174
:tags: [hide-output]
173175
@@ -178,6 +180,8 @@ idata1 = predict(
178180
)
179181
```
180182

183+
Finally, we can now visualise the model fit to data, and our posterior in parameter space.
184+
181185
```{code-cell} ipython3
182186
:tags: [hide-input]
183187
@@ -413,16 +417,16 @@ def plot(idata):
413417
plot(idata2);
414418
```
415419

416-
In contrast to plain regression model (Model 1), when we model on the group level we can see that now the evidence points toward _negative_ relationships between $x$ and $y$.
420+
In contrast to Model 1, when we consider groups we can see that now the evidence points toward _negative_ relationships between $x$ and $y$.
417421

418422
```{code-cell} ipython3
419423
ax = az.plot_forest(idata2.posterior["β1"], combined=True, figsize=(12, 4))
420424
ax[0].set(
421-
title="Model 2 suggests a negative slopes for each group", xlabel=r"$\beta_1$", ylabel="Group"
425+
title="Model 2 suggests negative slopes for each group", xlabel=r"$\beta_1$", ylabel="Group"
422426
);
423427
```
424428

425-
## Model 3: Partial pooling (hierarchical) model
429+
## Model 3: Partial pooling model
426430

427431
Model 3 assumes the same causal DAG as model 2 (see above). However, we can go further and incorporate more knowledge about the structure of our data. Rather than treating each group as entirely independent, we can use our knowledge that these groups are drawn from a population-level distribution.
428432

@@ -442,11 +446,7 @@ y_i &\sim \text{Normal}(\mu_i, \sigma)
442446
\end{aligned}
443447
$$
444448

445-
where $\beta_0$ and $\beta_1$ are the population-level parameters, and $\gamma_0$ and $\gamma_1$ are the group offset parameters.
446-
447-
+++
448-
449-
This model could also be called a partial pooling model.
449+
where $\vec{\beta_0}$ and $\vec{\beta_1}$ are the group-level parameters. These group level parameters can be though of as being sampled from population level intercept distribution $\text{Normal}(p_{0\mu}, p_{0\sigma})$ and population level slope distribution $\text{Normal}(p_{1\mu}, p_{1\sigma})$.
450450

451451
+++
452452

@@ -508,32 +508,13 @@ with pm.Model(coords=coords) as model3:
508508
pm.Normal("y", mu=μ, sigma=sigma, observed=data.y, dims="obs_id")
509509
```
510510

511-
```{code-cell} ipython3
512-
with pm.Model(coords=coords) as model3:
513-
# Define priors
514-
intercept_mu = pm.Normal("intercept_mu", 0, 1)
515-
slope_mu = pm.Normal("slope_mu", 0, 1)
516-
intercept_sigma = pm.Gamma("intercept_sigma", 2, 2)
517-
slope_sigma = pm.Gamma("slope_sigma", 2, 2)
518-
sigma = pm.Gamma("sigma", 2, 2)
519-
β0 = pm.Normal("β0", intercept_mu, intercept_sigma, dims="group")
520-
β1 = pm.Normal("β1", slope_mu, slope_sigma, dims="group")
521-
# Data
522-
x = pm.Data("x", data.x, dims="obs_id")
523-
g = pm.Data("g", data.group_idx, dims="obs_id")
524-
# Linear model
525-
μ = pm.Deterministic("μ", β0[g] + β1[g] * x, dims="obs_id")
526-
# Define likelihood
527-
pm.Normal("y", mu=μ, sigma=sigma, observed=data.y, dims="obs_id")
528-
```
529-
530-
Plotting the DAG now makes it clear that the group-level intercept and slope parameters are drawn from a population level distributions. That is, we have hyper-priors for the slopes and intercept parameters. This particular model does not have a hyper-prior for the measurement error - this is just left as one parameter per group, as in the previous model.
511+
Plotting the DAG now makes it clear that the group-level intercept and slope parameters are drawn from population level distributions. That is, we have hyper-priors for the slopes and intercept parameters. This particular model does not have a hyper-prior for the measurement error - this is just left as one parameter per group, as in the previous model.
531512

532513
```{code-cell} ipython3
533514
pm.model_to_graphviz(model3)
534515
```
535516

536-
The nodes `pop_intercept` and `pop_slope` represent the population-level intercept and slope parameters. While the 5 $\beta_0$ and $\beta_1$ nodes represent intercepts and slopes for each of the 5 observed groups (respectively), the `pop_intercept` and `pop_slope` represent what we can infer about the population-level intercept and slope. Equivalently, we could say they represent our beliefs about an as yet unobserved group.
517+
The nodes `pop_intercept` and `pop_slope` represent the population-level intercept and slope parameters. While the $\beta_0$ and $\beta_1$ nodes represent intercepts and slopes for each of the 5 observed groups (respectively), the `pop_intercept` and `pop_slope` represent what we can infer about the population-level intercept and slope. Equivalently, we could say they represent our beliefs about an as yet unobserved group.
537518

538519
+++
539520

@@ -601,8 +582,17 @@ The panel on the right shows the posterior group level posterior of the slope an
601582
```{code-cell} ipython3
602583
:tags: [hide-input]
603584
604-
az.plot_posterior(idata3.posterior["pop_slope"], ref_val=0)
605-
plt.title("Population level slope parameter");
585+
fig, ax = plt.subplots(1, 2)
586+
587+
az.plot_forest(idata2.posterior["β1"], combined=True, ax=ax[0])
588+
ax[0].set(
589+
title="Model 3 suggests negative slopes for each group", xlabel=r"$\beta_1$", ylabel="Group"
590+
)
591+
592+
az.plot_posterior(idata3.posterior["pop_slope"], ref_val=0, ax=ax[1])
593+
ax[1].set(
594+
title="Population level slope parameter", xlabel=r"$\text{Normal}(p_{1\mu}, p_{1\sigma})$"
595+
);
606596
```
607597

608598
## Summary

0 commit comments

Comments
 (0)