@@ -4,7 +4,7 @@ jupytext:
44 extension : .md
55 format_name : myst
66 format_version : 0.13
7- jupytext_version : 1.17.2
7+ jupytext_version : 1.17.1
88kernelspec :
99 display_name : Python 3 (ipykernel)
1010 language : python
@@ -247,7 +247,7 @@ def generate_path(f, initial_state, num_steps, **kwargs):
247247 **kwargs: Optional extra arguments passed to f
248248
249249 Returns:
250- Array of shape (T, dim(x)) containing the time series path
250+ Array of shape (dim(x), T ) containing the time series path
251251 [x_0, x_1, x_2, ..., x_{T-1}]
252252 """
253253
@@ -260,10 +260,9 @@ def generate_path(f, initial_state, num_steps, **kwargs):
260260
261261 _, path = jax.lax.scan(update_wrapper,
262262 initial_state, jnp.arange(num_steps))
263- return path
263+ return path.T
264264```
265265
266-
267266Now we can compute the matrices and simulate the dynamics.
268267
269268``` {code-cell} ipython3
@@ -331,13 +330,13 @@ fig, axes = plt.subplots(3, 1, figsize=(10, 8))
331330X_0 = jnp.array([U_0, E_0])
332331X_path = generate_path(stock_update, X_0, T, model=model)
333332
334- axes[0].plot(X_path[:, 0 ], lw=2)
333+ axes[0].plot(X_path[0, : ], lw=2)
335334axes[0].set_title('unemployment')
336335
337- axes[1].plot(X_path[:, 1 ], lw=2)
336+ axes[1].plot(X_path[1, : ], lw=2)
338337axes[1].set_title('employment')
339338
340- axes[2].plot(X_path.sum(1 ), lw=2)
339+ axes[2].plot(X_path.sum(0 ), lw=2)
341340axes[2].set_title('labor force')
342341
343342plt.tight_layout()
@@ -358,16 +357,26 @@ The following function can be used to compute the steady state.
358357
359358``` {code-cell} ipython3
360359@jax.jit
361- def rate_steady_state(model: LakeModel, tol=1e-6 ):
360+ def rate_steady_state(model: LakeModel):
362361 r"""
363362 Finds the steady state of the system :math:`x_{t+1} = \hat A x_{t}`
363+ by computing the eigenvector corresponding to the unit eigenvalue.
364364 """
365365 A, A_hat, g = compute_matrices(model)
366- x = jnp.array([ A_hat[0, 1], A_hat[1, 0]] )
367- x = x / x.sum()
368- return x
369- ```
366+ eigenvals, eigenvec = jnp.linalg.eig( A_hat)
367+
368+ # Find the eigenvector corresponding to eigenvalue 1
369+ unit_idx = jnp.argmin(jnp.abs(eigenvals - 1.0))
370370
371+ # Get the corresponding eigenvector
372+ steady_state = jnp.real(eigenvec[:, unit_idx])
373+
374+ # Normalize to ensure positive values and sum to 1
375+ steady_state = jnp.abs(steady_state)
376+ steady_state = steady_state / jnp.sum(steady_state)
377+
378+ return steady_state
379+ ```
371380
372381We also have $x_t \to \bar x$ as $t \to \infty$ provided that the remaining
373382eigenvalue of $\hat A$ has modulus less than 1.
@@ -392,7 +401,7 @@ x_path = generate_path(rate_update, x_0, T, model=model)
392401titles = ['unemployment rate', 'employment rate']
393402
394403for i, title in enumerate(titles):
395- axes[i].plot(x_path[:, i ], lw=2, alpha=0.5)
404+ axes[i].plot(x_path[i, : ], lw=2, alpha=0.5)
396405 axes[i].hlines(xbar[i], 0, T, 'black', '--')
397406 axes[i].set_title(title)
398407
@@ -815,8 +824,7 @@ def compute_steady_state_quantities(c, τ,
815824
816825 # Compute steady state employment and unemployment rates
817826 model = LakeModel(α=params.α_q, λ=λ, b=params.b, d=params.d)
818- x = rate_steady_state(model)
819- u, e = x
827+ u, e = rate_steady_state(model)
820828
821829 # Compute steady state welfare
822830 mask = (w_vec - τ > w_bar)
@@ -1019,13 +1027,13 @@ Now plot stocks
10191027``` {code-cell} ipython3
10201028fig, axes = plt.subplots(3, 1, figsize=[10, 9])
10211029
1022- axes[0].plot(X_path[:, 0 ])
1030+ axes[0].plot(X_path[0, : ])
10231031axes[0].set_title('unemployment')
10241032
1025- axes[1].plot(X_path[:, 1 ])
1033+ axes[1].plot(X_path[1, : ])
10261034axes[1].set_title('employment')
10271035
1028- axes[2].plot(X_path.sum(1 ))
1036+ axes[2].plot(X_path.sum(0 ))
10291037axes[2].set_title('labor force')
10301038
10311039plt.tight_layout()
@@ -1040,7 +1048,7 @@ fig, axes = plt.subplots(2, 1, figsize=(10, 8))
10401048titles = ['unemployment rate', 'employment rate']
10411049
10421050for i, title in enumerate(titles):
1043- axes[i].plot(x_path[:, i ])
1051+ axes[i].plot(x_path[i, : ])
10441052 axes[i].hlines(xbar[i], 0, T, 'r', '--')
10451053 axes[i].set_title(title)
10461054
@@ -1112,28 +1120,28 @@ additional 30 periods
11121120
11131121``` {code-cell} ipython3
11141122# Use final state from period 20 as initial condition
1115- X_path2 = generate_path(stock_update, X_path1[-1, : ], T-T_hat,
1123+ X_path2 = generate_path(stock_update, X_path1[:, -1 ], T-T_hat,
11161124 model=model_baseline)
1117- x_path2 = generate_path(rate_update, x_path1[-1, : ], T-T_hat,
1125+ x_path2 = generate_path(rate_update, x_path1[:, -1 ], T-T_hat,
11181126 model=model_baseline)
11191127```
11201128
11211129Finally, we combine these two paths and plot
11221130
11231131``` {code-cell} ipython3
11241132# Combine paths
1125- X_path = jnp.vstack ([X_path1, X_path2[1:]])
1126- x_path = jnp.vstack ([x_path1, x_path2[1:]])
1133+ X_path = jnp.hstack ([X_path1, X_path2[:, 1:]])
1134+ x_path = jnp.hstack ([x_path1, x_path2[:, 1:]])
11271135
11281136fig, axes = plt.subplots(3, 1, figsize=[10, 9])
11291137
1130- axes[0].plot(X_path[:, 0 ])
1138+ axes[0].plot(X_path[0, : ])
11311139axes[0].set_title('unemployment')
11321140
1133- axes[1].plot(X_path[:, 1 ])
1141+ axes[1].plot(X_path[1, : ])
11341142axes[1].set_title('employment')
11351143
1136- axes[2].plot(X_path.sum(1 ))
1144+ axes[2].plot(X_path.sum(0 ))
11371145axes[2].set_title('labor force')
11381146
11391147plt.tight_layout()
@@ -1148,7 +1156,7 @@ fig, axes = plt.subplots(2, 1, figsize=[10, 6])
11481156titles = ['unemployment rate', 'employment rate']
11491157
11501158for i, title in enumerate(titles):
1151- axes[i].plot(x_path[:, i ])
1159+ axes[i].plot(x_path[i, : ])
11521160 axes[i].hlines(x0[i], 0, T, 'r', '--')
11531161 axes[i].set_title(title)
11541162
0 commit comments