@@ -30,9 +30,8 @@ We'll begin with some Python imports.
3030import numpyro
3131from numpyro import distributions as dist
3232
33- import numpy as np
3433import jax.numpy as jnp
35- from jax import random
34+ from jax import random, lax
3635import matplotlib.pyplot as plt
3736```
3837
@@ -129,25 +128,23 @@ How we select the initial value $y_0$ matters.
129128To illustrate the issue, we'll begin by choosing an initial $y_0$ that is far out in a tail of the stationary distribution.
130129
131130```{code-cell} ipython3
132- def ar1_simulate(ρ, σ, y0, T):
131+ def ar1_simulate(ρ, σ, y0, T, key):
132+ ε = random.normal(key, shape=(T,)) * σ
133133
134- # Allocate space and draw epsilons
135- y = np.empty(T)
136- eps = np.random.normal(0., σ, T)
134+ def scan_fn(y_prev, ε_t):
135+ y_t = ρ * y_prev + ε_t
136+ return y_t, y_t
137137
138- # Initial condition and step forward
139- y[0] = y0
140- for t in range(1, T):
141- y[t] = ρ * y[t-1] + eps[t]
138+ _, y = lax.scan(scan_fn, y0, ε)
142139
143140 return y
144141
145142σ = 1.0
146143ρ = 0.5
147144T = 50
148145
149- np. random.seed (145353452)
150- y = ar1_simulate(ρ, σ, 10, T)
146+ key = random.PRNGKey (145353452)
147+ y = ar1_simulate(ρ, σ, 10.0 , T, key )
151148```
152149
153150```{code-cell} ipython3
@@ -168,15 +165,14 @@ def plot_posterior(sample):
168165 """
169166 Plot trace and histogram
170167 """
171- # To np array
172168 ρs = sample['ρ']
173169 σs = sample['σ']
174- ρs, σs = np.array(ρs), np.array(σs)
175170
176171 fig, axs = plt.subplots(2, 2, figsize=(17, 6))
172+
177173 # Plot trace
178174 axs[0, 0].plot(ρs) # ρ
179- axs[1, 0].plot(σs) # σ
175+ axs[1, 0].plot(σs) # σ
180176
181177 # Plot posterior
182178 axs[0, 1].hist(ρs, bins=50, density=True, alpha=0.7)
@@ -195,7 +191,7 @@ def plot_posterior(sample):
195191def AR1_model(data):
196192 # set prior
197193 ρ = numpyro.sample('ρ', dist.Uniform(low=-1., high=1.))
198- σ = numpyro.sample('σ', dist.HalfNormal(scale=np .sqrt(10)))
194+ σ = numpyro.sample('σ', dist.HalfNormal(scale=jnp .sqrt(10)))
199195
200196 # Expected value of y at the next period (ρ * y)
201197 yhat = ρ * data[:-1]
@@ -209,14 +205,14 @@ def AR1_model(data):
209205```{code-cell} ipython3
210206:tags: [hide-output]
211207
212- # Make jnp array
213208y = jnp.array(y)
214209
215210# Set NUTS kernel
216211NUTS_kernel = numpyro.infer.NUTS(AR1_model)
217212
218213# Run MCMC
219- mcmc = numpyro.infer.MCMC(NUTS_kernel, num_samples=50000, num_warmup=10000, progress_bar=False)
214+ mcmc = numpyro.infer.MCMC(NUTS_kernel,
215+ num_samples=50000, num_warmup=10000, progress_bar=False)
220216mcmc.run(rng_key=random.PRNGKey(1), data=y)
221217```
222218
@@ -250,7 +246,7 @@ Here's the new code to achieve this.
250246def AR1_model_y0(data):
251247 # Set prior
252248 ρ = numpyro.sample('ρ', dist.Uniform(low=-1., high=1.))
253- σ = numpyro.sample('σ', dist.HalfNormal(scale=np .sqrt(10)))
249+ σ = numpyro.sample('σ', dist.HalfNormal(scale=jnp .sqrt(10)))
254250
255251 # Standard deviation of ergodic y
256252 y_sd = σ / jnp.sqrt(1 - ρ**2)
@@ -268,14 +264,14 @@ def AR1_model_y0(data):
268264```{code-cell} ipython3
269265:tags: [hide-output]
270266
271- # Make jnp array
272267y = jnp.array(y)
273268
274269# Set NUTS kernel
275270NUTS_kernel = numpyro.infer.NUTS(AR1_model_y0)
276271
277272# Run MCMC
278- mcmc2 = numpyro.infer.MCMC(NUTS_kernel, num_samples=50000, num_warmup=10000, progress_bar=False)
273+ mcmc2 = numpyro.infer.MCMC(NUTS_kernel,
274+ num_samples=50000, num_warmup=10000, progress_bar=False)
279275mcmc2.run(rng_key=random.PRNGKey(1), data=y)
280276```
281277
0 commit comments