Skip to content

Commit 3ed4ac7

Browse files
committed
misc
1 parent 4a9c909 commit 3ed4ac7

File tree

1 file changed

+93
-83
lines changed

1 file changed

+93
-83
lines changed

lectures/lake_model.md

Lines changed: 93 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -223,66 +223,60 @@ class LakeModel(NamedTuple):
223223
d: float = 0.00822
224224
```
225225

226-
Now we can compute the matrices, simulate the dynamics, and find the steady state
226+
We will also use a specialized function to generate time series in an efficient
227+
JAX-compatible manner.
227228

228-
```{code-cell} ipython3
229-
@jax.jit
230-
def compute_matrices(model: LakeModel):
231-
"""Compute the transition matrices A and A_hat for the model."""
232-
λ, α, b, d = model.λ, model.α, model.b, model.d
233-
g = b - d
234-
A = jnp.array([[(1-d) * (1-λ) + b, (1 - d) * α + b],
235-
[ (1-d) * λ, (1 - d) * (1 - α)]])
236-
A_hat = A / (1 + g)
237-
return A, A_hat, g
238-
239-
240-
@jax.jit
241-
def rate_steady_state(model: LakeModel, tol=1e-6):
242-
r"""
243-
Finds the steady state of the system :math:`x_{t+1} = \hat A x_{t}`
244-
"""
245-
A, A_hat, g = compute_matrices(model)
246-
x = jnp.array([A_hat[0, 1], A_hat[1, 0]])
247-
x = x / x.sum()
248-
return x
229+
(Iteratively generating time series is somewhat nontrivial in JAX because arrays
230+
are immutable.)
249231

