@@ -4,7 +4,7 @@ jupytext:
44 extension : .md
55 format_name : myst
66 format_version : 0.13
7- jupytext_version : 1.17.3
7+ jupytext_version : 1.17.1
88kernelspec :
99 display_name : Python 3 (ipykernel)
1010 language : python
@@ -53,7 +53,6 @@ from typing import NamedTuple
5353import numpyro
5454import numpyro.distributions as dist
5555from numpyro.infer import MCMC, NUTS
56- numpyro.set_host_device_count(4)
5756
5857# jax
5958import jax
@@ -67,9 +66,6 @@ import arviz as az
6766sns.set_style('white')
6867colors = sns.color_palette()
6968key = random.PRNGKey(0)
70-
71- print(f"Available devices: {jax.local_device_count()}")
72- print(f"Device list: {jax.devices()}")
7369```
7470
7571## A Univariate First-Order Autoregressive Process
@@ -456,16 +452,17 @@ Note that in defining the likelihood function, we choose to condition on the ini
456452---
457453mystnb:
458454 figure:
459- caption: " AR(1) model"
455+ caption: AR(1) model
460456 name: fig_trace
461457---
462458def draw_from_posterior(data, size=10000, bins=20, dis_plot=1, key=key):
463- """Draw a sample of size from the posterior distribution."""
459+ """Draw a sample of size from the posterior distribution."""
460+
464461 def model(data):
465462 # Start with priors
466- ρ = numpyro.sample('rho ', dist.Uniform(-1, 1)) # Assume stable ρ
467- σ = numpyro.sample('sigma ', dist.HalfNormal(jnp.sqrt(10)))
468-
463+ ρ = numpyro.sample('ρ ', dist.Uniform(-1, 1)) # Assume stable ρ
464+ σ = numpyro.sample('σ ', dist.HalfNormal(jnp.sqrt(10)))
465+
469466 # Define likelihood recursively
470467 for t in range(1, len(data)):
471468 # Expectation of y_t
@@ -476,39 +473,40 @@ def draw_from_posterior(data, size=10000, bins=20, dis_plot=1, key=key):
476473
477474 # Compute posterior distribution of parameters
478475 nuts_kernel = NUTS(model)
479-
476+
480477 # Define MCMC class to compute the posteriors
481478 mcmc = MCMC(
482479 nuts_kernel,
483480 num_warmup=5000,
484481 num_samples=size,
485- num_chains=4, # plot 4 chains in the trace
486- progress_bar=False)
487-
482+ num_chains=4, # plot 4 chains in the trace
483+ progress_bar=False,
484+ chain_method='vectorized'
485+ )
486+
488487 # Run MCMC
489488 mcmc.run(key, data=data)
490-
489+
491490 # Get posterior samples
492491 post_sample = {
493- 'rho ': mcmc.get_samples()['rho '],
494- 'sigma ': mcmc.get_samples()['sigma']
492+ 'ρ ': mcmc.get_samples()['ρ '],
493+ 'σ ': mcmc.get_samples()['σ'],
495494 }
496-
495+
497496 # Plot posterior distributions and trace plots
498497 if dis_plot == 1:
499498 plot_data = az.from_numpyro(posterior=mcmc)
500499 axes = az.plot_trace(
501500 data=plot_data,
502501 compact=True,
503502 lines=[
504- ("rho ", {}, ar1.ρ),
505- ("sigma ", {}, ar1.σ),
503+ ("σ ", {}, ar1.ρ),
504+ ("ρ ", {}, ar1.σ),
506505 ],
507506 backend_kwargs={"figsize": (10, 6), "layout": "constrained"},
508507 )
509-
510- return post_sample
511508
509+ return post_sample
512510
513511post_samples = draw_from_posterior(initial_path)
514512```
@@ -688,10 +686,10 @@ def plot_extended_Wecker(
688686
689687 # Select a parameter sample
690688 index = random.choice(
691- key, jnp.arange(len(post_samples['rho '])), (N + 1,), replace=False
689+ key, jnp.arange(len(post_samples['ρ '])), (N + 1,), replace=False
692690 )
693- ρ_sample = post_samples['rho '][index]
694- σ_sample = post_samples['sigma '][index]
691+ ρ_sample = post_samples['ρ '][index]
692+ σ_sample = post_samples['σ '][index]
695693
696694 # Compute path statistics
697695 subkeys = random.split(key, num=N)
0 commit comments