@@ -4,7 +4,7 @@ jupytext:
44 extension : .md
55 format_name : myst
66 format_version : 0.13
7- jupytext_version : 1.16.4
7+ jupytext_version : 1.17.3
88kernelspec :
99 display_name : Python 3 (ipykernel)
1010 language : python
@@ -19,7 +19,7 @@ kernelspec:
1919``` {code-cell} ipython3
2020:tags: [hide-output]
2121
22- !pip install numpyro jax
22+ !pip install numpyro jax arviz
2323```
2424
2525This lecture describes methods for forecasting statistics that are functions of future values of a univariate autoregressive process.
@@ -60,6 +60,9 @@ import numpyro
6060import numpyro.distributions as dist
6161from numpyro.infer import MCMC, NUTS
6262
63+ # arviz
64+ import arviz as az
65+
6366sns.set_style('white')
6467colors = sns.color_palette()
6568key = random.PRNGKey(0)
@@ -141,9 +144,9 @@ class AR1(NamedTuple):
141144
142145 Parameters
143146 ----------
144- rho : float
145- Autoregressive coefficient, must satisfy |rho | < 1 for stationarity.
146- sigma : float
147+ ρ : float
148+ Autoregressive coefficient, must satisfy |ρ | < 1 for stationarity.
149+ σ : float
147150 Standard deviation of the error term.
148151 y0 : float
149152 Initial value of the process at time t=0.
@@ -152,11 +155,23 @@ class AR1(NamedTuple):
152155 T1 : int, optional
153156 Length of the future path to simulate (default is 100).
154157 """
155- rho : float
156- sigma : float
158+ ρ : float
159+ σ : float
157160 y0: float
158- T0: int=100
159- T1: int=100
161+ T0: int
162+ T1: int
163+
164+
165+ def make_ar1(ρ: float, σ: float, y0: float, T0: int = 100, T1: int = 100):
166+ """
167+ Factory function to create an AR1 instance with default values for T0 and T1.
168+
169+ Returns
170+ -------
171+ AR1
172+ AR1 named tuple containing the specified parameters.
173+ """
174+ return AR1(ρ=ρ, σ=σ, y0=y0, T0=T0, T1=T1)
160175
161176
162177def AR1_simulate_past(ar1: AR1, key=key):
@@ -166,7 +181,7 @@ def AR1_simulate_past(ar1: AR1, key=key):
166181 Parameters
167182 ----------
168183 ar1 : AR1
169- AR1 named tuple containing parameters (rho, sigma , y0, T0, T1).
184+ AR1 named tuple containing parameters (ρ, σ , y0, T0, T1).
170185 key : jax.random.PRNGKey
171186 JAX random key for generating random noise.
172187
@@ -175,18 +190,18 @@ def AR1_simulate_past(ar1: AR1, key=key):
175190 initial_path : jax.numpy.ndarray
176191 Simulated path of the AR(1) process and the initial y0.
177192 """
178- rho, sigma , y0, T0 = ar1.rho , ar1.sigma , ar1.y0, ar1.T0
179- # Draw epsilons
180- eps = sigma * random.normal(key, (T0,))
193+ ρ, σ , y0, T0 = ar1.ρ , ar1.σ , ar1.y0, ar1.T0
194+ # Draw εs
195+ ε = σ * random.normal(key, (T0,))
181196
182197 # Set step function
183- def ar1_step(y_prev, t_rho_eps ):
184- rho, eps_t = t_rho_eps
185- y_t = rho * y_prev + eps_t
198+ def ar1_step(y_prev, t_ρ_ε ):
199+ ρ, ε_t = t_ρ_ε
200+ y_t = ρ * y_prev + ε_t
186201 return y_t, y_t
187202
188203 # Scan over time steps
189- _, y_seq = lax.scan(ar1_step, y0, (jnp.full(T0, rho ), eps ))
204+ _, y_seq = lax.scan(ar1_step, y0, (jnp.full(T0, ρ ), ε ))
190205
191206 # Concatenate initial value
192207 initial_path = jnp.concatenate([jnp.array([y0]), y_seq])
@@ -204,7 +219,7 @@ def AR1_simulate_future(ar1: AR1, y_T0, N=10, key=key):
204219 Parameters
205220 ----------
206221 ar1 : AR1
207- AR1 named tuple containing parameters (rho, sigma , y0, T0, T1).
222+ AR1 named tuple containing parameters (ρ, σ , y0, T0, T1).
208223 y_T0 : float
209224 Value of the process at time T0.
210225 N: int
@@ -217,16 +232,16 @@ def AR1_simulate_future(ar1: AR1, y_T0, N=10, key=key):
217232 future_path : jax.numpy.ndarray
218233 Simulated N paths of the AR(1) process of length T1.
219234 """
220- rho, sigma , T1 = ar1.rho , ar1.sigma , ar1.T1
235+ ρ, σ , T1 = ar1.ρ , ar1.σ , ar1.T1
221236
222237 def single_path_scan(y_T0, subkey):
223- eps = sigma * random.normal(subkey, (T1,))
238+ ε = σ * random.normal(subkey, (T1,))
224239
225- def ar1_step(y_prev, t_rho_eps ):
226- rho, eps_t = t_rho_eps
227- y_t = rho * y_prev + eps_t
240+ def ar1_step(y_prev, t_ρ_ε ):
241+ ρ, ε_t = t_ρ_ε
242+ y_t = ρ * y_prev + ε_t
228243 return y_t, y_t
229- _, y = lax.scan(ar1_step, y_T0, (jnp.full(T1, rho ), eps ))
244+ _, y = lax.scan(ar1_step, y_T0, (jnp.full(T1, ρ ), ε ))
230245 return y
231246
232247 # Split key to generate different paths
@@ -249,7 +264,7 @@ def plot_path(ar1, initial_path, future_path, ax, key=key):
249264 Parameters
250265 ----------
251266 ar1 : AR1
252- AR1 named tuple containing process parameters (rho, sigma , T0, T1).
267+ AR1 named tuple containing process parameters (ρ, σ , T0, T1).
253268 initial_path : array-like
254269 Simulated initial path of the AR(1) process, shape(T0+1,).
255270 future_path : array-like
@@ -266,13 +281,13 @@ def plot_path(ar1, initial_path, future_path, ax, key=key):
266281 - 90% and 95% predictive confidence intervals
267282 - Expected future path
268283 """
269- rho, sigma , T0, T1 = ar1.rho , ar1.sigma , ar1.T0, ar1.T1
284+ ρ, σ , T0, T1 = ar1.ρ , ar1.σ , ar1.T0, ar1.T1
270285
271286 # Compute moments and confidence intervals
272287 y_T0 = initial_path[-1]
273288 j = jnp.arange(1, T1+1)
274- center = rho **j * y_T0
275- vars = sigma **2 * (1 - rho **(2 * j)) / (1 - rho **2)
289+ center = ρ **j * y_T0
290+ vars = σ **2 * (1 - ρ **(2 * j)) / (1 - ρ **2)
276291
277292 # 95% CI
278293 y_upper_c95 = center + 1.96 * jnp.sqrt(vars)
@@ -313,11 +328,10 @@ def plot_path(ar1, initial_path, future_path, ax, key=key):
313328---
314329mystnb:
315330 figure:
316- caption: |
317- Initial and predictive future paths
331+ caption: "Initial and predictive future paths \n"
318332 name: fig_path
319333---
320- ar1 = AR1(rho =0.9, sigma =1, y0=10, T0=100, T1=100 )
334+ ar1 = make_ar1(ρ =0.9, σ =1, y0=10)
321335
322336# Simulate
323337initial_path = AR1_simulate_past(ar1)
@@ -438,24 +452,23 @@ Note that in defining the likelihood function, we choose to condition on the ini
438452---
439453mystnb:
440454 figure:
441- caption: |
442- Posterior distributions (rho, sigma)
443- name: fig_post
455+ caption: "AR(1) model"
456+ name: fig_trace
444457---
445458def draw_from_posterior(data, size=10000, bins=20, dis_plot=1, key=key):
446459 """Draw a sample of size from the posterior distribution."""
447460 def model(data):
448461 # Start with priors
449- rho = numpyro.sample('rho', dist.Uniform(-1, 1)) # Assume stable rho
450- sigma = numpyro.sample('sigma', dist.HalfNormal(jnp.sqrt(10)))
462+ ρ = numpyro.sample('rho', dist.Uniform(-1, 1)) # Assume stable ρ
463+ σ = numpyro.sample('sigma', dist.HalfNormal(jnp.sqrt(10)))
451464
452465 # Define likelihood recursively
453466 for t in range(1, len(data)):
454467 # Expectation of y_t
455- mu = rho * data[t-1]
468+ μ = ρ * data[t-1]
456469
457470 # Likelihood of the actual realization.
458- numpyro.sample(f'y_{t}', dist.Normal(mu, sigma ), obs=data[t])
471+ numpyro.sample(f'y_{t}', dist.Normal(μ, σ ), obs=data[t])
459472
460473 # Compute posterior distribution of parameters
461474 nuts_kernel = NUTS(model)
@@ -465,6 +478,7 @@ def draw_from_posterior(data, size=10000, bins=20, dis_plot=1, key=key):
465478 nuts_kernel,
466479 num_warmup=5000,
467480 num_samples=size,
481+ num_chains=4, # plot 4 chains in the trace
468482 progress_bar=False)
469483
470484 # Run MCMC
@@ -476,24 +490,26 @@ def draw_from_posterior(data, size=10000, bins=20, dis_plot=1, key=key):
476490 'sigma': mcmc.get_samples()['sigma']
477491 }
478492
479- # Plot posterior distributions
493+ # Plot posterior distributions and trace plots
480494 if dis_plot == 1:
481- fig, ax = plt.subplots(1, 2, figsize=(12, 5))
482- sns.histplot(
483- post_sample['rho'], kde=True, stat="density", bins=bins, ax=ax[0]
495+ plot_data = az.from_numpyro(posterior=mcmc)
496+ axes = az.plot_trace(
497+ data=plot_data,
498+ compact=True,
499+ lines=[
500+ ("ρ", {}, ar1.ρ),
501+ ("σ", {}, ar1.σ),
502+ ],
503+ backend_kwargs={"figsize": (10, 6), "layout": "constrained"},
484504 )
485- ax[0].set_xlabel(r"$\rho$")
486- sns.histplot(
487- post_sample['sigma'], kde=True, stat="density", bins=bins, ax=ax[1]
488- )
489- ax[1].set_xlabel(r"$\sigma$")
505+
490506 return post_sample
491507
492508
493509post_samples = draw_from_posterior(initial_path)
494510```
495511
496- The graphs above portray posterior distributions.
512+ The graphs above portray posterior distributions and trace plots. The posterior distributions (top row) show the marginal distributions of the parameters after observing the data, while the trace plots (bottom row) help diagnose MCMC convergence by showing how the sampler explored the parameter space over iterations .
497513
498514## Calculating Sample Path Statistics
499515
@@ -567,6 +583,7 @@ def compute_path_statistics(initial_path, future_path):
567583 )
568584 return path_stats
569585```
586+
570587The following function creates visualizations of the path statistics in a subplot grid.
571588
572589```{code-cell} ipython3
@@ -611,7 +628,7 @@ def plot_Wecker(ar1: AR1, initial_path, ax, N=1000):
611628 Parameters
612629 ----------
613630 ar1 : AR1
614- An AR1 named tuple containing the process parameters (rho, sigma , T0, T1).
631+ An AR1 named tuple containing the process parameters (ρ, σ , T0, T1).
615632 initial_path : array-like
616633 The initial observed path of the AR(1) process.
617634 N : int
@@ -669,14 +686,14 @@ def plot_extended_Wecker(
669686 index = random.choice(
670687 key, jnp.arange(len(post_samples['rho'])), (N + 1,), replace=False
671688 )
672- rho_sample = post_samples['rho'][index]
673- sigma_sample = post_samples['sigma'][index]
689+ ρ_sample = post_samples['rho'][index]
690+ σ_sample = post_samples['sigma'][index]
674691
675692 # Compute path statistics
676693 subkeys = random.split(key, num=N)
677694
678695 def step(carry, n):
679- ar1_n = AR1(rho=rho_sample [n], sigma=sigma_sample [n], y0=y0)
696+ ar1_n = make_ar1(ρ=ρ_sample [n], σ=σ_sample [n], y0=y0, T1=T1 )
680697 future_temp = AR1_simulate_future(
681698 ar1_n, y_T0, N=1, key=subkeys[n]
682699 ).reshape(-1)
0 commit comments