Skip to content

Commit c842de3

Browse files
committed
conditonal posteror
1 parent 465d762 commit c842de3

File tree

2 files changed

+297
-34
lines changed

2 files changed

+297
-34
lines changed

examples/time_series/Time_Series_Generative_Graph.ipynb

Lines changed: 215 additions & 33 deletions
Large diffs are not rendered by default.

examples/time_series/Time_Series_Generative_Graph.myst.md

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ for i, hdi_prob in enumerate((0.94, 0.64), 1):
267267
color="C0",
268268
label=f"{hdi_prob:.0%} HDI",
269269
)
270-
ax.plot(prior.prior["ar"].mean(("chain", "draw")), color="C0", label="Mean")
270+
ax.plot(post_pred_ar.mean(("chain", "draw")), color="C0", label="Mean")
271271
ax.plot(ar_obs, color="black", label="Observed")
272272
ax.legend(loc="upper right")
273273
ax.set_xlabel("time")
@@ -307,6 +307,87 @@ $$
307307

308308
+++
309309

310+
Let's see how to do this in PyMC! The key observation is that we need to pass the observed data explicitly into out "for loop" in the generative graph. That is, we need to pass it into the {meth}`scan <pytensor.scan.basic.scan>` function.
311+
312+
```{code-cell} ipython3
313+
def conditional_ar_dist(y_data, ar_init, rho, sigma, size):
314+
def ar_step(x_tm2, x_tm1, rho, sigma):
315+
mu = x_tm1 * rho[0] + x_tm2 * rho[1]
316+
x = mu + pm.Normal.dist(sigma=sigma)
317+
return x, collect_default_updates([x])
318+
319+
# Here we condition on the observed data by passing it through the `sequences` argument.
320+
ar_innov, _ = pytensor.scan(
321+
fn=ar_step,
322+
sequences=[{"input": y_data, "taps": list(range(-lags, 0))}],
323+
non_sequences=[rho, sigma],
324+
n_steps=trials - lags,
325+
strict=True,
326+
)
327+
328+
return ar_innov
329+
```
330+
331+
Then we can simply generate samples from the posterior predictive distribution. Observe we need to "rewrite" the generative graph to include te conditioned transition step. Nevertheless, we will use the posterior samples from the model above, this means we can put *any* "prior" distributions on the parameters we learned. For a detailed explanation on these type of cross model predictions, see the great blog post [Out of model predictions with PyMC](https://www.pymc-labs.com/blog-posts/out-of-model-predictions-with-pymc/).
332+
333+
```{code-cell} ipython3
334+
coords = {
335+
"lags": range(-lags, 0),
336+
"steps": range(trials - lags),
337+
"trials": range(trials),
338+
}
339+
with pm.Model(coords=coords, check_bounds=False) as conditional_model:
340+
y_data = pm.Data("y_data", ar_obs)
341+
rho = pm.Flat(name="rho", dims=("lags",))
342+
sigma = pm.Flat(name="sigma")
343+
ar_init = pm.Flat(name="ar_init", dims=("lags",))
344+
345+
ar_innov = pm.CustomDist(
346+
"ar_dist",
347+
y_data,
348+
ar_init,
349+
rho,
350+
sigma,
351+
dist=conditional_ar_dist,
352+
dims=("steps",),
353+
)
354+
355+
ar = pm.Deterministic(
356+
name="ar", var=pt.concatenate([ar_init, ar_innov], axis=-1), dims=("trials",)
357+
)
358+
359+
post_pred_conditional = pm.sample_posterior_predictive(trace, var_names=["ar"], random_seed=rng)
360+
```
361+
362+
Let's visualize the conditional posterior predictive distribution:
363+
364+
```{code-cell} ipython3
365+
conditional_post_pred_ar = post_pred_conditional.posterior_predictive["ar"]
366+
367+
_, ax = plt.subplots()
368+
for i, hdi_prob in enumerate((0.94, 0.64), 1):
369+
hdi = az.hdi(conditional_post_pred_ar, hdi_prob=hdi_prob)["ar"]
370+
lower = hdi.sel(hdi="lower")
371+
upper = hdi.sel(hdi="higher")
372+
ax.fill_between(
373+
x=np.arange(trials),
374+
y1=lower,
375+
y2=upper,
376+
alpha=(i - 0.2) * 0.2,
377+
color="C1",
378+
label=f"{hdi_prob:.0%} HDI",
379+
)
380+
ax.plot(conditional_post_pred_ar.mean(("chain", "draw")), color="C1", label="Mean")
381+
ax.plot(ar_obs, color="black", label="Observed")
382+
ax.legend(loc="upper right")
383+
ax.set_xlabel("time")
384+
ax.set_title("AR(2) Conditional Posterior Predictive Samples", fontsize=18, fontweight="bold");
385+
```
386+
387+
We indeed see that these credible intervals are tighter than the unconditional ones.
388+
389+
+++
390+
310391
## Authors
311392
- Authored by [Jesse Grabowski](https://github.com/jessegrabowski), [Juan Orduz](https://juanitorduz.github.io/) and [Ricardo Vieira](https://github.com/ricardoV94) in March 2024
312393

0 commit comments

Comments
 (0)