250-
251-
@partial(jax.jit, static_argnames=['update_fn', 'num_steps'])
252-
def generate_path(update_fn, initial_state, num_steps, **kwargs):
232+
```{code-cell} ipython3
233+
@partial(jax.jit, static_argnames=['f', 'num_steps'])
234+
def generate_path(f, initial_state, num_steps, **kwargs):
253235
"""
254236
Generate a time series by repeatedly applying an update rule.
255237
256-
Fix an update function f, initial state x_0,
257-
and a set of model parameter θ, this function computes
258-
the sequence {x_t}_{t=0}^{T-1} where:
238+
Given a map f, initial state x_0, and a set of model parameter θ, this
239+
function computes and returns the sequence {x_t}_{t=0}^{T-1} when
259240
260-
x_{t+1} = f(x_t, t, θ)
261-
262-
for t = 0, 1, ..., T-1.
241+
x_{t+1} = f(x_t, t, θ)
263242
264243
Args:
265-
update_fn: Update function f that takes
266-
(x_t, t, θ) -> x_{t+1}
244+
f: Update function mapping (x_t, t, θ) -> x_{t+1}
267245
initial_state: Initial state x_0
268246
num_steps: Number of time steps T to simulate
269-
**kwargs: Function arguments passed to update_fn
247+
**kwargs: Optional extra arguments passed to f
270248
271249
Returns:
272250
Array of shape (T, dim(x)) containing the time series path
273251
[x_0, x_1, x_2, ..., x_{T-1}]
274252
"""
253+
275254
def update_wrapper(state, t):
276255
"""
277-
Wrapper function that adapts the single-return
278-
update_fn for use with JAX scan.
256+
Wrapper function that adapts f for use with JAX scan.
279257
"""
280-
next_state = update_fn(state, t, **kwargs)
258+
next_state = f(state, t, **kwargs)
281259
return next_state, state
282260
283261
_, path = jax.lax.scan(update_wrapper,
284262
initial_state, jnp.arange(num_steps))
285263
return path
264+
```
265+
266+
267+
Now we can compute the matrices and simulate the dynamics.
268+
269+
```{code-cell} ipython3
270+
@jax.jit
271+
def compute_matrices(model: LakeModel):
272+
"""Compute the transition matrices A and A_hat for the model."""
273+
λ, α, b, d = model.λ, model.α, model.b, model.d
274+
g = b - d
275+
A = jnp.array([[(1-d) * (1-λ) + b, (1 - d) * α + b],
276+
[ (1-d) * λ, (1 - d) * (1 - α)]])
277+
A_hat = A / (1 + g)
278+
return A, A_hat, g
279+
286280
287281
@jax.jit
288282
def stock_update(current_stocks, time_step, model):
@@ -306,17 +300,17 @@ def rate_update(current_rates, time_step, model):
306300
We create two instances, one with $α=0.013$ and another with $α=0.03$
307301

308302
```{code-cell} ipython3
309-
lm = LakeModel()
310-
lm_new = LakeModel(α=0.03)
303+
model = LakeModel()
304+
model_new = LakeModel(α=0.03)
311305
312-
print(f"Default α: {lm.α}")
313-
A, A_hat, g = compute_matrices(lm)
306+
print(f"Default α: {model.α}")
307+
A, A_hat, g = compute_matrices(model)
314308
print(f"A matrix:\n{A}")
315309
```
316310

317311
```{code-cell} ipython3
318-
A_new, A_hat_new, g_new = compute_matrices(lm_new)
319-
print(f"New α: {lm_new.α}")
312+
A_new, A_hat_new, g_new = compute_matrices(model_new)
313+
print(f"New α: {model_new.α}")
320314
print(f"New A matrix:\n{A_new}")
321315
```
322316

@@ -335,7 +329,7 @@ E_0 = e_0 * N_0
335329
336330
fig, axes = plt.subplots(3, 1, figsize=(10, 8))
337331
X_0 = jnp.array([U_0, E_0])
338-
X_path = generate_path(stock_update, X_0, T, model=lm)
332+
X_path = generate_path(stock_update, X_0, T, model=model)
339333
340334
axes[0].plot(X_path[:, 0], lw=2)
341335
axes[0].set_title('unemployment')
@@ -360,30 +354,46 @@ there exists an $\bar x$ such that
360354

361355
This equation tells us that a steady state level $\bar x$ is an eigenvector of $\hat A$ associated with a unit eigenvalue.
362356

363-
We also have $x_t \to \bar x$ as $t \to \infty$ provided that the remaining eigenvalue of $\hat A$ has modulus less than 1.
357+
The following function can be used to compute the steady state.
358+
359+
```{code-cell} ipython3
360+
@jax.jit
361+
def rate_steady_state(model: LakeModel, tol=1e-6):
362+
r"""
363+
Finds the steady state of the system :math:`x_{t+1} = \hat A x_{t}`
364+
"""
365+
A, A_hat, g = compute_matrices(model)
366+
x = jnp.array([A_hat[0, 1], A_hat[1, 0]])
367+
x = x / x.sum()
368+
return x
369+
```
370+
371+
372+
We also have $x_t \to \bar x$ as $t \to \infty$ provided that the remaining
373+
eigenvalue of $\hat A$ has modulus less than 1.
364374

365375
This is the case for our default parameters:
366376

367377
```{code-cell} ipython3
368-
A, A_hat, g = compute_matrices(lm)
378+
A, A_hat, g = compute_matrices(model)
369379
e, f = jnp.linalg.eigvals(A_hat)
370-
print(f"Eigenvalue magnitudes: {abs(e):.4f}, {abs(f):.4f}")
380+
print(f"Eigenvalue magnitudes: {abs(e):.2f}, {abs(f):.2f}")
371381
```
372382

373-
Let's look at the convergence of the unemployment and employment rates to steady state levels (dashed red line)
383+
Let's look at the convergence of the unemployment and employment rates to steady state levels (dashed black line)
374384

375385
```{code-cell} ipython3
376-
xbar = rate_steady_state(lm)
386+
xbar = rate_steady_state(model)
377387
378388
fig, axes = plt.subplots(2, 1, figsize=(10, 8))
379389
x_0 = jnp.array([u_0, e_0])
380-
x_path = generate_path(rate_update, x_0, T, model=lm)
390+
x_path = generate_path(rate_update, x_0, T, model=model)
381391
382392
titles = ['unemployment rate', 'employment rate']
383393
384394
for i, title in enumerate(titles):
385395
axes[i].plot(x_path[:, i], lw=2, alpha=0.5)
386-
axes[i].hlines(xbar[i], 0, T, 'r', '--')
396+
axes[i].hlines(xbar[i], 0, T, 'black', '--')
387397
axes[i].set_title(title)
388398
389399
plt.tight_layout()
@@ -492,15 +502,15 @@ def markov_update(state, t, P, keys):
492502
p=probs)
493503
return state_new
494504
495-
lm_markov = LakeModel(d=0, b=0)
505+
model_markov = LakeModel(d=0, b=0)
496506
T = 5000 # Simulation length
497507
498-
α, λ = lm_markov.α, lm_markov
508+
α, λ = model_markov.α, model_markov
499509
500510
P = jnp.array([[1 - λ, λ],
501511
[ α, 1 - α]])
502512
503-
xbar = rate_steady_state(lm_markov)
513+
xbar = rate_steady_state(model_markov)
504514
505515
# Simulate the Markov chain
506516
key = jax.random.PRNGKey(0)
@@ -694,7 +704,7 @@ def create_mccall_model(α=0.2, β=0.98, γ=0.7, c=6.0, σ=2.0,
694704
695705
696706
@jax.jit
697-
def update_bellman(mcm: McCallModel, V, U):
707+
def bellman(mcm: McCallModel, V, U):
698708
"""
699709
Update the Bellman equations.
700710
"""
@@ -718,7 +728,7 @@ def solve_mccall_model(mcm: McCallModel, tol=1e-5, max_iter=2000):
718728
719729
def body_fun(state):
720730
V, U, i, error = state
721-
V_new, U_new = update_bellman(mcm, V, U)
731+
V_new, U_new = bellman(mcm, V, U)
722732
error_1 = jnp.max(jnp.abs(V_new - V))
723733
error_2 = jnp.abs(U_new - U)
724734
error_new = jnp.maximum(error_1, error_2)
@@ -804,8 +814,8 @@ def compute_steady_state_quantities(c, τ,
804814
params, w_vec, p_vec)
805815
806816
# Compute steady state employment and unemployment rates
807-
lm = LakeModel(α=params.α_q, λ=λ, b=params.b, d=params.d)
808-
x = rate_steady_state(lm)
817+
model = LakeModel(α=params.α_q, λ=λ, b=params.b, d=params.d)
818+
x = rate_steady_state(model)
809819
u, e = x
810820
811821
# Compute steady state welfare
@@ -888,7 +898,7 @@ The level that maximizes steady state welfare is approximately 62.
888898
## Exercises
889899

890900
```{exercise}
891-
:label: lm_ex1
901+
:label: model_ex1
892902
893903
In the JAX implementation of the Lake Model, we use a `NamedTuple` for parameters and separate functions for computations.
894904
@@ -902,7 +912,7 @@ In this exercise, your task is to:
902912
3. Plot how the steady-state unemployment rate varies with the job finding rate $\lambda$
903913
```
904914

905-
```{solution-start} lm_ex1
915+
```{solution-start} model_ex1
906916
:class: dropdown
907917
```
908918

@@ -927,15 +937,15 @@ ax.set_xlabel(r'$\lambda$')
927937
ax.set_ylabel('steady-state unemployment rate')
928938
plt.show()
929939
930-
lm_base = LakeModel()
931-
lm_ex1 = LakeModel(α=0.02, λ=0.3)
940+
model_base = LakeModel()
941+
model_ex1 = LakeModel(α=0.02, λ=0.3)
932942
933-
print(f"Base model α: {lm_base.α}")
934-
print(f"New model α: {lm_ex1.α}, λ: {lm_ex1.λ}")
943+
print(f"Base model α: {model_base.α}")
944+
print(f"New model α: {model_ex1.α}, λ: {model_ex1.λ}")
935945
936946
# Compute steady states for both
937-
base_steady_state = rate_steady_state(lm_base)
938-
new_steady_state = rate_steady_state(lm_ex1)
947+
base_steady_state = rate_steady_state(model_base)
948+
new_steady_state = rate_steady_state(model_ex1)
939949
940950
print(f"Base unemployment rate: {base_steady_state[0]:.4f}")
941951
print(f"New unemployment rate: {new_steady_state[0]:.4f}")
@@ -945,7 +955,7 @@ print(f"New unemployment rate: {new_steady_state[0]:.4f}")
945955
```
946956

947957
```{exercise-start}
948-
:label: lm_ex2
958+
:label: model_ex2
949959
```
950960

951961
Consider an economy with an initial stock of workers $N_0 = 100$ at the
@@ -972,16 +982,16 @@ What is the new steady state level of employment?
972982
```
973983

974984

975-
```{solution-start} lm_ex2
985+
```{solution-start} model_ex2
976986
:class: dropdown
977987
```
978988

979989
We begin by constructing the model with default parameters and finding the
980990
initial steady state
981991

982992
```{code-cell} ipython3
983-
lm_initial = LakeModel()
984-
x0 = rate_steady_state(lm_initial)
993+
model_initial = LakeModel()
994+
x0 = rate_steady_state(model_initial)
985995
print(f"Initial Steady State: {x0}")
986996
```
987997

@@ -995,12 +1005,12 @@ T = 50
9951005
New legislation changes $\lambda$ to $0.2$
9961006

9971007
```{code-cell} ipython3
998-
lm_ex2 = LakeModel(λ=0.2)
999-
xbar = rate_steady_state(lm_ex2) # new steady state
1008+
model_ex2 = LakeModel(λ=0.2)
1009+
xbar = rate_steady_state(model_ex2) # new steady state
10001010
10011011
# Simulate paths
1002-
X_path = generate_path(stock_update, x0 * N0, T, model=lm_ex2)
1003-
x_path = generate_path(rate_update, x0, T, model=lm_ex2)
1012+
X_path = generate_path(stock_update, x0 * N0, T, model=model_ex2)
1013+
x_path = generate_path(rate_update, x0, T, model=model_ex2)
10041014
print(f"New Steady State: {xbar}")
10051015
```
10061016

@@ -1046,7 +1056,7 @@ steady state levels.
10461056

10471057

10481058
```{exercise}
1049-
:label: lm_ex3
1059+
:label: model_ex3
10501060
10511061
Consider an economy with an initial stock of workers $N_0 = 100$ at the
10521062
steady state level of employment in the baseline parameterization.
@@ -1060,7 +1070,7 @@ Plot the transition dynamics for the rates.
10601070
How long does the economy take to return to its original steady state?
10611071
```
10621072

1063-
```{solution-start} lm_ex3
1073+
```{solution-start} model_ex3
10641074
:class: dropdown
10651075
```
10661076

@@ -1073,8 +1083,8 @@ Let's start off at the baseline parameterization and record the steady
10731083
state
10741084

10751085
```{code-cell} ipython3
1076-
lm_baseline = LakeModel()
1077-
x0 = rate_steady_state(lm_baseline)
1086+
model_baseline = LakeModel()
1087+
x0 = rate_steady_state(model_baseline)
10781088
N0 = 100
10791089
T = 50
10801090
```
@@ -1089,11 +1099,11 @@ T_hat = 20
10891099
Let's increase $b$ to the new value and simulate for 20 periods
10901100

10911101
```{code-cell} ipython3
1092-
lm_high_b = LakeModel(b=b_hat)
1102+
model_high_b = LakeModel(b=b_hat)
10931103
10941104
# Simulate stocks and rates for first 20 periods
1095-
X_path1 = generate_path(stock_update, x0 * N0, T_hat, model=lm_high_b)
1096-
x_path1 = generate_path(rate_update, x0, T_hat, model=lm_high_b)
1105+
X_path1 = generate_path(stock_update, x0 * N0, T_hat, model=model_high_b)
1106+
x_path1 = generate_path(rate_update, x0, T_hat, model=model_high_b)
10971107
```
10981108

10991109
Now we reset $b$ to the original value and then, using the state
@@ -1103,9 +1113,9 @@ additional 30 periods
11031113
```{code-cell} ipython3
11041114
# Use final state from period 20 as initial condition
11051115
X_path2 = generate_path(stock_update, X_path1[-1, :], T-T_hat,
1106-
model=lm_baseline)
1116+
model=model_baseline)
11071117
x_path2 = generate_path(rate_update, x_path1[-1, :], T-T_hat,
1108-
model=lm_baseline)
1118+
model=model_baseline)
11091119
```
11101120

11111121
Finally, we combine these two paths and plot

0 commit comments

Comments
 (0)