@@ -4,7 +4,7 @@ jupytext:
44 extension : .md
55 format_name : myst
66 format_version : 0.13
7- jupytext_version : 1.17.1
7+ jupytext_version : 1.17.2
88kernelspec :
99 display_name : Python 3 (ipykernel)
1010 language : python
@@ -248,34 +248,59 @@ def rate_steady_state(model: LakeModel, tol=1e-6):
248248 return x
249249
250250
251- @partial(jax.jit, static_argnames=['T '])
252- def simulate_stock_path(model: LakeModel, X0, T ):
251+ @partial(jax.jit, static_argnames=['update_fn', 'num_steps '])
252+ def generate_path(update_fn, initial_state, num_steps, **kwargs ):
253253 """
254- Simulates the sequence of employment and unemployment stocks.
254+ Generate a time series by repeatedly applying an update rule.
255+
256+ Fix an update function f, initial state x_0,
257+ and a set of model parameter θ, this function computes
258+ the sequence {x_t}_{t=0}^{T-1} where:
259+
260+ x_{t+1} = f(x_t, t, θ)
261+
262+ for t = 0, 1, ..., T-1.
263+
264+ Args:
265+ update_fn: Update function f that takes
266+ (x_t, t, θ) -> x_{t+1}
267+ initial_state: Initial state x_0
268+ num_steps: Number of time steps T to simulate
269+ **kwargs: Function arguments passed to update_fn
270+
271+ Returns:
272+ Array of shape (T, dim(x)) containing the time series path
273+ [x_0, x_1, x_2, ..., x_{T-1}]
274+ """
275+ def update_wrapper(state, t):
276+ """
277+ Wrapper function that adapts the single-return
278+ update_fn for use with JAX scan.
279+ """
280+ next_state = update_fn(state, t, **kwargs)
281+ return next_state, state
282+
283+ _, path = jax.lax.scan(update_wrapper,
284+ initial_state, jnp.arange(num_steps))
285+ return path
286+
287+ @jax.jit
288+ def stock_update(current_stocks, time_step, model):
289+ """
290+ Apply transition matrix to get next period's stocks.
255291 """
256292 A, A_hat, g = compute_matrices(model)
257-
258- def update_X(X, _):
259- X_new = A @ X
260- return X_new, X
261-
262- X0 = jnp.atleast_1d(X0)
263- _, X_path = jax.lax.scan(update_X, X0, jnp.arange(T))
264- return X_path
293+ next_stocks = A @ current_stocks
294+ return next_stocks
265295
266- @partial( jax.jit, static_argnames=['T'])
267- def simulate_rate_path(model: LakeModel, x0, T ):
296+ @jax.jit
297+ def rate_update(current_rates, time_step, model ):
268298 """
269- Simulates the sequence of employment and unemployment rates.
299+ Apply normalized transition matrix for next period's rates.
270300 """
271301 A, A_hat, g = compute_matrices(model)
272-
273- def update_x(x, _):
274- x_new = A_hat @ x
275- return x_new, x
276-
277- _, x_path = jax.lax.scan(update_x, x0, jnp.arange(T))
278- return x_path
302+ next_rates = A_hat @ current_rates
303+ return next_rates
279304```
280305
281306We create two instances, one with $α=0.013$ and another with $α=0.03$
@@ -310,7 +335,7 @@ E_0 = e_0 * N_0
310335
311336fig, axes = plt.subplots(3, 1, figsize=(10, 8))
312337X_0 = jnp.array([U_0, E_0])
313- X_path = simulate_stock_path(lm , X_0, T)
338+ X_path = generate_path(stock_update , X_0, T, model=lm )
314339
315340axes[0].plot(X_path[:, 0], lw=2)
316341axes[0].set_title('unemployment')
@@ -352,7 +377,7 @@ xbar = rate_steady_state(lm)
352377
353378fig, axes = plt.subplots(2, 1, figsize=(10, 8))
354379x_0 = jnp.array([u_0, e_0])
355- x_path = simulate_rate_path(lm , x_0, T)
380+ x_path = generate_path(rate_update , x_0, T, model=lm )
356381
357382titles = ['unemployment rate', 'employment rate']
358383
@@ -456,18 +481,16 @@ We can investigate this by simulating the Markov chain.
456481Let's plot the path of the sample averages over 5,000 periods
457482
458483``` {code-cell} ipython3
459- @partial(jax.jit, static_argnames=['T'])
460- def simulate_markov_chain(P, T: int, init_state, key):
461- """Simulate a Markov chain."""
462- def step(state, key):
463- probs = P[state]
464- state_new = jax.random.choice(key,
465- a=jnp.arange(len(probs)), p=probs)
466- return state_new, state
467-
468- keys = jax.random.split(key, T)
469- _, states = jax.lax.scan(step, init_state, keys)
470- return states
484+ @jax.jit
485+ def markov_update(state, t, P, keys):
486+ """
487+ Sample next state from transition probabilities.
488+ """
489+ probs = P[state]
490+ state_new = jax.random.choice(keys[t],
491+ a=jnp.arange(len(probs)),
492+ p=probs)
493+ return state_new
471494
472495lm_markov = LakeModel(d=0, b=0)
473496T = 5000 # Simulation length
@@ -480,8 +503,9 @@ P = jnp.array([[1 - λ, λ],
480503xbar = rate_steady_state(lm_markov)
481504
482505# Simulate the Markov chain
483- key = jax.random.PRNGKey(42)
484- s_path = simulate_markov_chain(P, T, 1, key)
506+ key = jax.random.PRNGKey(0)
507+ keys = jax.random.split(key, T)
508+ s_path = generate_path(markov_update, 1, T, P=P, keys=keys)
485509
486510fig, axes = plt.subplots(2, 1, figsize=(10, 8))
487511s_bar_e = jnp.cumsum(s_path) / jnp.arange(1, T+1)
@@ -975,8 +999,8 @@ lm_ex2 = LakeModel(λ=0.2)
975999xbar = rate_steady_state(lm_ex2) # new steady state
9761000
9771001# Simulate paths
978- X_path = simulate_stock_path(lm_ex2 , x0 * N0, T)
979- x_path = simulate_rate_path(lm_ex2 , x0, T)
1002+ X_path = generate_path(stock_update , x0 * N0, T, model=lm_ex2 )
1003+ x_path = generate_path(rate_update , x0, T, model=lm_ex2 )
9801004print(f"New Steady State: {xbar}")
9811005```
9821006
@@ -1068,8 +1092,8 @@ Let's increase $b$ to the new value and simulate for 20 periods
10681092lm_high_b = LakeModel(b=b_hat)
10691093
10701094# Simulate stocks and rates for first 20 periods
1071- X_path1 = simulate_stock_path(lm_high_b , x0 * N0, T_hat)
1072- x_path1 = simulate_rate_path(lm_high_b , x0, T_hat)
1095+ X_path1 = generate_path(stock_update , x0 * N0, T_hat, model=lm_high_b )
1096+ x_path1 = generate_path(rate_update , x0, T_hat, model=lm_high_b )
10731097```
10741098
10751099Now we reset $b$ to the original value and then, using the state
@@ -1078,15 +1102,18 @@ additional 30 periods
10781102
10791103``` {code-cell} ipython3
10801104# Use final state from period 20 as initial condition
1081- X_path2 = simulate_stock_path(lm_baseline, X_path1[-1, :], T-T_hat)
1082- x_path2 = simulate_rate_path(lm_baseline, x_path1[-1, :], T-T_hat)
1105+ X_path2 = generate_path(stock_update, X_path1[-1, :], T-T_hat,
1106+ model=lm_baseline)
1107+ x_path2 = generate_path(rate_update, x_path1[-1, :], T-T_hat,
1108+ model=lm_baseline)
10831109```
10841110
10851111Finally, we combine these two paths and plot
10861112
10871113``` {code-cell} ipython3
10881114# Combine paths
10891115X_path = jnp.vstack([X_path1, X_path2[1:]])
1116+ x_path = jnp.vstack([x_path1, x_path2[1:]])
10901117
10911118fig, axes = plt.subplots(3, 1, figsize=[10, 9])
10921119
0 commit comments