Skip to content

Commit 547121a

Browse files
committed
minor updates
1 parent 3e2f6cf commit 547121a

File tree

1 file changed

+17
-21
lines changed

1 file changed

+17
-21
lines changed

lectures/ar1_bayes.md

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,8 @@ We'll begin with some Python imports.
3030
import numpyro
3131
from numpyro import distributions as dist
3232
33-
import numpy as np
3433
import jax.numpy as jnp
35-
from jax import random
34+
from jax import random, lax
3635
import matplotlib.pyplot as plt
3736
```
3837

@@ -129,25 +128,23 @@ How we select the initial value $y_0$ matters.
129128
To illustrate the issue, we'll begin by choosing an initial $y_0$ that is far out in a tail of the stationary distribution.
130129
131130
```{code-cell} ipython3
132-
def ar1_simulate(ρ, σ, y0, T):
131+
def ar1_simulate(ρ, σ, y0, T, key):
132+
ε = random.normal(key, shape=(T,)) * σ
133133
134-
# Allocate space and draw epsilons
135-
y = np.empty(T)
136-
eps = np.random.normal(0., σ, T)
134+
def scan_fn(y_prev, ε_t):
135+
y_t = ρ * y_prev + ε_t
136+
return y_t, y_t
137137
138-
# Initial condition and step forward
139-
y[0] = y0
140-
for t in range(1, T):
141-
y[t] = ρ * y[t-1] + eps[t]
138+
_, y = lax.scan(scan_fn, y0, ε)
142139
143140
return y
144141
145142
σ = 1.0
146143
ρ = 0.5
147144
T = 50
148145
149-
np.random.seed(145353452)
150-
y = ar1_simulate(ρ, σ, 10, T)
146+
key = random.PRNGKey(145353452)
147+
y = ar1_simulate(ρ, σ, 10.0, T, key)
151148
```
152149
153150
```{code-cell} ipython3
@@ -168,15 +165,14 @@ def plot_posterior(sample):
168165
"""
169166
Plot trace and histogram
170167
"""
171-
# To np array
172168
ρs = sample['ρ']
173169
σs = sample['σ']
174-
ρs, σs = np.array(ρs), np.array(σs)
175170
176171
fig, axs = plt.subplots(2, 2, figsize=(17, 6))
172+
177173
# Plot trace
178174
axs[0, 0].plot(ρs) # ρ
179-
axs[1, 0].plot(σs) # σ
175+
axs[1, 0].plot(σs) # σ
180176
181177
# Plot posterior
182178
axs[0, 1].hist(ρs, bins=50, density=True, alpha=0.7)
@@ -195,7 +191,7 @@ def plot_posterior(sample):
195191
def AR1_model(data):
196192
# set prior
197193
ρ = numpyro.sample('ρ', dist.Uniform(low=-1., high=1.))
198-
σ = numpyro.sample('σ', dist.HalfNormal(scale=np.sqrt(10)))
194+
σ = numpyro.sample('σ', dist.HalfNormal(scale=jnp.sqrt(10)))
199195
200196
# Expected value of y at the next period (ρ * y)
201197
yhat = ρ * data[:-1]
@@ -209,14 +205,14 @@ def AR1_model(data):
209205
```{code-cell} ipython3
210206
:tags: [hide-output]
211207
212-
# Make jnp array
213208
y = jnp.array(y)
214209
215210
# Set NUTS kernel
216211
NUTS_kernel = numpyro.infer.NUTS(AR1_model)
217212
218213
# Run MCMC
219-
mcmc = numpyro.infer.MCMC(NUTS_kernel, num_samples=50000, num_warmup=10000, progress_bar=False)
214+
mcmc = numpyro.infer.MCMC(NUTS_kernel,
215+
num_samples=50000, num_warmup=10000, progress_bar=False)
220216
mcmc.run(rng_key=random.PRNGKey(1), data=y)
221217
```
222218
@@ -250,7 +246,7 @@ Here's the new code to achieve this.
250246
def AR1_model_y0(data):
251247
# Set prior
252248
ρ = numpyro.sample('ρ', dist.Uniform(low=-1., high=1.))
253-
σ = numpyro.sample('σ', dist.HalfNormal(scale=np.sqrt(10)))
249+
σ = numpyro.sample('σ', dist.HalfNormal(scale=jnp.sqrt(10)))
254250
255251
# Standard deviation of ergodic y
256252
y_sd = σ / jnp.sqrt(1 - ρ**2)
@@ -268,14 +264,14 @@ def AR1_model_y0(data):
268264
```{code-cell} ipython3
269265
:tags: [hide-output]
270266
271-
# Make jnp array
272267
y = jnp.array(y)
273268
274269
# Set NUTS kernel
275270
NUTS_kernel = numpyro.infer.NUTS(AR1_model_y0)
276271
277272
# Run MCMC
278-
mcmc2 = numpyro.infer.MCMC(NUTS_kernel, num_samples=50000, num_warmup=10000, progress_bar=False)
273+
mcmc2 = numpyro.infer.MCMC(NUTS_kernel,
274+
num_samples=50000, num_warmup=10000, progress_bar=False)
279275
mcmc2.run(rng_key=random.PRNGKey(1), data=y)
280276
```
281277

0 commit comments

Comments
 (0)