@@ -108,13 +108,6 @@ We want to derive the dynamics of the following aggregates:
108108* $U_t$, the total number of unemployed workers at $t$
109109* $N_t$, the number of workers in the labor force at $t$
110110
111- We also want to know the values of the following objects:
112-
113- * The employment rate $e_t := E_t/N_t$.
114- * The unemployment rate $u_t := U_t/N_t$.
115-
116- (Here and below, capital letters represent aggregates and lowercase letters represent rates)
117-
118111### Laws of motion for stock variables
119112
120113We begin by constructing laws of motion for the aggregate variables $E_t,U_t, N_t$.
@@ -167,6 +160,13 @@ This law tells us how total employment and unemployment evolve over time.
167160
168161Now let's derive the law of motion for rates.
169162
163+ We want to track the values of the following objects:
164+
165+ * The employment rate $e_t := E_t/N_t$.
166+ * The unemployment rate $u_t := U_t/N_t$.
167+
168+ (Here and below, capital letters represent aggregates and lowercase letters represent rates)
169+
170170To get these we can divide both sides of $X_ {t+1} = A X_t$ by $N_ {t+1}$ to get
171171
172172$$
@@ -266,7 +266,6 @@ def generate_path(f, initial_state, num_steps, **kwargs):
266266Now we can compute the matrices and simulate the dynamics.
267267
268268``` {code-cell} ipython3
269- @jax.jit
270269def compute_matrices(model: LakeModel):
271270 """Compute the transition matrices A and A_hat for the model."""
272271 λ, α, b, d = model.λ, model.α, model.b, model.d
@@ -277,7 +276,6 @@ def compute_matrices(model: LakeModel):
277276 return A, A_hat, g
278277
279278
280- @jax.jit
281279def stock_update(current_stocks, time_step, model):
282280 """
283281 Apply transition matrix to get next period's stocks.
@@ -286,7 +284,6 @@ def stock_update(current_stocks, time_step, model):
286284 next_stocks = A @ current_stocks
287285 return next_stocks
288286
289- @jax.jit
290287def rate_update(current_rates, time_step, model):
291288 """
292289 Apply normalized transition matrix for next period's rates.
@@ -330,14 +327,12 @@ fig, axes = plt.subplots(3, 1, figsize=(10, 8))
330327X_0 = jnp.array([U_0, E_0])
331328X_path = generate_path(stock_update, X_0, T, model=model)
332329
333- axes[0].plot(X_path[0, :], lw=2)
334- axes[0].set_title('unemployment')
335-
336- axes[1].plot(X_path[1, :], lw=2)
337- axes[1].set_title('employment')
330+ titles = ['unemployment', 'employment', 'labor force']
331+ data = [X_path[0, :], X_path[1, :], X_path.sum(0)]
338332
339- axes[2].plot(X_path.sum(0), lw=2)
340- axes[2].set_title('labor force')
333+ for ax, title, series in zip(axes, titles, data):
334+ ax.plot(series, lw=2)
335+ ax.set_title(title)
341336
342337plt.tight_layout()
343338plt.show()
@@ -409,6 +404,41 @@ plt.tight_layout()
409404plt.show()
410405```
411406
407+ ``` {exercise}
408+ :label: model_ex1
409+
410+ Use JAX's `vmap` to compute steady-state unemployment rates for a range of job finding rates $\lambda$ (from 0.1 to 0.5), and plot the relationship.
411+ ```
412+
413+ ``` {solution-start} model_ex1
414+ :class: dropdown
415+ ```
416+
417+ Here is one solution
418+
419+ ``` {code-cell} ipython3
420+ @jax.jit
421+ def compute_unemployment_rate(λ_val):
422+ """Computes steady-state unemployment for a given λ"""
423+ model = LakeModel(λ=λ_val)
424+ steady_state = rate_steady_state(model)
425+ return steady_state[0]
426+
427+ # Use vmap to compute for multiple λ values
428+ λ_values = jnp.linspace(0.1, 0.5, 50)
429+ unemployment_rates = jax.vmap(compute_unemployment_rate)(λ_values)
430+
431+ # Plot the results
432+ fig, ax = plt.subplots(figsize=(10, 6))
433+ ax.plot(λ_values, unemployment_rates, lw=2)
434+ ax.set_xlabel(r'$\lambda$')
435+ ax.set_ylabel('steady-state unemployment rate')
436+ plt.show()
437+ ```
438+
439+ ``` {solution-end}
440+ ```
441+
412442(dynamics_workers)=
413443## Dynamics of an individual worker
414444
@@ -500,7 +530,6 @@ We can investigate this by simulating the Markov chain.
500530Let's plot the path of the sample averages over 5,000 periods
501531
502532``` {code-cell} ipython3
503- @jax.jit
504533def markov_update(state, t, P, keys):
505534 """
506535 Sample next state from transition probabilities.
@@ -535,14 +564,14 @@ titles = ['percent of time unemployed', 'percent of time employed']
535564
536565for i, plot in enumerate(to_plot):
537566 axes[i].plot(plot, lw=2, alpha=0.5)
538- axes[i].hlines(xbar[i], 0, T, 'r', '--')
567+ axes[i].hlines(xbar[i], 0, T, linestyles= '--')
539568 axes[i].set_title(titles[i])
540569
541570plt.tight_layout()
542571plt.show()
543572```
544573
545- The stationary probabilities are given by the dashed red line.
574+ The stationary probabilities are given by the dashed line.
546575
547576In this case it takes much of the sample for these two objects to converge.
548577
@@ -905,63 +934,6 @@ The level that maximizes steady state welfare is approximately 62.
905934
906935## Exercises
907936
908- ``` {exercise}
909- :label: model_ex1
910-
911- In the JAX implementation of the Lake Model, we use a `NamedTuple` for parameters and separate functions for computations.
912-
913- This approach has several advantages:
914- 1. It's immutable, which aligns with JAX's functional programming paradigm
915- 2. Functions can be JIT-compiled for better performance
916-
917- In this exercise, your task is to:
918- 1. Update parameters by creating a new instance of the model with the parameters (`α=0.02, λ=0.3`).
919- 2. Use JAX's `vmap` to compute steady states for different parameter values
920- 3. Plot how the steady-state unemployment rate varies with the job finding rate $\lambda$
921- ```
922-
923- ``` {solution-start} model_ex1
924- :class: dropdown
925- ```
926-
927- Here is one solution
928-
929- ``` {code-cell} ipython3
930- @jax.jit
931- def compute_unemployment_rate(λ_val):
932- """Computes steady-state unemployment for a given λ"""
933- model = LakeModel(λ=λ_val)
934- steady_state = rate_steady_state(model)
935- return steady_state[0]
936-
937- # Use vmap to compute for multiple λ values
938- λ_values = jnp.linspace(0.1, 0.5, 50)
939- unemployment_rates = jax.vmap(compute_unemployment_rate)(λ_values)
940-
941- # Plot the results
942- fig, ax = plt.subplots(figsize=(10, 6))
943- ax.plot(λ_values, unemployment_rates, lw=2)
944- ax.set_xlabel(r'$\lambda$')
945- ax.set_ylabel('steady-state unemployment rate')
946- plt.show()
947-
948- model_base = LakeModel()
949- model_ex1 = LakeModel(α=0.02, λ=0.3)
950-
951- print(f"Base model α: {model_base.α}")
952- print(f"New model α: {model_ex1.α}, λ: {model_ex1.λ}")
953-
954- # Compute steady states for both
955- base_steady_state = rate_steady_state(model_base)
956- new_steady_state = rate_steady_state(model_ex1)
957-
958- print(f"Base unemployment rate: {base_steady_state[0]:.4f}")
959- print(f"New unemployment rate: {new_steady_state[0]:.4f}")
960- ```
961-
962- ``` {solution-end}
963- ```
964-
965937``` {exercise-start}
966938:label: model_ex2
967939```
@@ -1049,7 +1021,7 @@ titles = ['unemployment rate', 'employment rate']
10491021
10501022for i, title in enumerate(titles):
10511023 axes[i].plot(x_path[i, :])
1052- axes[i].hlines(xbar[i], 0, T, 'r', '--')
1024+ axes[i].hlines(xbar[i], 0, T, linestyles= '--')
10531025 axes[i].set_title(title)
10541026
10551027plt.tight_layout()
@@ -1157,7 +1129,7 @@ titles = ['unemployment rate', 'employment rate']
11571129
11581130for i, title in enumerate(titles):
11591131 axes[i].plot(x_path[i, :])
1160- axes[i].hlines(x0[i], 0, T, 'r', '--')
1132+ axes[i].hlines(x0[i], 0, T, linestyles= '--')
11611133 axes[i].set_title(title)
11621134
11631135plt.tight_layout()
0 commit comments