Skip to content

Commit 986b445

Browse files
committed
update the vectorization strategy
1 parent 2a37d97 commit 986b445

File tree

1 file changed

+23
-25
lines changed

1 file changed

+23
-25
lines changed

lectures/ar1_turningpts.md

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ jupytext:
44
extension: .md
55
format_name: myst
66
format_version: 0.13
7-
jupytext_version: 1.17.3
7+
jupytext_version: 1.17.1
88
kernelspec:
99
display_name: Python 3 (ipykernel)
1010
language: python
@@ -53,7 +53,6 @@ from typing import NamedTuple
5353
import numpyro
5454
import numpyro.distributions as dist
5555
from numpyro.infer import MCMC, NUTS
56-
numpyro.set_host_device_count(4)
5756
5857
# jax
5958
import jax
@@ -67,9 +66,6 @@ import arviz as az
6766
sns.set_style('white')
6867
colors = sns.color_palette()
6968
key = random.PRNGKey(0)
70-
71-
print(f"Available devices: {jax.local_device_count()}")
72-
print(f"Device list: {jax.devices()}")
7369
```
7470

7571
## A Univariate First-Order Autoregressive Process
@@ -456,16 +452,17 @@ Note that in defining the likelihood function, we choose to condition on the ini
456452
---
457453
mystnb:
458454
figure:
459-
caption: "AR(1) model"
455+
caption: AR(1) model
460456
name: fig_trace
461457
---
462458
def draw_from_posterior(data, size=10000, bins=20, dis_plot=1, key=key):
463-
"""Draw a sample of size from the posterior distribution."""
459+
"""Draw a sample of size from the posterior distribution."""
460+
464461
def model(data):
465462
# Start with priors
466-
ρ = numpyro.sample('rho', dist.Uniform(-1, 1)) # Assume stable ρ
467-
σ = numpyro.sample('sigma', dist.HalfNormal(jnp.sqrt(10)))
468-
463+
ρ = numpyro.sample('ρ', dist.Uniform(-1, 1)) # Assume stable ρ
464+
σ = numpyro.sample('σ', dist.HalfNormal(jnp.sqrt(10)))
465+
469466
# Define likelihood recursively
470467
for t in range(1, len(data)):
471468
# Expectation of y_t
@@ -476,39 +473,40 @@ def draw_from_posterior(data, size=10000, bins=20, dis_plot=1, key=key):
476473
477474
# Compute posterior distribution of parameters
478475
nuts_kernel = NUTS(model)
479-
476+
480477
# Define MCMC class to compute the posteriors
481478
mcmc = MCMC(
482479
nuts_kernel,
483480
num_warmup=5000,
484481
num_samples=size,
485-
num_chains=4, # plot 4 chains in the trace
486-
progress_bar=False)
487-
482+
num_chains=4, # plot 4 chains in the trace
483+
progress_bar=False,
484+
chain_method='vectorized'
485+
)
486+
488487
# Run MCMC
489488
mcmc.run(key, data=data)
490-
489+
491490
# Get posterior samples
492491
post_sample = {
493-
'rho': mcmc.get_samples()['rho'],
494-
'sigma': mcmc.get_samples()['sigma']
492+
'ρ': mcmc.get_samples()['ρ'],
493+
'σ': mcmc.get_samples()['σ'],
495494
}
496-
495+
497496
# Plot posterior distributions and trace plots
498497
if dis_plot == 1:
499498
plot_data = az.from_numpyro(posterior=mcmc)
500499
axes = az.plot_trace(
501500
data=plot_data,
502501
compact=True,
503502
lines=[
504-
("rho", {}, ar1.ρ),
505-
("sigma", {}, ar1.σ),
503+
("σ", {}, ar1.ρ),
504+
("ρ", {}, ar1.σ),
506505
],
507506
backend_kwargs={"figsize": (10, 6), "layout": "constrained"},
508507
)
509-
510-
return post_sample
511508
509+
return post_sample
512510
513511
post_samples = draw_from_posterior(initial_path)
514512
```
@@ -688,10 +686,10 @@ def plot_extended_Wecker(
688686
689687
# Select a parameter sample
690688
index = random.choice(
691-
key, jnp.arange(len(post_samples['rho'])), (N + 1,), replace=False
689+
key, jnp.arange(len(post_samples['ρ'])), (N + 1,), replace=False
692690
)
693-
ρ_sample = post_samples['rho'][index]
694-
σ_sample = post_samples['sigma'][index]
691+
ρ_sample = post_samples['ρ'][index]
692+
σ_sample = post_samples['σ'][index]
695693
696694
# Compute path statistics
697695
subkeys = random.split(key, num=N)

0 commit comments

Comments
 (0)