Skip to content

Commit 27bb29b

Browse files
jstacclaude
andcommitted
Refine lake_model.md: Remove redundant decorators and improve organization
This commit improves the Lake Model lecture with several refinements: **Code improvements:** - Remove redundant `@jax.jit` decorators from `compute_matrices`, `stock_update`, `rate_update`, and `markov_update` (these functions are only called from within jitted functions, so the decorators are unnecessary and can inhibit compiler optimization) - Refactor aggregate dynamics plot to use a for loop instead of repetitive code - Remove hardcoded colors ('r') from plots to use matplotlib's default color cycle **Content organization:** - Move rate definitions ($e_t$, $u_t$) to "Laws of motion for rates" section where they logically belong - Relocate Exercise 1 to appear immediately before "Dynamics of an individual worker" section for better flow - Simplify Exercise 1 to focus on the pedagogically interesting `vmap` usage, removing less interesting parameter comparison parts These changes improve code clarity, performance, and pedagogical flow without changing functionality. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 9992ef5 commit 27bb29b

File tree

1 file changed

+51
-79
lines changed

1 file changed

+51
-79
lines changed

lectures/lake_model.md

Lines changed: 51 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,6 @@ We want to derive the dynamics of the following aggregates:
108108
* $U_t$, the total number of unemployed workers at $t$
109109
* $N_t$, the number of workers in the labor force at $t$
110110

111-
We also want to know the values of the following objects:
112-
113-
* The employment rate $e_t := E_t/N_t$.
114-
* The unemployment rate $u_t := U_t/N_t$.
115-
116-
(Here and below, capital letters represent aggregates and lowercase letters represent rates)
117-
118111
### Laws of motion for stock variables
119112

120113
We begin by constructing laws of motion for the aggregate variables $E_t,U_t, N_t$.
@@ -167,6 +160,13 @@ This law tells us how total employment and unemployment evolve over time.
167160

168161
Now let's derive the law of motion for rates.
169162

163+
We want to track the values of the following objects:
164+
165+
* The employment rate $e_t := E_t/N_t$.
166+
* The unemployment rate $u_t := U_t/N_t$.
167+
168+
(Here and below, capital letters represent aggregates and lowercase letters represent rates)
169+
170170
To get these we can divide both sides of $X_{t+1} = A X_t$ by $N_{t+1}$ to get
171171

172172
$$
@@ -266,7 +266,6 @@ def generate_path(f, initial_state, num_steps, **kwargs):
266266
Now we can compute the matrices and simulate the dynamics.
267267

268268
```{code-cell} ipython3
269-
@jax.jit
270269
def compute_matrices(model: LakeModel):
271270
"""Compute the transition matrices A and A_hat for the model."""
272271
λ, α, b, d = model.λ, model.α, model.b, model.d
@@ -277,7 +276,6 @@ def compute_matrices(model: LakeModel):
277276
return A, A_hat, g
278277
279278
280-
@jax.jit
281279
def stock_update(current_stocks, time_step, model):
282280
"""
283281
Apply transition matrix to get next period's stocks.
@@ -286,7 +284,6 @@ def stock_update(current_stocks, time_step, model):
286284
next_stocks = A @ current_stocks
287285
return next_stocks
288286
289-
@jax.jit
290287
def rate_update(current_rates, time_step, model):
291288
"""
292289
Apply normalized transition matrix for next period's rates.
@@ -330,14 +327,12 @@ fig, axes = plt.subplots(3, 1, figsize=(10, 8))
330327
X_0 = jnp.array([U_0, E_0])
331328
X_path = generate_path(stock_update, X_0, T, model=model)
332329
333-
axes[0].plot(X_path[0, :], lw=2)
334-
axes[0].set_title('unemployment')
335-
336-
axes[1].plot(X_path[1, :], lw=2)
337-
axes[1].set_title('employment')
330+
titles = ['unemployment', 'employment', 'labor force']
331+
data = [X_path[0, :], X_path[1, :], X_path.sum(0)]
338332
339-
axes[2].plot(X_path.sum(0), lw=2)
340-
axes[2].set_title('labor force')
333+
for ax, title, series in zip(axes, titles, data):
334+
ax.plot(series, lw=2)
335+
ax.set_title(title)
341336
342337
plt.tight_layout()
343338
plt.show()
@@ -409,6 +404,41 @@ plt.tight_layout()
409404
plt.show()
410405
```
411406

407+
```{exercise}
408+
:label: model_ex1
409+
410+
Use JAX's `vmap` to compute steady-state unemployment rates for a range of job finding rates $\lambda$ (from 0.1 to 0.5), and plot the relationship.
411+
```
412+
413+
```{solution-start} model_ex1
414+
:class: dropdown
415+
```
416+
417+
Here is one solution
418+
419+
```{code-cell} ipython3
420+
@jax.jit
421+
def compute_unemployment_rate(λ_val):
422+
"""Computes steady-state unemployment for a given λ"""
423+
model = LakeModel(λ=λ_val)
424+
steady_state = rate_steady_state(model)
425+
return steady_state[0]
426+
427+
# Use vmap to compute for multiple λ values
428+
λ_values = jnp.linspace(0.1, 0.5, 50)
429+
unemployment_rates = jax.vmap(compute_unemployment_rate)(λ_values)
430+
431+
# Plot the results
432+
fig, ax = plt.subplots(figsize=(10, 6))
433+
ax.plot(λ_values, unemployment_rates, lw=2)
434+
ax.set_xlabel(r'$\lambda$')
435+
ax.set_ylabel('steady-state unemployment rate')
436+
plt.show()
437+
```
438+
439+
```{solution-end}
440+
```
441+
412442
(dynamics_workers)=
413443
## Dynamics of an individual worker
414444

