Skip to content

Commit dbd3605

Browse files
committed
use Greek letters and arviz plot
1 parent d06a556 commit dbd3605

File tree

1 file changed

+69
-52
lines changed

1 file changed

+69
-52
lines changed

lectures/ar1_turningpts.md

Lines changed: 69 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ jupytext:
44
extension: .md
55
format_name: myst
66
format_version: 0.13
7-
jupytext_version: 1.16.4
7+
jupytext_version: 1.17.3
88
kernelspec:
99
display_name: Python 3 (ipykernel)
1010
language: python
@@ -19,7 +19,7 @@ kernelspec:
1919
```{code-cell} ipython3
2020
:tags: [hide-output]
2121
22-
!pip install numpyro jax
22+
!pip install numpyro jax arviz
2323
```
2424

2525
This lecture describes methods for forecasting statistics that are functions of future values of a univariate autoregressive process.
@@ -60,6 +60,9 @@ import numpyro
6060
import numpyro.distributions as dist
6161
from numpyro.infer import MCMC, NUTS
6262
63+
# arviz
64+
import arviz as az
65+
6366
sns.set_style('white')
6467
colors = sns.color_palette()
6568
key = random.PRNGKey(0)
@@ -141,9 +144,9 @@ class AR1(NamedTuple):
141144
142145
Parameters
143146
----------
144-
rho : float
145-
Autoregressive coefficient, must satisfy |rho| < 1 for stationarity.
146-
sigma : float
147+
ρ : float
148+
Autoregressive coefficient, must satisfy |ρ| < 1 for stationarity.
149+
σ : float
147150
Standard deviation of the error term.
148151
y0 : float
149152
Initial value of the process at time t=0.
@@ -152,11 +155,23 @@ class AR1(NamedTuple):
152155
T1 : int, optional
153156
Length of the future path to simulate (default is 100).
154157
"""
155-
rho: float
156-
sigma: float
158+
ρ: float
159+
σ: float
157160
y0: float
158-
T0: int=100
159-
T1: int=100
161+
T0: int
162+
T1: int
163+
164+
165+
def make_ar1(ρ: float, σ: float, y0: float, T0: int = 100, T1: int = 100):
166+
"""
167+
Factory function to create an AR1 instance with default values for T0 and T1.
168+
169+
Returns
170+
-------
171+
AR1
172+
AR1 named tuple containing the specified parameters.
173+
"""
174+
return AR1(ρ=ρ, σ=σ, y0=y0, T0=T0, T1=T1)
160175
161176
162177
def AR1_simulate_past(ar1: AR1, key=key):
@@ -166,7 +181,7 @@ def AR1_simulate_past(ar1: AR1, key=key):
166181
Parameters
167182
----------
168183
ar1 : AR1
169-
AR1 named tuple containing parameters (rho, sigma, y0, T0, T1).
184+
AR1 named tuple containing parameters (ρ, σ, y0, T0, T1).
170185
key : jax.random.PRNGKey
171186
JAX random key for generating random noise.
172187
@@ -175,18 +190,18 @@ def AR1_simulate_past(ar1: AR1, key=key):
175190
initial_path : jax.numpy.ndarray
176191
Simulated path of the AR(1) process and the initial y0.
177192
"""
178-
rho, sigma, y0, T0 = ar1.rho, ar1.sigma, ar1.y0, ar1.T0
179-
# Draw epsilons
180-
eps = sigma * random.normal(key, (T0,))
193+
ρ, σ, y0, T0 = ar1.ρ, ar1.σ, ar1.y0, ar1.T0
194+
# Draw εs
195+
ε = σ * random.normal(key, (T0,))
181196
182197
# Set step function
183-
def ar1_step(y_prev, t_rho_eps):
184-
rho, eps_t = t_rho_eps
185-
y_t = rho * y_prev + eps_t
198+
def ar1_step(y_prev, t_ρ_ε):
199+
ρ, ε_t = t_ρ_ε
200+
y_t = ρ * y_prev + ε_t
186201
return y_t, y_t
187202
188203
# Scan over time steps
189-
_, y_seq = lax.scan(ar1_step, y0, (jnp.full(T0, rho), eps))
204+
_, y_seq = lax.scan(ar1_step, y0, (jnp.full(T0, ρ), ε))
190205
191206
# Concatenate initial value
192207
initial_path = jnp.concatenate([jnp.array([y0]), y_seq])
@@ -204,7 +219,7 @@ def AR1_simulate_future(ar1: AR1, y_T0, N=10, key=key):
204219
Parameters
205220
----------
206221
ar1 : AR1
207-
AR1 named tuple containing parameters (rho, sigma, y0, T0, T1).
222+
AR1 named tuple containing parameters (ρ, σ, y0, T0, T1).
208223
y_T0 : float
209224
Value of the process at time T0.
210225
N: int
@@ -217,16 +232,16 @@ def AR1_simulate_future(ar1: AR1, y_T0, N=10, key=key):
217232
future_path : jax.numpy.ndarray
218233
Simulated N paths of the AR(1) process of length T1.
219234
"""
220-
rho, sigma, T1 = ar1.rho, ar1.sigma, ar1.T1
235+
ρ, σ, T1 = ar1.ρ, ar1.σ, ar1.T1
221236
222237
def single_path_scan(y_T0, subkey):
223-
eps = sigma * random.normal(subkey, (T1,))
238+
ε = σ * random.normal(subkey, (T1,))
224239
225-
def ar1_step(y_prev, t_rho_eps):
226-
rho, eps_t = t_rho_eps
227-
y_t = rho * y_prev + eps_t
240+
def ar1_step(y_prev, t_ρ_ε):
241+
ρ, ε_t = t_ρ_ε
242+
y_t = ρ * y_prev + ε_t
228243
return y_t, y_t
229-
_, y = lax.scan(ar1_step, y_T0, (jnp.full(T1, rho), eps))
244+
_, y = lax.scan(ar1_step, y_T0, (jnp.full(T1, ρ), ε))
230245
return y
231246
232247
# Split key to generate different paths
@@ -249,7 +264,7 @@ def plot_path(ar1, initial_path, future_path, ax, key=key):
249264
Parameters
250265
----------
251266
ar1 : AR1
252-
AR1 named tuple containing process parameters (rho, sigma, T0, T1).
267+
AR1 named tuple containing process parameters (ρ, σ, T0, T1).
253268
initial_path : array-like
254269
Simulated initial path of the AR(1) process, shape(T0+1,).
255270
future_path : array-like
@@ -266,13 +281,13 @@ def plot_path(ar1, initial_path, future_path, ax, key=key):
266281
- 90% and 95% predictive confidence intervals
267282
- Expected future path
268283
"""
269-
rho, sigma, T0, T1 = ar1.rho, ar1.sigma, ar1.T0, ar1.T1
284+
ρ, σ, T0, T1 = ar1.ρ, ar1.σ, ar1.T0, ar1.T1
270285
271286
# Compute moments and confidence intervals
272287
y_T0 = initial_path[-1]
273288
j = jnp.arange(1, T1+1)
274-
center = rho**j * y_T0
275-
vars = sigma**2 * (1 - rho**(2 * j)) / (1 - rho**2)
289+
center = ρ**j * y_T0
290+
vars = σ**2 * (1 - ρ**(2 * j)) / (1 - ρ**2)
276291
277292
# 95% CI
278293
y_upper_c95 = center + 1.96 * jnp.sqrt(vars)
@@ -313,11 +328,10 @@ def plot_path(ar1, initial_path, future_path, ax, key=key):
313328
---
314329
mystnb:
315330
figure:
316-
caption: |
317-
Initial and predictive future paths
331+
caption: "Initial and predictive future paths \n"
318332
name: fig_path
319333
---
320-
ar1 = AR1(rho=0.9, sigma=1, y0=10, T0=100, T1=100)
334+
ar1 = make_ar1(ρ=0.9, σ=1, y0=10)
321335
322336
# Simulate
323337
initial_path = AR1_simulate_past(ar1)
@@ -438,24 +452,23 @@ Note that in defining the likelihood function, we choose to condition on the ini
438452
---
439453
mystnb:
440454
figure:
441-
caption: |
442-
Posterior distributions (rho, sigma)
443-
name: fig_post
455+
caption: "AR(1) model"
456+
name: fig_trace
444457
---
445458
def draw_from_posterior(data, size=10000, bins=20, dis_plot=1, key=key):
446459
"""Draw a sample of size from the posterior distribution."""
447460
def model(data):
448461
# Start with priors
449-
rho = numpyro.sample('rho', dist.Uniform(-1, 1)) # Assume stable rho
450-
sigma = numpyro.sample('sigma', dist.HalfNormal(jnp.sqrt(10)))
462+
ρ = numpyro.sample('rho', dist.Uniform(-1, 1)) # Assume stable ρ
463+
σ = numpyro.sample('sigma', dist.HalfNormal(jnp.sqrt(10)))
451464
452465
# Define likelihood recursively
453466
for t in range(1, len(data)):
454467
# Expectation of y_t
455-
mu = rho * data[t-1]
468+
μ = ρ * data[t-1]
456469
457470
# Likelihood of the actual realization.
458-
numpyro.sample(f'y_{t}', dist.Normal(mu, sigma), obs=data[t])
471+
numpyro.sample(f'y_{t}', dist.Normal(μ, σ), obs=data[t])
459472
460473
# Compute posterior distribution of parameters
461474
nuts_kernel = NUTS(model)
@@ -465,6 +478,7 @@ def draw_from_posterior(data, size=10000, bins=20, dis_plot=1, key=key):
465478
nuts_kernel,
466479
num_warmup=5000,
467480
num_samples=size,
481+
num_chains=4, # plot 4 chains in the trace
468482
progress_bar=False)
469483
470484
# Run MCMC
@@ -476,24 +490,26 @@ def draw_from_posterior(data, size=10000, bins=20, dis_plot=1, key=key):
476490
'sigma': mcmc.get_samples()['sigma']
477491
}
478492
479-
# Plot posterior distributions
493+
# Plot posterior distributions and trace plots
480494
if dis_plot == 1:
481-
fig, ax = plt.subplots(1, 2, figsize=(12, 5))
482-
sns.histplot(
483-
post_sample['rho'], kde=True, stat="density", bins=bins, ax=ax[0]
495+
plot_data = az.from_numpyro(posterior=mcmc)
496+
axes = az.plot_trace(
497+
data=plot_data,
498+
compact=True,
499+
lines=[
500+
("ρ", {}, ar1.ρ),
501+
("σ", {}, ar1.σ),
502+
],
503+
backend_kwargs={"figsize": (10, 6), "layout": "constrained"},
484504
)
485-
ax[0].set_xlabel(r"$\rho$")
486-
sns.histplot(
487-
post_sample['sigma'], kde=True, stat="density", bins=bins, ax=ax[1]
488-
)
489-
ax[1].set_xlabel(r"$\sigma$")
505+
490506
return post_sample
491507
492508
493509
post_samples = draw_from_posterior(initial_path)
494510
```
495511
496-
The graphs above portray posterior distributions.
512+
The graphs above portray posterior distributions and trace plots. The posterior distributions (top row) show the marginal distributions of the parameters after observing the data, while the trace plots (bottom row) help diagnose MCMC convergence by showing how the sampler explored the parameter space over iterations.
497513
498514
## Calculating Sample Path Statistics
499515
@@ -567,6 +583,7 @@ def compute_path_statistics(initial_path, future_path):
567583
)
568584
return path_stats
569585
```
586+
570587
The following function creates visualizations of the path statistics in a subplot grid.
571588
572589
```{code-cell} ipython3
@@ -611,7 +628,7 @@ def plot_Wecker(ar1: AR1, initial_path, ax, N=1000):
611628
Parameters
612629
----------
613630
ar1 : AR1
614-
An AR1 named tuple containing the process parameters (rho, sigma, T0, T1).
631+
An AR1 named tuple containing the process parameters (ρ, σ, T0, T1).
615632
initial_path : array-like
616633
The initial observed path of the AR(1) process.
617634
N : int
@@ -669,14 +686,14 @@ def plot_extended_Wecker(
669686
index = random.choice(
670687
key, jnp.arange(len(post_samples['rho'])), (N + 1,), replace=False
671688
)
672-
rho_sample = post_samples['rho'][index]
673-
sigma_sample = post_samples['sigma'][index]
689+
ρ_sample = post_samples['rho'][index]
690+
σ_sample = post_samples['sigma'][index]
674691
675692
# Compute path statistics
676693
subkeys = random.split(key, num=N)
677694
678695
def step(carry, n):
679-
ar1_n = AR1(rho=rho_sample[n], sigma=sigma_sample[n], y0=y0)
696+
ar1_n = make_ar1(ρ=ρ_sample[n], σ=σ_sample[n], y0=y0, T1=T1)
680697
future_temp = AR1_simulate_future(
681698
ar1_n, y_T0, N=1, key=subkeys[n]
682699
).reshape(-1)

0 commit comments

Comments
 (0)