Skip to content

Commit 4a9c909

Browse files
committed
updates according to feedback
1 parent 0e6e327 commit 4a9c909

File tree

1 file changed

+71
-44
lines changed

1 file changed

+71
-44
lines changed

lectures/lake_model.md

Lines changed: 71 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ jupytext:
44
extension: .md
55
format_name: myst
66
format_version: 0.13
7-
jupytext_version: 1.17.1
7+
jupytext_version: 1.17.2
88
kernelspec:
99
display_name: Python 3 (ipykernel)
1010
language: python
@@ -248,34 +248,59 @@ def rate_steady_state(model: LakeModel, tol=1e-6):
248248
return x
249249
250250
251-
@partial(jax.jit, static_argnames=['T'])
252-
def simulate_stock_path(model: LakeModel, X0, T):
251+
@partial(jax.jit, static_argnames=['update_fn', 'num_steps'])
252+
def generate_path(update_fn, initial_state, num_steps, **kwargs):
253253
"""
254-
Simulates the sequence of employment and unemployment stocks.
254+
Generate a time series by repeatedly applying an update rule.
255+
256+
Fix an update function f, initial state x_0,
257+
and a set of model parameter θ, this function computes
258+
the sequence {x_t}_{t=0}^{T-1} where:
259+
260+
x_{t+1} = f(x_t, t, θ)
261+
262+
for t = 0, 1, ..., T-1.
263+
264+
Args:
265+
update_fn: Update function f that takes
266+
(x_t, t, θ) -> x_{t+1}
267+
initial_state: Initial state x_0
268+
num_steps: Number of time steps T to simulate
269+
**kwargs: Function arguments passed to update_fn
270+
271+
Returns:
272+
Array of shape (T, dim(x)) containing the time series path
273+
[x_0, x_1, x_2, ..., x_{T-1}]
274+
"""
275+
def update_wrapper(state, t):
276+
"""
277+
Wrapper function that adapts the single-return
278+
update_fn for use with JAX scan.
279+
"""
280+
next_state = update_fn(state, t, **kwargs)
281+
return next_state, state
282+
283+
_, path = jax.lax.scan(update_wrapper,
284+
initial_state, jnp.arange(num_steps))
285+
return path
286+
287+
@jax.jit
288+
def stock_update(current_stocks, time_step, model):
289+
"""
290+
Apply transition matrix to get next period's stocks.
255291
"""
256292
A, A_hat, g = compute_matrices(model)
257-
258-
def update_X(X, _):
259-
X_new = A @ X
260-
return X_new, X
261-
262-
X0 = jnp.atleast_1d(X0)
263-
_, X_path = jax.lax.scan(update_X, X0, jnp.arange(T))
264-
return X_path
293+
next_stocks = A @ current_stocks
294+
return next_stocks
265295
266-
@partial(jax.jit, static_argnames=['T'])
267-
def simulate_rate_path(model: LakeModel, x0, T):
296+
@jax.jit
297+
def rate_update(current_rates, time_step, model):
268298
"""
269-
Simulates the sequence of employment and unemployment rates.
299+
Apply normalized transition matrix for next period's rates.
270300
"""
271301
A, A_hat, g = compute_matrices(model)
272-
273-
def update_x(x, _):
274-
x_new = A_hat @ x
275-
return x_new, x
276-
277-
_, x_path = jax.lax.scan(update_x, x0, jnp.arange(T))
278-
return x_path
302+
next_rates = A_hat @ current_rates
303+
return next_rates
279304
```
280305

281306
We create two instances, one with $α=0.013$ and another with $α=0.03$
@@ -310,7 +335,7 @@ E_0 = e_0 * N_0
310335
311336
fig, axes = plt.subplots(3, 1, figsize=(10, 8))
312337
X_0 = jnp.array([U_0, E_0])
313-
X_path = simulate_stock_path(lm, X_0, T)
338+
X_path = generate_path(stock_update, X_0, T, model=lm)
314339
315340
axes[0].plot(X_path[:, 0], lw=2)
316341
axes[0].set_title('unemployment')
@@ -352,7 +377,7 @@ xbar = rate_steady_state(lm)
352377
353378
fig, axes = plt.subplots(2, 1, figsize=(10, 8))
354379
x_0 = jnp.array([u_0, e_0])
355-
x_path = simulate_rate_path(lm, x_0, T)
380+
x_path = generate_path(rate_update, x_0, T, model=lm)
356381
357382
titles = ['unemployment rate', 'employment rate']
358383
@@ -456,18 +481,16 @@ We can investigate this by simulating the Markov chain.
456481
Let's plot the path of the sample averages over 5,000 periods
457482

458483
```{code-cell} ipython3
459-
@partial(jax.jit, static_argnames=['T'])
460-
def simulate_markov_chain(P, T: int, init_state, key):
461-
"""Simulate a Markov chain."""
462-
def step(state, key):
463-
probs = P[state]
464-
state_new = jax.random.choice(key,
465-
a=jnp.arange(len(probs)), p=probs)
466-
return state_new, state
467-
468-
keys = jax.random.split(key, T)
469-
_, states = jax.lax.scan(step, init_state, keys)
470-
return states
484+
@jax.jit
485+
def markov_update(state, t, P, keys):
486+
"""
487+
Sample next state from transition probabilities.
488+
"""
489+
probs = P[state]
490+
state_new = jax.random.choice(keys[t],
491+
a=jnp.arange(len(probs)),
492+
p=probs)
493+
return state_new
471494
472495
lm_markov = LakeModel(d=0, b=0)
473496
T = 5000 # Simulation length
@@ -480,8 +503,9 @@ P = jnp.array([[1 - λ, λ],
480503
xbar = rate_steady_state(lm_markov)
481504
482505
# Simulate the Markov chain
483-
key = jax.random.PRNGKey(42)
484-
s_path = simulate_markov_chain(P, T, 1, key)
506+
key = jax.random.PRNGKey(0)
507+
keys = jax.random.split(key, T)
508+
s_path = generate_path(markov_update, 1, T, P=P, keys=keys)
485509
486510
fig, axes = plt.subplots(2, 1, figsize=(10, 8))
487511
s_bar_e = jnp.cumsum(s_path) / jnp.arange(1, T+1)
@@ -975,8 +999,8 @@ lm_ex2 = LakeModel(λ=0.2)
975999
xbar = rate_steady_state(lm_ex2) # new steady state
9761000
9771001
# Simulate paths
978-
X_path = simulate_stock_path(lm_ex2, x0 * N0, T)
979-
x_path = simulate_rate_path(lm_ex2, x0, T)
1002+
X_path = generate_path(stock_update, x0 * N0, T, model=lm_ex2)
1003+
x_path = generate_path(rate_update, x0, T, model=lm_ex2)
9801004
print(f"New Steady State: {xbar}")
9811005
```
9821006

@@ -1068,8 +1092,8 @@ Let's increase $b$ to the new value and simulate for 20 periods
10681092
lm_high_b = LakeModel(b=b_hat)
10691093
10701094
# Simulate stocks and rates for first 20 periods
1071-
X_path1 = simulate_stock_path(lm_high_b, x0 * N0, T_hat)
1072-
x_path1 = simulate_rate_path(lm_high_b, x0, T_hat)
1095+
X_path1 = generate_path(stock_update, x0 * N0, T_hat, model=lm_high_b)
1096+
x_path1 = generate_path(rate_update, x0, T_hat, model=lm_high_b)
10731097
```
10741098

10751099
Now we reset $b$ to the original value and then, using the state
@@ -1078,15 +1102,18 @@ additional 30 periods
10781102

10791103
```{code-cell} ipython3
10801104
# Use final state from period 20 as initial condition
1081-
X_path2 = simulate_stock_path(lm_baseline, X_path1[-1, :], T-T_hat)
1082-
x_path2 = simulate_rate_path(lm_baseline, x_path1[-1, :], T-T_hat)
1105+
X_path2 = generate_path(stock_update, X_path1[-1, :], T-T_hat,
1106+
model=lm_baseline)
1107+
x_path2 = generate_path(rate_update, x_path1[-1, :], T-T_hat,
1108+
model=lm_baseline)
10831109
```
10841110

10851111
Finally, we combine these two paths and plot
10861112

10871113
```{code-cell} ipython3
10881114
# Combine paths
10891115
X_path = jnp.vstack([X_path1, X_path2[1:]])
1116+
x_path = jnp.vstack([x_path1, x_path2[1:]])
10901117
10911118
fig, axes = plt.subplots(3, 1, figsize=[10, 9])
10921119

0 commit comments

Comments
 (0)