Skip to content

Commit ed4782b

Browse files
committed
significant simplification of plot code
1 parent e87c605 commit ed4782b

File tree

2 files changed

+224
-496
lines changed

2 files changed

+224
-496
lines changed

examples/generalized_linear_models/GLM-simpsons-paradox.ipynb

Lines changed: 170 additions & 301 deletions
Large diffs are not rendered by default.

examples/generalized_linear_models/GLM-simpsons-paradox.myst.md

Lines changed: 54 additions & 195 deletions
Original file line numberDiff line numberDiff line change
@@ -172,62 +172,40 @@ idata1 = predict(
172172
```{code-cell} ipython3
173173
:tags: [hide-input]
174174
175-
def plot(idata):
176-
fig, ax = plt.subplots(1, 3, figsize=(12, 4))
177-
178-
# conditional mean plot ---------------------------------------------
179-
# data
180-
ax[0].scatter(data.x, data.y, color="k")
181-
# conditional mean credible intervals
182-
post = az.extract(idata)
183-
xi = xr.DataArray(np.linspace(np.min(data.x), np.max(data.x), 20), dims=["x_plot"])
184-
y = post.β0 + post.β1 * xi
185-
region = y.quantile([0.025, 0.15, 0.5, 0.85, 0.975], dim="sample")
186-
ax[0].fill_between(
187-
xi,
188-
region.sel(quantile=0.025),
189-
region.sel(quantile=0.975),
190-
alpha=0.2,
191-
color="k",
192-
edgecolor="w",
193-
)
194-
ax[0].fill_between(
195-
xi,
196-
region.sel(quantile=0.15),
197-
region.sel(quantile=0.85),
198-
alpha=0.2,
199-
color="k",
200-
edgecolor="w",
201-
)
202-
# conditional mean
203-
ax[0].plot(xi, region.sel(quantile=0.5), "k", linewidth=2)
204-
# formatting
205-
ax[0].set(xlabel="x", ylabel="y", title="Conditional mean")
206-
207-
# posterior prediction ----------------------------------------------
208-
# data
209-
ax[1].scatter(data.x, data.y, color="k")
210-
# posterior mean and HDI's
211-
212-
ax[1].plot(xi, idata.posterior_predictive.y.mean(["chain", "draw"]), "k")
175+
def plot_band(xi, var: xr.DataArray, ax, color: str):
176+
ax.plot(xi, var.mean(["chain", "draw"]), color=color)
213177
214178
az.plot_hdi(
215179
xi,
216-
idata.posterior_predictive.y,
180+
var,
217181
hdi_prob=0.6,
218-
color="k",
182+
color=color,
219183
fill_kwargs={"alpha": 0.2, "linewidth": 0},
220-
ax=ax[1],
184+
ax=ax,
221185
)
222186
az.plot_hdi(
223187
xi,
224-
idata.posterior_predictive.y,
188+
var,
225189
hdi_prob=0.95,
226-
color="k",
190+
color=color,
227191
fill_kwargs={"alpha": 0.2, "linewidth": 0},
228-
ax=ax[1],
192+
ax=ax,
229193
)
230-
# formatting
194+
195+
196+
def plot(idata: az.InferenceData):
197+
fig, ax = plt.subplots(1, 3, figsize=(12, 4))
198+
199+
xi = xr.DataArray(np.linspace(np.min(data.x), np.max(data.x), 20), dims=["x_plot"])
200+
201+
# conditional mean plot ---------------------------------------------
202+
ax[0].scatter(data.x, data.y, color="k")
203+
plot_band(xi, idata.posterior_predictive.μ, ax=ax[0], color="k")
204+
ax[0].set(xlabel="x", ylabel="y", title="Conditional mean")
205+
206+
# posterior prediction ----------------------------------------------
207+
ax[1].scatter(data.x, data.y, color="k")
208+
plot_band(xi, idata.posterior_predictive.y, ax=ax[1], color="k")
231209
ax[1].set(xlabel="x", ylabel="y", title="Posterior predictive distribution")
232210
233211
# parameter space ---------------------------------------------------
@@ -346,78 +324,40 @@ idata2 = predict(
346324
```{code-cell} ipython3
347325
:tags: [hide-input]
348326
349-
def get_ppy_for_group(idata, group_list, group):
350-
"""Get posterior predictive outcomes for observations from a given group"""
351-
return idata.posterior_predictive.y.data[:, :, group_list == group]
352-
353-
354327
def plot(idata):
355328
fig, ax = plt.subplots(1, 3, figsize=(12, 4))
356329
357-
# conditional mean plot ---------------------------------------------
358-
for i, groupname in enumerate(group_list):
359-
# data
360-
ax[0].scatter(data.x[data.group_idx == i], data.y[data.group_idx == i], color=f"C{i}")
361-
# conditional mean credible intervals
362-
post = az.extract(idata)
330+
for i in range(len(group_list)):
331+
363332
_xi = xr.DataArray(
364333
np.linspace(
365334
np.min(data.x[data.group_idx == i]),
366335
np.max(data.x[data.group_idx == i]),
367-
20,
336+
10,
368337
),
369338
dims=["x_plot"],
370339
)
371-
y = post.β0.sel(group=groupname) + post.β1.sel(group=groupname) * _xi
372-
region = y.quantile([0.025, 0.15, 0.5, 0.85, 0.975], dim="sample")
373-
ax[0].fill_between(
374-
_xi,
375-
region.sel(quantile=0.025),
376-
region.sel(quantile=0.975),
377-
alpha=0.2,
378-
color=f"C{i}",
379-
edgecolor="w",
380-
)
381-
ax[0].fill_between(
340+
341+
# conditional mean plot ---------------------------------------------
342+
ax[0].scatter(data.x[data.group_idx == i], data.y[data.group_idx == i], color=f"C{i}")
343+
plot_band(
382344
_xi,
383-
region.sel(quantile=0.15),
384-
region.sel(quantile=0.85),
385-
alpha=0.2,
345+
idata.posterior_predictive.μ.isel(obs_id=(g == i)),
346+
ax=ax[0],
386347
color=f"C{i}",
387-
edgecolor="w",
388348
)
389-
# conditional mean
390-
ax[0].plot(_xi, region.sel(quantile=0.5), color=f"C{i}", linewidth=2)
391-
# formatting
392-
ax[0].set(xlabel="x", ylabel="y", title="Conditional mean")
393349
394-
# posterior prediction ----------------------------------------------
395-
for i, groupname in enumerate(group_list):
396-
# data
350+
# posterior prediction ----------------------------------------------
397351
ax[1].scatter(data.x[data.group_idx == i], data.y[data.group_idx == i], color=f"C{i}")
398-
# posterior mean and HDI's
399-
ax[1].plot(
400-
xi[g == i],
401-
np.mean(get_ppy_for_group(idata, g, i), axis=(0, 1)),
402-
label=groupname,
403-
)
404-
az.plot_hdi(
405-
xi[g == i],
406-
get_ppy_for_group(idata, g, i), # pp_y[:, :, g == i],
407-
hdi_prob=0.6,
408-
color=f"C{i}",
409-
fill_kwargs={"alpha": 0.4, "linewidth": 0},
352+
plot_band(
353+
_xi,
354+
idata.posterior_predictive.y.isel(obs_id=(g == i)),
410355
ax=ax[1],
411-
)
412-
az.plot_hdi(
413-
xi[g == i],
414-
get_ppy_for_group(idata, g, i),
415-
hdi_prob=0.95,
416356
color=f"C{i}",
417-
fill_kwargs={"alpha": 0.2, "linewidth": 0},
418-
ax=ax[1],
419357
)
420358
359+
# formatting
360+
ax[0].set(xlabel="x", ylabel="y", title="Conditional mean")
421361
ax[1].set(xlabel="x", ylabel="y", title="Posterior predictive distribution")
422362
423363
# parameter space ---------------------------------------------------
@@ -428,14 +368,16 @@ def plot(idata):
428368
color=f"C{i}",
429369
alpha=0.01,
430370
rasterized=True,
371+
zorder=2,
431372
)
432373
433374
ax[2].set(xlabel="slope", ylabel="intercept", title="Parameter space")
434375
ax[2].axhline(y=0, c="k")
435376
ax[2].axvline(x=0, c="k")
377+
return ax
436378
437379
438-
plot(idata2)
380+
plot(idata2);
439381
```
440382

441383
In contrast to plain regression model (Model 1), when we model on the group level we can see that now the evidence points toward _negative_ relationships between $x$ and $y$.
@@ -554,104 +496,21 @@ idata3 = predict(
554496
```{code-cell} ipython3
555497
:tags: [hide-input]
556498
557-
def plot(idata):
558-
fig, ax = plt.subplots(1, 3, figsize=(12, 4))
559-
560-
# conditional mean plot ---------------------------------------------
561-
for i, groupname in enumerate(group_list):
562-
# data
563-
ax[0].scatter(data.x[data.group_idx == i], data.y[data.group_idx == i], color=f"C{i}")
564-
# conditional mean credible intervals
565-
post = az.extract(idata)
566-
_xi = xr.DataArray(
567-
np.linspace(
568-
np.min(data.x[data.group_idx == i]),
569-
np.max(data.x[data.group_idx == i]),
570-
20,
571-
),
572-
dims=["x_plot"],
573-
)
574-
y = post.β0.sel(group=groupname) + post.β1.sel(group=groupname) * _xi
575-
region = y.quantile([0.025, 0.15, 0.5, 0.85, 0.975], dim="sample")
576-
ax[0].fill_between(
577-
_xi,
578-
region.sel(quantile=0.025),
579-
region.sel(quantile=0.975),
580-
alpha=0.2,
581-
color=f"C{i}",
582-
edgecolor="w",
583-
)
584-
ax[0].fill_between(
585-
_xi,
586-
region.sel(quantile=0.15),
587-
region.sel(quantile=0.85),
588-
alpha=0.2,
589-
color=f"C{i}",
590-
edgecolor="w",
591-
)
592-
# conditional mean
593-
ax[0].plot(_xi, region.sel(quantile=0.5), color=f"C{i}", linewidth=2)
594-
# formatting
595-
ax[0].set(xlabel="x", ylabel="y", title="Conditional mean")
596-
597-
# posterior prediction ----------------------------------------------
598-
for i, groupname in enumerate(group_list):
599-
# data
600-
ax[1].scatter(data.x[data.group_idx == i], data.y[data.group_idx == i], color=f"C{i}")
601-
# posterior mean and HDI's
602-
ax[1].plot(
603-
xi[g == i],
604-
np.mean(get_ppy_for_group(idata, g, i), axis=(0, 1)),
605-
label=groupname,
606-
)
607-
az.plot_hdi(
608-
xi[g == i],
609-
get_ppy_for_group(idata, g, i),
610-
hdi_prob=0.6,
611-
color=f"C{i}",
612-
fill_kwargs={"alpha": 0.4, "linewidth": 0},
613-
ax=ax[1],
614-
)
615-
az.plot_hdi(
616-
xi[g == i],
617-
get_ppy_for_group(idata, g, i),
618-
hdi_prob=0.95,
619-
color=f"C{i}",
620-
fill_kwargs={"alpha": 0.2, "linewidth": 0},
621-
ax=ax[1],
622-
)
623-
624-
ax[1].set(xlabel="x", ylabel="y", title="Posterior Predictive")
625-
626-
# parameter space ---------------------------------------------------
627-
# plot posterior for population level slope and intercept
628-
ax[2].scatter(
629-
az.extract(idata, var_names="pop_slope"),
630-
az.extract(idata, var_names="pop_intercept"),
631-
color="k",
632-
alpha=0.05,
633-
)
634-
# plot posterior for group level slope and intercept
635-
for i, _ in enumerate(group_list):
636-
ax[2].scatter(
637-
az.extract(idata, var_names="β1")[i, :],
638-
az.extract(idata, var_names="β0")[i, :],
639-
color=f"C{i}",
640-
alpha=0.01,
641-
)
642-
643-
ax[2].set(
644-
xlabel="slope",
645-
ylabel="intercept",
646-
title="Parameter space",
647-
xlim=[-2, 1],
648-
ylim=[-5, 5],
649-
)
650-
ax[2].axhline(y=0, c="k")
651-
ax[2].axvline(x=0, c="k")
499+
ax = plot(idata3)
652500
501+
# add a KDE countour plot of the population level parameters
502+
sns.kdeplot(
503+
x=az.extract(idata3, var_names="pop_slope"),
504+
y=az.extract(idata3, var_names="pop_intercept"),
505+
thresh=0.1,
506+
levels=5,
507+
ax=ax[2],
508+
)
653509
654-
plot(idata3)
510+
ax[2].set(
511+
xlim=[-2, 1],
512+
ylim=[-5, 5],
513+
)
655514
```
656515

657516
The panel on the right shows the posterior group level posterior of the slope and intercept parameters in black. This particular visualisation is a little unclear however, so we can just plot the marginal distribution below to see how much belief we have in the slope being less than zero.

0 commit comments

Comments
 (0)