Skip to content

Commit dad7bb1

Browse files
committed
updates according to feedback
1 parent 3ed4ac7 commit dad7bb1

File tree

1 file changed

+35
-27
lines changed

1 file changed

+35
-27
lines changed

lectures/lake_model.md

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ jupytext:
44
extension: .md
55
format_name: myst
66
format_version: 0.13
7-
jupytext_version: 1.17.2
7+
jupytext_version: 1.17.1
88
kernelspec:
99
display_name: Python 3 (ipykernel)
1010
language: python
@@ -247,7 +247,7 @@ def generate_path(f, initial_state, num_steps, **kwargs):
247247
**kwargs: Optional extra arguments passed to f
248248
249249
Returns:
250-
Array of shape (T, dim(x)) containing the time series path
250+
Array of shape (dim(x), T) containing the time series path
251251
[x_0, x_1, x_2, ..., x_{T-1}]
252252
"""
253253
@@ -260,10 +260,9 @@ def generate_path(f, initial_state, num_steps, **kwargs):
260260
261261
_, path = jax.lax.scan(update_wrapper,
262262
initial_state, jnp.arange(num_steps))
263-
return path
263+
return path.T
264264
```
265265

266-
267266
Now we can compute the matrices and simulate the dynamics.
268267

269268
```{code-cell} ipython3
@@ -331,13 +330,13 @@ fig, axes = plt.subplots(3, 1, figsize=(10, 8))
331330
X_0 = jnp.array([U_0, E_0])
332331
X_path = generate_path(stock_update, X_0, T, model=model)
333332
334-
axes[0].plot(X_path[:, 0], lw=2)
333+
axes[0].plot(X_path[0, :], lw=2)
335334
axes[0].set_title('unemployment')
336335
337-
axes[1].plot(X_path[:, 1], lw=2)
336+
axes[1].plot(X_path[1, :], lw=2)
338337
axes[1].set_title('employment')
339338
340-
axes[2].plot(X_path.sum(1), lw=2)
339+
axes[2].plot(X_path.sum(0), lw=2)
341340
axes[2].set_title('labor force')
342341
343342
plt.tight_layout()
@@ -358,16 +357,26 @@ The following function can be used to compute the steady state.
358357

359358
```{code-cell} ipython3
360359
@jax.jit
361-
def rate_steady_state(model: LakeModel, tol=1e-6):
360+
def rate_steady_state(model: LakeModel):
362361
r"""
363362
Finds the steady state of the system :math:`x_{t+1} = \hat A x_{t}`
363+
by computing the eigenvector corresponding to the unit eigenvalue.
364364
"""
365365
A, A_hat, g = compute_matrices(model)
366-
x = jnp.array([A_hat[0, 1], A_hat[1, 0]])
367-
x = x / x.sum()
368-
return x
369-
```
366+
eigenvals, eigenvec = jnp.linalg.eig(A_hat)
367+
368+
# Find the eigenvector corresponding to eigenvalue 1
369+
unit_idx = jnp.argmin(jnp.abs(eigenvals - 1.0))
370370
371+
# Get the corresponding eigenvector
372+
steady_state = jnp.real(eigenvec[:, unit_idx])
373+
374+
# Normalize to ensure positive values and sum to 1
375+
steady_state = jnp.abs(steady_state)
376+
steady_state = steady_state / jnp.sum(steady_state)
377+
378+
return steady_state
379+
```
371380

372381
We also have $x_t \to \bar x$ as $t \to \infty$ provided that the remaining
373382
eigenvalue of $\hat A$ has modulus less than 1.
@@ -392,7 +401,7 @@ x_path = generate_path(rate_update, x_0, T, model=model)
392401
titles = ['unemployment rate', 'employment rate']
393402
394403
for i, title in enumerate(titles):
395-
axes[i].plot(x_path[:, i], lw=2, alpha=0.5)
404+
axes[i].plot(x_path[i, :], lw=2, alpha=0.5)
396405
axes[i].hlines(xbar[i], 0, T, 'black', '--')
397406
axes[i].set_title(title)
398407
@@ -815,8 +824,7 @@ def compute_steady_state_quantities(c, τ,
815824
816825
# Compute steady state employment and unemployment rates
817826
model = LakeModel(α=params.α_q, λ=λ, b=params.b, d=params.d)
818-
x = rate_steady_state(model)
819-
u, e = x
827+
u, e = rate_steady_state(model)
820828
821829
# Compute steady state welfare
822830
mask = (w_vec - τ > w_bar)
@@ -1019,13 +1027,13 @@ Now plot stocks
10191027
```{code-cell} ipython3
10201028
fig, axes = plt.subplots(3, 1, figsize=[10, 9])
10211029
1022-
axes[0].plot(X_path[:, 0])
1030+
axes[0].plot(X_path[0, :])
10231031
axes[0].set_title('unemployment')
10241032
1025-
axes[1].plot(X_path[:, 1])
1033+
axes[1].plot(X_path[1, :])
10261034
axes[1].set_title('employment')
10271035
1028-
axes[2].plot(X_path.sum(1))
1036+
axes[2].plot(X_path.sum(0))
10291037
axes[2].set_title('labor force')
10301038
10311039
plt.tight_layout()
@@ -1040,7 +1048,7 @@ fig, axes = plt.subplots(2, 1, figsize=(10, 8))
10401048
titles = ['unemployment rate', 'employment rate']
10411049
10421050
for i, title in enumerate(titles):
1043-
axes[i].plot(x_path[:, i])
1051+
axes[i].plot(x_path[i, :])
10441052
axes[i].hlines(xbar[i], 0, T, 'r', '--')
10451053
axes[i].set_title(title)
10461054
@@ -1112,28 +1120,28 @@ additional 30 periods
11121120

11131121
```{code-cell} ipython3
11141122
# Use final state from period 20 as initial condition
1115-
X_path2 = generate_path(stock_update, X_path1[-1, :], T-T_hat,
1123+
X_path2 = generate_path(stock_update, X_path1[:, -1], T-T_hat,
11161124
model=model_baseline)
1117-
x_path2 = generate_path(rate_update, x_path1[-1, :], T-T_hat,
1125+
x_path2 = generate_path(rate_update, x_path1[:, -1], T-T_hat,
11181126
model=model_baseline)
11191127
```
11201128

11211129
Finally, we combine these two paths and plot
11221130

11231131
```{code-cell} ipython3
11241132
# Combine paths
1125-
X_path = jnp.vstack([X_path1, X_path2[1:]])
1126-
x_path = jnp.vstack([x_path1, x_path2[1:]])
1133+
X_path = jnp.hstack([X_path1, X_path2[:, 1:]])
1134+
x_path = jnp.hstack([x_path1, x_path2[:, 1:]])
11271135
11281136
fig, axes = plt.subplots(3, 1, figsize=[10, 9])
11291137
1130-
axes[0].plot(X_path[:, 0])
1138+
axes[0].plot(X_path[0, :])
11311139
axes[0].set_title('unemployment')
11321140
1133-
axes[1].plot(X_path[:, 1])
1141+
axes[1].plot(X_path[1, :])
11341142
axes[1].set_title('employment')
11351143
1136-
axes[2].plot(X_path.sum(1))
1144+
axes[2].plot(X_path.sum(0))
11371145
axes[2].set_title('labor force')
11381146
11391147
plt.tight_layout()
@@ -1148,7 +1156,7 @@ fig, axes = plt.subplots(2, 1, figsize=[10, 6])
11481156
titles = ['unemployment rate', 'employment rate']
11491157
11501158
for i, title in enumerate(titles):
1151-
axes[i].plot(x_path[:, i])
1159+
axes[i].plot(x_path[i, :])
11521160
axes[i].hlines(x0[i], 0, T, 'r', '--')
11531161
axes[i].set_title(title)
11541162

0 commit comments

Comments
 (0)