@@ -223,66 +223,60 @@ class LakeModel(NamedTuple):
223223 d: float = 0.00822
224224```
225225
226- Now we can compute the matrices, simulate the dynamics, and find the steady state
226+ We will also use a specialized function to generate time series in an efficient
227+ JAX-compatible manner.
227228
228- ``` {code-cell} ipython3
229- @jax.jit
230- def compute_matrices(model: LakeModel):
231- """Compute the transition matrices A and A_hat for the model."""
232- λ, α, b, d = model.λ, model.α, model.b, model.d
233- g = b - d
234- A = jnp.array([[(1-d) * (1-λ) + b, (1 - d) * α + b],
235- [ (1-d) * λ, (1 - d) * (1 - α)]])
236- A_hat = A / (1 + g)
237- return A, A_hat, g
238-
239-
240- @jax.jit
241- def rate_steady_state(model: LakeModel, tol=1e-6):
242- r"""
243- Finds the steady state of the system :math:`x_{t+1} = \hat A x_{t}`
244- """
245- A, A_hat, g = compute_matrices(model)
246- x = jnp.array([A_hat[0, 1], A_hat[1, 0]])
247- x = x / x.sum()
248- return x
229+ (Iteratively generating time series is somewhat nontrivial in JAX because arrays
230+ are immutable.)
249231
250-
251- @partial(jax.jit, static_argnames=['update_fn ', 'num_steps'])
252- def generate_path(update_fn , initial_state, num_steps, **kwargs):
232+ ``` {code-cell} ipython3
233+ @partial(jax.jit, static_argnames=['f ', 'num_steps'])
234+ def generate_path(f , initial_state, num_steps, **kwargs):
253235 """
254236 Generate a time series by repeatedly applying an update rule.
255237
256- Fix an update function f, initial state x_0,
257- and a set of model parameter θ, this function computes
258- the sequence {x_t}_{t=0}^{T-1} where:
238+ Given a map f, initial state x_0, and a set of model parameter θ, this
239+ function computes and returns the sequence {x_t}_{t=0}^{T-1} when
259240
260- x_{t+1} = f(x_t, t, θ)
261-
262- for t = 0, 1, ..., T-1.
241+ x_{t+1} = f(x_t, t, θ)
263242
264243 Args:
265- update_fn: Update function f that takes
266- (x_t, t, θ) -> x_{t+1}
244+ f: Update function mapping (x_t, t, θ) -> x_{t+1}
267245 initial_state: Initial state x_0
268246 num_steps: Number of time steps T to simulate
269- **kwargs: Function arguments passed to update_fn
247+ **kwargs: Optional extra arguments passed to f
270248
271249 Returns:
272250 Array of shape (T, dim(x)) containing the time series path
273251 [x_0, x_1, x_2, ..., x_{T-1}]
274252 """
253+
275254 def update_wrapper(state, t):
276255 """
277- Wrapper function that adapts the single-return
278- update_fn for use with JAX scan.
256+ Wrapper function that adapts f for use with JAX scan.
279257 """
280- next_state = update_fn (state, t, **kwargs)
258+ next_state = f (state, t, **kwargs)
281259 return next_state, state
282260
283261 _, path = jax.lax.scan(update_wrapper,
284262 initial_state, jnp.arange(num_steps))
285263 return path
264+ ```
265+
266+
267+ Now we can compute the matrices and simulate the dynamics.
268+
269+ ``` {code-cell} ipython3
270+ @jax.jit
271+ def compute_matrices(model: LakeModel):
272+ """Compute the transition matrices A and A_hat for the model."""
273+ λ, α, b, d = model.λ, model.α, model.b, model.d
274+ g = b - d
275+ A = jnp.array([[(1-d) * (1-λ) + b, (1 - d) * α + b],
276+ [ (1-d) * λ, (1 - d) * (1 - α)]])
277+ A_hat = A / (1 + g)
278+ return A, A_hat, g
279+
286280
287281@jax.jit
288282def stock_update(current_stocks, time_step, model):
@@ -306,17 +300,17 @@ def rate_update(current_rates, time_step, model):
306300We create two instances, one with $α=0.013$ and another with $α=0.03$
307301
308302``` {code-cell} ipython3
309- lm = LakeModel()
310- lm_new = LakeModel(α=0.03)
303+ model = LakeModel()
304+ model_new = LakeModel(α=0.03)
311305
312- print(f"Default α: {lm .α}")
313- A, A_hat, g = compute_matrices(lm )
306+ print(f"Default α: {model .α}")
307+ A, A_hat, g = compute_matrices(model )
314308print(f"A matrix:\n{A}")
315309```
316310
317311``` {code-cell} ipython3
318- A_new, A_hat_new, g_new = compute_matrices(lm_new )
319- print(f"New α: {lm_new .α}")
312+ A_new, A_hat_new, g_new = compute_matrices(model_new )
313+ print(f"New α: {model_new .α}")
320314print(f"New A matrix:\n{A_new}")
321315```
322316
@@ -335,7 +329,7 @@ E_0 = e_0 * N_0
335329
336330fig, axes = plt.subplots(3, 1, figsize=(10, 8))
337331X_0 = jnp.array([U_0, E_0])
338- X_path = generate_path(stock_update, X_0, T, model=lm )
332+ X_path = generate_path(stock_update, X_0, T, model=model )
339333
340334axes[0].plot(X_path[:, 0], lw=2)
341335axes[0].set_title('unemployment')
@@ -360,30 +354,46 @@ there exists an $\bar x$ such that
360354
361355This equation tells us that a steady state level $\bar x$ is an eigenvector of $\hat A$ associated with a unit eigenvalue.
362356
363- We also have $x_t \to \bar x$ as $t \to \infty$ provided that the remaining eigenvalue of $\hat A$ has modulus less than 1.
357+ The following function can be used to compute the steady state.
358+
359+ ``` {code-cell} ipython3
360+ @jax.jit
361+ def rate_steady_state(model: LakeModel, tol=1e-6):
362+ r"""
363+ Finds the steady state of the system :math:`x_{t+1} = \hat A x_{t}`
364+ """
365+ A, A_hat, g = compute_matrices(model)
366+ x = jnp.array([A_hat[0, 1], A_hat[1, 0]])
367+ x = x / x.sum()
368+ return x
369+ ```
370+
371+
372+ We also have $x_t \to \bar x$ as $t \to \infty$ provided that the remaining
373+ eigenvalue of $\hat A$ has modulus less than 1.
364374
365375This is the case for our default parameters:
366376
367377``` {code-cell} ipython3
368- A, A_hat, g = compute_matrices(lm )
378+ A, A_hat, g = compute_matrices(model )
369379e, f = jnp.linalg.eigvals(A_hat)
370- print(f"Eigenvalue magnitudes: {abs(e):.4f }, {abs(f):.4f }")
380+ print(f"Eigenvalue magnitudes: {abs(e):.2f }, {abs(f):.2f }")
371381```
372382
373- Let's look at the convergence of the unemployment and employment rates to steady state levels (dashed red line)
383+ Let's look at the convergence of the unemployment and employment rates to steady state levels (dashed black line)
374384
375385``` {code-cell} ipython3
376- xbar = rate_steady_state(lm )
386+ xbar = rate_steady_state(model )
377387
378388fig, axes = plt.subplots(2, 1, figsize=(10, 8))
379389x_0 = jnp.array([u_0, e_0])
380- x_path = generate_path(rate_update, x_0, T, model=lm )
390+ x_path = generate_path(rate_update, x_0, T, model=model )
381391
382392titles = ['unemployment rate', 'employment rate']
383393
384394for i, title in enumerate(titles):
385395 axes[i].plot(x_path[:, i], lw=2, alpha=0.5)
386- axes[i].hlines(xbar[i], 0, T, 'r ', '--')
396+ axes[i].hlines(xbar[i], 0, T, 'black ', '--')
387397 axes[i].set_title(title)
388398
389399plt.tight_layout()
@@ -492,15 +502,15 @@ def markov_update(state, t, P, keys):
492502 p=probs)
493503 return state_new
494504
495- lm_markov = LakeModel(d=0, b=0)
505+ model_markov = LakeModel(d=0, b=0)
496506T = 5000 # Simulation length
497507
498- α, λ = lm_markov .α, lm_markov .λ
508+ α, λ = model_markov .α, model_markov .λ
499509
500510P = jnp.array([[1 - λ, λ],
501511 [ α, 1 - α]])
502512
503- xbar = rate_steady_state(lm_markov )
513+ xbar = rate_steady_state(model_markov )
504514
505515# Simulate the Markov chain
506516key = jax.random.PRNGKey(0)
@@ -694,7 +704,7 @@ def create_mccall_model(α=0.2, β=0.98, γ=0.7, c=6.0, σ=2.0,
694704
695705
696706@jax.jit
697- def update_bellman (mcm: McCallModel, V, U):
707+ def bellman (mcm: McCallModel, V, U):
698708 """
699709 Update the Bellman equations.
700710 """
@@ -718,7 +728,7 @@ def solve_mccall_model(mcm: McCallModel, tol=1e-5, max_iter=2000):
718728
719729 def body_fun(state):
720730 V, U, i, error = state
721- V_new, U_new = update_bellman (mcm, V, U)
731+ V_new, U_new = bellman (mcm, V, U)
722732 error_1 = jnp.max(jnp.abs(V_new - V))
723733 error_2 = jnp.abs(U_new - U)
724734 error_new = jnp.maximum(error_1, error_2)
@@ -804,8 +814,8 @@ def compute_steady_state_quantities(c, τ,
804814 params, w_vec, p_vec)
805815
806816 # Compute steady state employment and unemployment rates
807- lm = LakeModel(α=params.α_q, λ=λ, b=params.b, d=params.d)
808- x = rate_steady_state(lm )
817+ model = LakeModel(α=params.α_q, λ=λ, b=params.b, d=params.d)
818+ x = rate_steady_state(model )
809819 u, e = x
810820
811821 # Compute steady state welfare
@@ -888,7 +898,7 @@ The level that maximizes steady state welfare is approximately 62.
888898## Exercises
889899
890900``` {exercise}
891- :label: lm_ex1
901+ :label: model_ex1
892902
893903In the JAX implementation of the Lake Model, we use a `NamedTuple` for parameters and separate functions for computations.
894904
@@ -902,7 +912,7 @@ In this exercise, your task is to:
9029123. Plot how the steady-state unemployment rate varies with the job finding rate $\lambda$
903913```
904914
905- ``` {solution-start} lm_ex1
915+ ``` {solution-start} model_ex1
906916:class: dropdown
907917```
908918
@@ -927,15 +937,15 @@ ax.set_xlabel(r'$\lambda$')
927937ax.set_ylabel('steady-state unemployment rate')
928938plt.show()
929939
930- lm_base = LakeModel()
931- lm_ex1 = LakeModel(α=0.02, λ=0.3)
940+ model_base = LakeModel()
941+ model_ex1 = LakeModel(α=0.02, λ=0.3)
932942
933- print(f"Base model α: {lm_base .α}")
934- print(f"New model α: {lm_ex1 .α}, λ: {lm_ex1 .λ}")
943+ print(f"Base model α: {model_base .α}")
944+ print(f"New model α: {model_ex1 .α}, λ: {model_ex1 .λ}")
935945
936946# Compute steady states for both
937- base_steady_state = rate_steady_state(lm_base )
938- new_steady_state = rate_steady_state(lm_ex1 )
947+ base_steady_state = rate_steady_state(model_base )
948+ new_steady_state = rate_steady_state(model_ex1 )
939949
940950print(f"Base unemployment rate: {base_steady_state[0]:.4f}")
941951print(f"New unemployment rate: {new_steady_state[0]:.4f}")
@@ -945,7 +955,7 @@ print(f"New unemployment rate: {new_steady_state[0]:.4f}")
945955```
946956
947957``` {exercise-start}
948- :label: lm_ex2
958+ :label: model_ex2
949959```
950960
951961Consider an economy with an initial stock of workers $N_0 = 100$ at the
@@ -972,16 +982,16 @@ What is the new steady state level of employment?
972982```
973983
974984
975- ``` {solution-start} lm_ex2
985+ ``` {solution-start} model_ex2
976986:class: dropdown
977987```
978988
979989We begin by constructing the model with default parameters and finding the
980990initial steady state
981991
982992``` {code-cell} ipython3
983- lm_initial = LakeModel()
984- x0 = rate_steady_state(lm_initial )
993+ model_initial = LakeModel()
994+ x0 = rate_steady_state(model_initial )
985995print(f"Initial Steady State: {x0}")
986996```
987997
@@ -995,12 +1005,12 @@ T = 50
9951005New legislation changes $\lambda$ to $0.2$
9961006
9971007``` {code-cell} ipython3
998- lm_ex2 = LakeModel(λ=0.2)
999- xbar = rate_steady_state(lm_ex2 ) # new steady state
1008+ model_ex2 = LakeModel(λ=0.2)
1009+ xbar = rate_steady_state(model_ex2 ) # new steady state
10001010
10011011# Simulate paths
1002- X_path = generate_path(stock_update, x0 * N0, T, model=lm_ex2 )
1003- x_path = generate_path(rate_update, x0, T, model=lm_ex2 )
1012+ X_path = generate_path(stock_update, x0 * N0, T, model=model_ex2 )
1013+ x_path = generate_path(rate_update, x0, T, model=model_ex2 )
10041014print(f"New Steady State: {xbar}")
10051015```
10061016
@@ -1046,7 +1056,7 @@ steady state levels.
10461056
10471057
10481058``` {exercise}
1049- :label: lm_ex3
1059+ :label: model_ex3
10501060
10511061Consider an economy with an initial stock of workers $N_0 = 100$ at the
10521062steady state level of employment in the baseline parameterization.
@@ -1060,7 +1070,7 @@ Plot the transition dynamics for the rates.
10601070How long does the economy take to return to its original steady state?
10611071```
10621072
1063- ``` {solution-start} lm_ex3
1073+ ``` {solution-start} model_ex3
10641074:class: dropdown
10651075```
10661076
@@ -1073,8 +1083,8 @@ Let's start off at the baseline parameterization and record the steady
10731083state
10741084
10751085``` {code-cell} ipython3
1076- lm_baseline = LakeModel()
1077- x0 = rate_steady_state(lm_baseline )
1086+ model_baseline = LakeModel()
1087+ x0 = rate_steady_state(model_baseline )
10781088N0 = 100
10791089T = 50
10801090```
@@ -1089,11 +1099,11 @@ T_hat = 20
10891099Let's increase $b$ to the new value and simulate for 20 periods
10901100
10911101``` {code-cell} ipython3
1092- lm_high_b = LakeModel(b=b_hat)
1102+ model_high_b = LakeModel(b=b_hat)
10931103
10941104# Simulate stocks and rates for first 20 periods
1095- X_path1 = generate_path(stock_update, x0 * N0, T_hat, model=lm_high_b )
1096- x_path1 = generate_path(rate_update, x0, T_hat, model=lm_high_b )
1105+ X_path1 = generate_path(stock_update, x0 * N0, T_hat, model=model_high_b )
1106+ x_path1 = generate_path(rate_update, x0, T_hat, model=model_high_b )
10971107```
10981108
10991109Now we reset $b$ to the original value and then, using the state
@@ -1103,9 +1113,9 @@ additional 30 periods
11031113``` {code-cell} ipython3
11041114# Use final state from period 20 as initial condition
11051115X_path2 = generate_path(stock_update, X_path1[-1, :], T-T_hat,
1106- model=lm_baseline )
1116+ model=model_baseline )
11071117x_path2 = generate_path(rate_update, x_path1[-1, :], T-T_hat,
1108- model=lm_baseline )
1118+ model=model_baseline )
11091119```
11101120
11111121Finally, we combine these two paths and plot
0 commit comments