@@ -172,62 +172,40 @@ idata1 = predict(
172
172
``` {code-cell} ipython3
173
173
:tags: [hide-input]
174
174
175
- def plot(idata):
176
- fig, ax = plt.subplots(1, 3, figsize=(12, 4))
177
-
178
- # conditional mean plot ---------------------------------------------
179
- # data
180
- ax[0].scatter(data.x, data.y, color="k")
181
- # conditional mean credible intervals
182
- post = az.extract(idata)
183
- xi = xr.DataArray(np.linspace(np.min(data.x), np.max(data.x), 20), dims=["x_plot"])
184
- y = post.β0 + post.β1 * xi
185
- region = y.quantile([0.025, 0.15, 0.5, 0.85, 0.975], dim="sample")
186
- ax[0].fill_between(
187
- xi,
188
- region.sel(quantile=0.025),
189
- region.sel(quantile=0.975),
190
- alpha=0.2,
191
- color="k",
192
- edgecolor="w",
193
- )
194
- ax[0].fill_between(
195
- xi,
196
- region.sel(quantile=0.15),
197
- region.sel(quantile=0.85),
198
- alpha=0.2,
199
- color="k",
200
- edgecolor="w",
201
- )
202
- # conditional mean
203
- ax[0].plot(xi, region.sel(quantile=0.5), "k", linewidth=2)
204
- # formatting
205
- ax[0].set(xlabel="x", ylabel="y", title="Conditional mean")
206
-
207
- # posterior prediction ----------------------------------------------
208
- # data
209
- ax[1].scatter(data.x, data.y, color="k")
210
- # posterior mean and HDI's
211
-
212
- ax[1].plot(xi, idata.posterior_predictive.y.mean(["chain", "draw"]), "k")
175
+ def plot_band(xi, var: xr.DataArray, ax, color: str):
176
+ ax.plot(xi, var.mean(["chain", "draw"]), color=color)
213
177
214
178
az.plot_hdi(
215
179
xi,
216
- idata.posterior_predictive.y ,
180
+ var ,
217
181
hdi_prob=0.6,
218
- color="k" ,
182
+ color=color ,
219
183
fill_kwargs={"alpha": 0.2, "linewidth": 0},
220
- ax=ax[1] ,
184
+ ax=ax,
221
185
)
222
186
az.plot_hdi(
223
187
xi,
224
- idata.posterior_predictive.y ,
188
+ var ,
225
189
hdi_prob=0.95,
226
- color="k" ,
190
+ color=color ,
227
191
fill_kwargs={"alpha": 0.2, "linewidth": 0},
228
- ax=ax[1] ,
192
+ ax=ax,
229
193
)
230
- # formatting
194
+
195
+
196
+ def plot(idata: az.InferenceData):
197
+ fig, ax = plt.subplots(1, 3, figsize=(12, 4))
198
+
199
+ xi = xr.DataArray(np.linspace(np.min(data.x), np.max(data.x), 20), dims=["x_plot"])
200
+
201
+ # conditional mean plot ---------------------------------------------
202
+ ax[0].scatter(data.x, data.y, color="k")
203
+ plot_band(xi, idata.posterior_predictive.μ, ax=ax[0], color="k")
204
+ ax[0].set(xlabel="x", ylabel="y", title="Conditional mean")
205
+
206
+ # posterior prediction ----------------------------------------------
207
+ ax[1].scatter(data.x, data.y, color="k")
208
+ plot_band(xi, idata.posterior_predictive.y, ax=ax[1], color="k")
231
209
ax[1].set(xlabel="x", ylabel="y", title="Posterior predictive distribution")
232
210
233
211
# parameter space ---------------------------------------------------
@@ -346,78 +324,40 @@ idata2 = predict(
346
324
``` {code-cell} ipython3
347
325
:tags: [hide-input]
348
326
349
- def get_ppy_for_group(idata, group_list, group):
350
- """Get posterior predictive outcomes for observations from a given group"""
351
- return idata.posterior_predictive.y.data[:, :, group_list == group]
352
-
353
-
354
327
def plot(idata):
355
328
fig, ax = plt.subplots(1, 3, figsize=(12, 4))
356
329
357
- # conditional mean plot ---------------------------------------------
358
- for i, groupname in enumerate(group_list):
359
- # data
360
- ax[0].scatter(data.x[data.group_idx == i], data.y[data.group_idx == i], color=f"C{i}")
361
- # conditional mean credible intervals
362
- post = az.extract(idata)
330
+ for i in range(len(group_list)):
331
+
363
332
_xi = xr.DataArray(
364
333
np.linspace(
365
334
np.min(data.x[data.group_idx == i]),
366
335
np.max(data.x[data.group_idx == i]),
367
- 20 ,
336
+ 10 ,
368
337
),
369
338
dims=["x_plot"],
370
339
)
371
- y = post.β0.sel(group=groupname) + post.β1.sel(group=groupname) * _xi
372
- region = y.quantile([0.025, 0.15, 0.5, 0.85, 0.975], dim="sample")
373
- ax[0].fill_between(
374
- _xi,
375
- region.sel(quantile=0.025),
376
- region.sel(quantile=0.975),
377
- alpha=0.2,
378
- color=f"C{i}",
379
- edgecolor="w",
380
- )
381
- ax[0].fill_between(
340
+
341
+ # conditional mean plot ---------------------------------------------
342
+ ax[0].scatter(data.x[data.group_idx == i], data.y[data.group_idx == i], color=f"C{i}")
343
+ plot_band(
382
344
_xi,
383
- region.sel(quantile=0.15),
384
- region.sel(quantile=0.85),
385
- alpha=0.2,
345
+ idata.posterior_predictive.μ.isel(obs_id=(g == i)),
346
+ ax=ax[0],
386
347
color=f"C{i}",
387
- edgecolor="w",
388
348
)
389
- # conditional mean
390
- ax[0].plot(_xi, region.sel(quantile=0.5), color=f"C{i}", linewidth=2)
391
- # formatting
392
- ax[0].set(xlabel="x", ylabel="y", title="Conditional mean")
393
349
394
- # posterior prediction ----------------------------------------------
395
- for i, groupname in enumerate(group_list):
396
- # data
350
+ # posterior prediction ----------------------------------------------
397
351
ax[1].scatter(data.x[data.group_idx == i], data.y[data.group_idx == i], color=f"C{i}")
398
- # posterior mean and HDI's
399
- ax[1].plot(
400
- xi[g == i],
401
- np.mean(get_ppy_for_group(idata, g, i), axis=(0, 1)),
402
- label=groupname,
403
- )
404
- az.plot_hdi(
405
- xi[g == i],
406
- get_ppy_for_group(idata, g, i), # pp_y[:, :, g == i],
407
- hdi_prob=0.6,
408
- color=f"C{i}",
409
- fill_kwargs={"alpha": 0.4, "linewidth": 0},
352
+ plot_band(
353
+ _xi,
354
+ idata.posterior_predictive.y.isel(obs_id=(g == i)),
410
355
ax=ax[1],
411
- )
412
- az.plot_hdi(
413
- xi[g == i],
414
- get_ppy_for_group(idata, g, i),
415
- hdi_prob=0.95,
416
356
color=f"C{i}",
417
- fill_kwargs={"alpha": 0.2, "linewidth": 0},
418
- ax=ax[1],
419
357
)
420
358
359
+ # formatting
360
+ ax[0].set(xlabel="x", ylabel="y", title="Conditional mean")
421
361
ax[1].set(xlabel="x", ylabel="y", title="Posterior predictive distribution")
422
362
423
363
# parameter space ---------------------------------------------------
@@ -428,14 +368,16 @@ def plot(idata):
428
368
color=f"C{i}",
429
369
alpha=0.01,
430
370
rasterized=True,
371
+ zorder=2,
431
372
)
432
373
433
374
ax[2].set(xlabel="slope", ylabel="intercept", title="Parameter space")
434
375
ax[2].axhline(y=0, c="k")
435
376
ax[2].axvline(x=0, c="k")
377
+ return ax
436
378
437
379
438
- plot(idata2)
380
+ plot(idata2);
439
381
```
440
382
441
383
In contrast to plain regression model (Model 1), when we model on the group level we can see that now the evidence points toward _ negative_ relationships between $x$ and $y$.
@@ -554,104 +496,21 @@ idata3 = predict(
554
496
``` {code-cell} ipython3
555
497
:tags: [hide-input]
556
498
557
- def plot(idata):
558
- fig, ax = plt.subplots(1, 3, figsize=(12, 4))
559
-
560
- # conditional mean plot ---------------------------------------------
561
- for i, groupname in enumerate(group_list):
562
- # data
563
- ax[0].scatter(data.x[data.group_idx == i], data.y[data.group_idx == i], color=f"C{i}")
564
- # conditional mean credible intervals
565
- post = az.extract(idata)
566
- _xi = xr.DataArray(
567
- np.linspace(
568
- np.min(data.x[data.group_idx == i]),
569
- np.max(data.x[data.group_idx == i]),
570
- 20,
571
- ),
572
- dims=["x_plot"],
573
- )
574
- y = post.β0.sel(group=groupname) + post.β1.sel(group=groupname) * _xi
575
- region = y.quantile([0.025, 0.15, 0.5, 0.85, 0.975], dim="sample")
576
- ax[0].fill_between(
577
- _xi,
578
- region.sel(quantile=0.025),
579
- region.sel(quantile=0.975),
580
- alpha=0.2,
581
- color=f"C{i}",
582
- edgecolor="w",
583
- )
584
- ax[0].fill_between(
585
- _xi,
586
- region.sel(quantile=0.15),
587
- region.sel(quantile=0.85),
588
- alpha=0.2,
589
- color=f"C{i}",
590
- edgecolor="w",
591
- )
592
- # conditional mean
593
- ax[0].plot(_xi, region.sel(quantile=0.5), color=f"C{i}", linewidth=2)
594
- # formatting
595
- ax[0].set(xlabel="x", ylabel="y", title="Conditional mean")
596
-
597
- # posterior prediction ----------------------------------------------
598
- for i, groupname in enumerate(group_list):
599
- # data
600
- ax[1].scatter(data.x[data.group_idx == i], data.y[data.group_idx == i], color=f"C{i}")
601
- # posterior mean and HDI's
602
- ax[1].plot(
603
- xi[g == i],
604
- np.mean(get_ppy_for_group(idata, g, i), axis=(0, 1)),
605
- label=groupname,
606
- )
607
- az.plot_hdi(
608
- xi[g == i],
609
- get_ppy_for_group(idata, g, i),
610
- hdi_prob=0.6,
611
- color=f"C{i}",
612
- fill_kwargs={"alpha": 0.4, "linewidth": 0},
613
- ax=ax[1],
614
- )
615
- az.plot_hdi(
616
- xi[g == i],
617
- get_ppy_for_group(idata, g, i),
618
- hdi_prob=0.95,
619
- color=f"C{i}",
620
- fill_kwargs={"alpha": 0.2, "linewidth": 0},
621
- ax=ax[1],
622
- )
623
-
624
- ax[1].set(xlabel="x", ylabel="y", title="Posterior Predictive")
625
-
626
- # parameter space ---------------------------------------------------
627
- # plot posterior for population level slope and intercept
628
- ax[2].scatter(
629
- az.extract(idata, var_names="pop_slope"),
630
- az.extract(idata, var_names="pop_intercept"),
631
- color="k",
632
- alpha=0.05,
633
- )
634
- # plot posterior for group level slope and intercept
635
- for i, _ in enumerate(group_list):
636
- ax[2].scatter(
637
- az.extract(idata, var_names="β1")[i, :],
638
- az.extract(idata, var_names="β0")[i, :],
639
- color=f"C{i}",
640
- alpha=0.01,
641
- )
642
-
643
- ax[2].set(
644
- xlabel="slope",
645
- ylabel="intercept",
646
- title="Parameter space",
647
- xlim=[-2, 1],
648
- ylim=[-5, 5],
649
- )
650
- ax[2].axhline(y=0, c="k")
651
- ax[2].axvline(x=0, c="k")
499
+ ax = plot(idata3)
652
500
501
+ # add a KDE countour plot of the population level parameters
502
+ sns.kdeplot(
503
+ x=az.extract(idata3, var_names="pop_slope"),
504
+ y=az.extract(idata3, var_names="pop_intercept"),
505
+ thresh=0.1,
506
+ levels=5,
507
+ ax=ax[2],
508
+ )
653
509
654
- plot(idata3)
510
+ ax[2].set(
511
+ xlim=[-2, 1],
512
+ ylim=[-5, 5],
513
+ )
655
514
```
656
515
657
516
The panel on the right shows the posterior group level posterior of the slope and intercept parameters in black. This particular visualisation is a little unclear however, so we can just plot the marginal distribution below to see how much belief we have in the slope being less than zero.
0 commit comments