@@ -500,7 +530,6 @@ We can investigate this by simulating the Markov chain.
500530
Let's plot the path of the sample averages over 5,000 periods
501531

502532
```{code-cell} ipython3
503-
@jax.jit
504533
def markov_update(state, t, P, keys):
505534
"""
506535
Sample next state from transition probabilities.
@@ -535,14 +564,14 @@ titles = ['percent of time unemployed', 'percent of time employed']
535564
536565
for i, plot in enumerate(to_plot):
537566
axes[i].plot(plot, lw=2, alpha=0.5)
538-
axes[i].hlines(xbar[i], 0, T, 'r', '--')
567+
axes[i].hlines(xbar[i], 0, T, linestyles='--')
539568
axes[i].set_title(titles[i])
540569
541570
plt.tight_layout()
542571
plt.show()
543572
```
544573

545-
The stationary probabilities are given by the dashed red line.
574+
The stationary probabilities are given by the dashed line.
546575

547576
In this case it takes much of the sample for these two objects to converge.
548577

@@ -905,63 +934,6 @@ The level that maximizes steady state welfare is approximately 62.
905934

906935
## Exercises
907936

908-
```{exercise}
909-
:label: model_ex1
910-
911-
In the JAX implementation of the Lake Model, we use a `NamedTuple` for parameters and separate functions for computations.
912-
913-
This approach has several advantages:
914-
1. It's immutable, which aligns with JAX's functional programming paradigm
915-
2. Functions can be JIT-compiled for better performance
916-
917-
In this exercise, your task is to:
918-
1. Update parameters by creating a new instance of the model with the parameters (`α=0.02, λ=0.3`).
919-
2. Use JAX's `vmap` to compute steady states for different parameter values
920-
3. Plot how the steady-state unemployment rate varies with the job finding rate $\lambda$
921-
```
922-
923-
```{solution-start} model_ex1
924-
:class: dropdown
925-
```
926-
927-
Here is one solution
928-
929-
```{code-cell} ipython3
930-
@jax.jit
931-
def compute_unemployment_rate(λ_val):
932-
"""Computes steady-state unemployment for a given λ"""
933-
model = LakeModel(λ=λ_val)
934-
steady_state = rate_steady_state(model)
935-
return steady_state[0]
936-
937-
# Use vmap to compute for multiple λ values
938-
λ_values = jnp.linspace(0.1, 0.5, 50)
939-
unemployment_rates = jax.vmap(compute_unemployment_rate)(λ_values)
940-
941-
# Plot the results
942-
fig, ax = plt.subplots(figsize=(10, 6))
943-
ax.plot(λ_values, unemployment_rates, lw=2)
944-
ax.set_xlabel(r'$\lambda$')
945-
ax.set_ylabel('steady-state unemployment rate')
946-
plt.show()
947-
948-
model_base = LakeModel()
949-
model_ex1 = LakeModel(α=0.02, λ=0.3)
950-
951-
print(f"Base model α: {model_base.α}")
952-
print(f"New model α: {model_ex1.α}, λ: {model_ex1.λ}")
953-
954-
# Compute steady states for both
955-
base_steady_state = rate_steady_state(model_base)
956-
new_steady_state = rate_steady_state(model_ex1)
957-
958-
print(f"Base unemployment rate: {base_steady_state[0]:.4f}")
959-
print(f"New unemployment rate: {new_steady_state[0]:.4f}")
960-
```
961-
962-
```{solution-end}
963-
```
964-
965937
```{exercise-start}
966938
:label: model_ex2
967939
```
@@ -1049,7 +1021,7 @@ titles = ['unemployment rate', 'employment rate']
10491021
10501022
for i, title in enumerate(titles):
10511023
axes[i].plot(x_path[i, :])
1052-
axes[i].hlines(xbar[i], 0, T, 'r', '--')
1024+
axes[i].hlines(xbar[i], 0, T, linestyles='--')
10531025
axes[i].set_title(title)
10541026
10551027
plt.tight_layout()
@@ -1157,7 +1129,7 @@ titles = ['unemployment rate', 'employment rate']
11571129
11581130
for i, title in enumerate(titles):
11591131
axes[i].plot(x_path[i, :])
1160-
axes[i].hlines(x0[i], 0, T, 'r', '--')
1132+
axes[i].hlines(x0[i], 0, T, linestyles='--')
11611133
axes[i].set_title(title)
11621134
11631135
plt.tight_layout()

0 commit comments

Comments
 (0)