Skip to content

Commit 0376354

Browse files
committed
address review comments
1 parent aee4dfd commit 0376354

File tree

1 file changed

+33
-31
lines changed

1 file changed

+33
-31
lines changed

lectures/ar1_turningpts.md

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -158,20 +158,8 @@ class AR1(NamedTuple):
158158
ρ: float
159159
σ: float
160160
y0: float
161-
T0: int
162-
T1: int
163-
164-
165-
def make_ar1(ρ: float, σ: float, y0: float, T0: int = 100, T1: int = 100):
166-
"""
167-
Factory function to create an AR1 instance with default values for T0 and T1.
168-
169-
Returns
170-
-------
171-
AR1
172-
AR1 named tuple containing the specified parameters.
173-
"""
174-
return AR1(ρ=ρ, σ=σ, y0=y0, T0=T0, T1=T1)
161+
T0: int = 100
162+
T1: int = 100
175163
```
176164
177165
Using the `AR1` class, we can simulate paths more conveniently. The following function simulates an initial path with $T0$ length.
@@ -334,7 +322,7 @@ mystnb:
334322
caption: "Initial and predictive future paths \n"
335323
name: fig_path
336324
---
337-
ar1 = make_ar1(ρ=0.9, σ=1, y0=10)
325+
ar1 = AR1(ρ=0.9, σ=1, y0=10)
338326
339327
# Simulate
340328
initial_path = AR1_simulate_past(ar1)
@@ -346,9 +334,7 @@ plot_path(ar1, initial_path, future_path, ax)
346334
plt.show()
347335
```
348336
349-
As functions of forecast horizon, the coverage intervals have shapes like those described in
350-
https://python.quantecon.org/perm_income_cons.html
351-
337+
As functions of forecast horizon, the coverage intervals have shapes like those described in [Permanent Income II: LQ Techniques](perm_income_cons)
352338
353339
## Predictive Distributions of Path Properties
354340
@@ -644,16 +630,24 @@ def plot_Wecker(ar1: AR1, initial_path, ax, N=1000):
644630
future_path = AR1_simulate_future(ar1, y_T0, N=N)
645631
plot_path(ar1, initial_path, future_path, ax[0, 0])
646632
633+
next_reces = jnp.zeros(N)
634+
severe_rec = jnp.zeros(N)
635+
min_val_8q = jnp.zeros(N)
636+
next_up_turn = jnp.zeros(N)
637+
next_down_turn = jnp.zeros(N)
638+
647639
# Simulate future paths and compute statistics
648-
def step(carry, n):
640+
for n in range(N):
649641
future_temp = future_path[n, :]
650-
(next_reces, severe_rec, min_val_8q, next_up_turn, next_down_turn
642+
(next_reces_val, severe_rec_val, min_val_8q_val,
643+
next_up_turn_val, next_down_turn_val
651644
) = compute_path_statistics(initial_path, future_temp)
652645
653-
return carry, (next_reces, severe_rec, min_val_8q, next_up_turn, next_down_turn)
654-
655-
_, (next_reces, severe_rec, min_val_8q, next_up_turn, next_down_turn
656-
) = lax.scan(step, None, jnp.arange(N))
646+
next_reces = next_reces.at[n].set(next_reces_val)
647+
severe_rec = severe_rec.at[n].set(severe_rec_val)
648+
min_val_8q = min_val_8q.at[n].set(min_val_8q_val)
649+
next_up_turn = next_up_turn.at[n].set(next_up_turn_val)
650+
next_down_turn = next_down_turn.at[n].set(next_down_turn_val)
657651
658652
# Plot path statistics
659653
plot_path_stats(next_reces, severe_rec, min_val_8q,
@@ -695,20 +689,28 @@ def plot_extended_Wecker(
695689
σ_sample = post_samples['σ'][index]
696690
697691
# Compute path statistics
692+
next_reces = jnp.zeros(N)
693+
severe_rec = jnp.zeros(N)
694+
min_val_8q = jnp.zeros(N)
695+
next_up_turn = jnp.zeros(N)
696+
next_down_turn = jnp.zeros(N)
697+
698698
subkeys = random.split(key, num=N)
699699
700-
def step(carry, n):
701-
ar1_n = make_ar1(ρ=ρ_sample[n], σ=σ_sample[n], y0=y0, T1=T1)
700+
for n in range(N):
701+
ar1_n = AR1(ρ=ρ_sample[n], σ=σ_sample[n], y0=y0, T1=T1)
702702
future_temp = AR1_simulate_future(
703703
ar1_n, y_T0, N=1, key=subkeys[n]
704704
).reshape(-1)
705-
(next_reces, severe_rec, min_val_8q, next_up_turn, next_down_turn
705+
(next_reces_val, severe_rec_val, min_val_8q_val,
706+
next_up_turn_val, next_down_turn_val
706707
) = compute_path_statistics(initial_path, future_temp)
707-
return carry, (future_temp, next_reces, severe_rec,
708-
min_val_8q, next_up_turn, next_down_turn)
709708
710-
_, (future_path, next_reces, severe_rec, min_val_8q, next_up_turn, next_down_turn
711-
) = jax.lax.scan(step, None, jnp.arange(N))
709+
next_reces = next_reces.at[n].set(next_reces_val)
710+
severe_rec = severe_rec.at[n].set(severe_rec_val)
711+
min_val_8q = min_val_8q.at[n].set(min_val_8q_val)
712+
next_up_turn = next_up_turn.at[n].set(next_up_turn_val)
713+
next_down_turn = next_down_turn.at[n].set(next_down_turn_val)
712714
713715
# Plot simulated initial and future paths
714716
plot_path(ar1, initial_path, future_path, ax[0, 0])

0 commit comments

Comments
 (0)