Skip to content

Commit 8cce1a7

Browse files
jstacclaude
andcommitted
Refactor lake_model.md: Improve model structure and notation
Major improvements: - Added create_lake_model() function to generate model instances with precomputed matrices A, R, and g - Replaced A_hat notation with R throughout (code and LaTeX) for cleaner notation - Updated LakeModel NamedTuple to store computed matrices A and R - Modified all functions to unpack model using tuple unpacking for efficiency - Added type hints to stock_update(), rate_update(), and create_lake_model() - Simplified generate_path() function by removing unused time parameter - Updated rate_steady_state() to use Perron-Frobenius theorem (argmax instead of searching for eigenvalue near 1) - Converted all LakeModel() instantiations to use create_lake_model() - Updated markov simulation to use dedicated simulate_markov() function Benefits: - Matrices computed once at model creation instead of repeatedly - Cleaner mathematical notation using R instead of \hat{A} - More efficient code with direct tuple unpacking - Better type safety with added annotations - More mathematically rigorous using Perron-Frobenius theorem 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 27bb29b commit 8cce1a7

File tree

1 file changed

+117
-73
lines changed

1 file changed

+117
-73
lines changed

lectures/lake_model.md

Lines changed: 117 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -197,14 +197,14 @@ $$
197197
we can also write this as
198198

199199
$$
200-
x_{t+1} = \hat A x_t
200+
x_{t+1} = R x_t
201201
\quad \text{where} \quad
202-
\hat A := \frac{1}{1 + g} A
202+
R := \frac{1}{1 + g} A
203203
$$
204204

205205
You can check that $e_t + u_t = 1$ implies that $e_{t+1}+u_{t+1} = 1$.
206206

207-
This follows from the fact that the columns of $\hat A$ sum to 1.
207+
This follows from the fact that the columns of $R$ sum to 1.
208208

209209
## Implementation
210210

@@ -221,6 +221,50 @@ class LakeModel(NamedTuple):
221221
α: float = 0.013
222222
b: float = 0.0124
223223
d: float = 0.00822
224+
A: jnp.ndarray = None
225+
R: jnp.ndarray = None
226+
g: float = None
227+
228+
229+
def create_lake_model(λ: float = 0.283,
230+
α: float = 0.013,
231+
b: float = 0.0124,
232+
d: float = 0.00822) -> LakeModel:
233+
"""
234+
Create a LakeModel instance with default parameters.
235+
236+
Computes and stores the transition matrices A and R,
237+
and the labor force growth rate g.
238+
239+
Parameters
240+
----------
241+
λ : float, optional
242+
Job finding rate (default: 0.283)
243+
α : float, optional
244+
Job separation rate (default: 0.013)
245+
b : float, optional
246+
Entry rate into labor force (default: 0.0124)
247+
d : float, optional
248+
Exit rate from labor force (default: 0.00822)
249+
250+
Returns
251+
-------
252+
LakeModel
253+
A LakeModel instance with computed matrices A, R, and growth rate g
254+
"""
255+
# Compute growth rate
256+
g = b - d
257+
258+
# Compute transition matrix A
259+
A = jnp.array([
260+
[(1-d) * (1-λ) + b, (1-d) * α + b],
261+
[(1-d) * λ, (1-d) * (1-α)]
262+
])
263+
264+
# Compute normalized transition matrix R
265+
R = A / (1 + g)
266+
267+
return LakeModel(λ=λ, α=α, b=b, d=d, A=A, R=R, g=g)
224268
```
225269

226270
We will also use a specialized function to generate time series in an efficient
@@ -235,13 +279,13 @@ def generate_path(f, initial_state, num_steps, **kwargs):
235279
"""
236280
Generate a time series by repeatedly applying an update rule.
237281
238-
Given a map f, initial state x_0, and a set of model parameter θ, this
282+
Given a map f, initial state x_0, and model parameters, this
239283
function computes and returns the sequence {x_t}_{t=0}^{T-1} when
240284
241-
x_{t+1} = f(x_t, t, θ)
285+
x_{t+1} = f(x_t, **kwargs)
242286
243287
Args:
244-
f: Update function mapping (x_t, t, θ) -> x_{t+1}
288+
f: Update function mapping (x_t, **kwargs) -> x_{t+1}
245289
initial_state: Initial state x_0
246290
num_steps: Number of time steps T to simulate
247291
**kwargs: Optional extra arguments passed to f
@@ -255,59 +299,43 @@ def generate_path(f, initial_state, num_steps, **kwargs):
255299
"""
256300
Wrapper function that adapts f for use with JAX scan.
257301
"""
258-
next_state = f(state, t, **kwargs)
302+
next_state = f(state, **kwargs)
259303
return next_state, state
260304
261305
_, path = jax.lax.scan(update_wrapper,
262306
initial_state, jnp.arange(num_steps))
263307
return path.T
264308
```
265309

266-
Now we can compute the matrices and simulate the dynamics.
310+
Now we can simulate the dynamics.
267311

268312
```{code-cell} ipython3
269-
def compute_matrices(model: LakeModel):
270-
"""Compute the transition matrices A and A_hat for the model."""
271-
λ, α, b, d = model.λ, model.α, model.b, model.d
272-
g = b - d
273-
A = jnp.array([[(1-d) * (1-λ) + b, (1 - d) * α + b],
274-
[ (1-d) * λ, (1 - d) * (1 - α)]])
275-
A_hat = A / (1 + g)
276-
return A, A_hat, g
277-
278-
279-
def stock_update(current_stocks, time_step, model):
280-
"""
281-
Apply transition matrix to get next period's stocks.
282-
"""
283-
A, A_hat, g = compute_matrices(model)
284-
next_stocks = A @ current_stocks
285-
return next_stocks
286-
287-
def rate_update(current_rates, time_step, model):
288-
"""
289-
Apply normalized transition matrix for next period's rates.
290-
"""
291-
A, A_hat, g = compute_matrices(model)
292-
next_rates = A_hat @ current_rates
293-
return next_rates
313+
def stock_update(X: jnp.ndarray, model: LakeModel) -> jnp.ndarray:
314+
"""Apply transition matrix to get next period's stocks."""
315+
λ, α, b, d, A, R, g = model
316+
return A @ X
317+
318+
def rate_update(x: jnp.ndarray, model: LakeModel) -> jnp.ndarray:
319+
"""Apply normalized transition matrix for next period's rates."""
320+
λ, α, b, d, A, R, g = model
321+
return R @ x
294322
```
295323

296324
We create two instances, one with $α=0.013$ and another with $α=0.03$
297325

298326
```{code-cell} ipython3
299-
model = LakeModel()
300-
model_new = LakeModel(α=0.03)
327+
model = create_lake_model()
328+
model_new = create_lake_model(α=0.03)
301329
302330
print(f"Default α: {model.α}")
303-
A, A_hat, g = compute_matrices(model)
304-
print(f"A matrix:\n{A}")
331+
print(f"A matrix:\n{model.A}")
332+
print(f"R matrix:\n{model.R}")
305333
```
306334

307335
```{code-cell} ipython3
308-
A_new, A_hat_new, g_new = compute_matrices(model_new)
309336
print(f"New α: {model_new.α}")
310-
print(f"New A matrix:\n{A_new}")
337+
print(f"New A matrix:\n{model_new.A}")
338+
print(f"New R matrix:\n{model_new.R}")
311339
```
312340

313341
### Aggregate dynamics
@@ -343,44 +371,49 @@ The aggregates $E_t$ and $U_t$ don't converge because their sum $E_t + U_t$ grow
343371
On the other hand, the vector of employment and unemployment rates $x_t$ can be in a steady state $\bar x$ if
344372
there exists an $\bar x$ such that
345373

346-
* $\bar x = \hat A \bar x$
374+
* $\bar x = R \bar x$
347375
* the components satisfy $\bar e + \bar u = 1$
348376

349-
This equation tells us that a steady state level $\bar x$ is an eigenvector of $\hat A$ associated with a unit eigenvalue.
377+
This equation tells us that a steady state level $\bar x$ is an eigenvector of $R$ associated with a unit eigenvalue.
350378

351379
The following function can be used to compute the steady state.
352380

353381
```{code-cell} ipython3
354382
@jax.jit
355-
def rate_steady_state(model: LakeModel):
383+
def rate_steady_state(model: LakeModel) -> jnp.ndarray:
356384
r"""
357-
Finds the steady state of the system :math:`x_{t+1} = \hat A x_{t}`
358-
by computing the eigenvector corresponding to the unit eigenvalue.
385+
Finds the steady state of the system :math:`x_{t+1} = R x_{t}`
386+
by computing the eigenvector corresponding to the largest eigenvalue.
387+
388+
By the Perron-Frobenius theorem, since :math:`R` is a non-negative
389+
matrix with columns summing to 1 (a stochastic matrix), the largest
390+
eigenvalue equals 1 and the corresponding eigenvector gives the steady state.
359391
"""
360-
A, A_hat, g = compute_matrices(model)
361-
eigenvals, eigenvec = jnp.linalg.eig(A_hat)
362-
363-
# Find the eigenvector corresponding to eigenvalue 1
364-
unit_idx = jnp.argmin(jnp.abs(eigenvals - 1.0))
392+
λ, α, b, d, A, R, g = model
393+
eigenvals, eigenvec = jnp.linalg.eig(R)
394+
395+
# Find the eigenvector corresponding to the largest eigenvalue
396+
# (which is 1 for a stochastic matrix by Perron-Frobenius theorem)
397+
max_idx = jnp.argmax(jnp.abs(eigenvals))
365398
366399
# Get the corresponding eigenvector
367-
steady_state = jnp.real(eigenvec[:, unit_idx])
368-
400+
steady_state = jnp.real(eigenvec[:, max_idx])
401+
369402
# Normalize to ensure positive values and sum to 1
370403
steady_state = jnp.abs(steady_state)
371404
steady_state = steady_state / jnp.sum(steady_state)
372-
405+
373406
return steady_state
374407
```
375408

376409
We also have $x_t \to \bar x$ as $t \to \infty$ provided that the remaining
377-
eigenvalue of $\hat A$ has modulus less than 1.
410+
eigenvalue of $R$ has modulus less than 1.
378411

379412
This is the case for our default parameters:
380413

381414
```{code-cell} ipython3
382-
A, A_hat, g = compute_matrices(model)
383-
e, f = jnp.linalg.eigvals(A_hat)
415+
model = create_lake_model()
416+
e, f = jnp.linalg.eigvals(model.R)
384417
print(f"Eigenvalue magnitudes: {abs(e):.2f}, {abs(f):.2f}")
385418
```
386419

@@ -420,7 +453,7 @@ Here is one solution
420453
@jax.jit
421454
def compute_unemployment_rate(λ_val):
422455
"""Computes steady-state unemployment for a given λ"""
423-
model = LakeModel(λ=λ_val)
456+
model = create_lake_model(λ=λ_val)
424457
steady_state = rate_steady_state(model)
425458
return steady_state[0]
426459
@@ -517,7 +550,7 @@ $$
517550

518551
with probability one.
519552

520-
Inspection tells us that $P$ is exactly the transpose of $\hat A$ under the assumption $b=d=0$.
553+
Inspection tells us that $P$ is exactly the transpose of $R$ under the assumption $b=d=0$.
521554

522555
Thus, the percentages of time that an infinitely lived worker spends employed and unemployed equal the fractions of workers employed and unemployed in the steady state distribution.
523556

@@ -530,17 +563,17 @@ We can investigate this by simulating the Markov chain.
530563
Let's plot the path of the sample averages over 5,000 periods
531564

532565
```{code-cell} ipython3
533-
def markov_update(state, t, P, keys):
566+
def markov_update(state, P, key):
534567
"""
535568
Sample next state from transition probabilities.
536569
"""
537570
probs = P[state]
538-
state_new = jax.random.choice(keys[t],
571+
state_new = jax.random.choice(key,
539572
a=jnp.arange(len(probs)),
540573
p=probs)
541574
return state_new
542575
543-
model_markov = LakeModel(d=0, b=0)
576+
model_markov = create_lake_model(d=0, b=0)
544577
T = 5000 # Simulation length
545578
546579
α, λ = model_markov.α, model_markov.λ
@@ -550,10 +583,21 @@ P = jnp.array([[1 - λ, λ],
550583
551584
xbar = rate_steady_state(model_markov)
552585
553-
# Simulate the Markov chain
586+
# Simulate the Markov chain - we need a different approach for random updates
554587
key = jax.random.PRNGKey(0)
555-
keys = jax.random.split(key, T)
556-
s_path = generate_path(markov_update, 1, T, P=P, keys=keys)
588+
589+
def simulate_markov(P, initial_state, T, key):
590+
"""Simulate Markov chain for T periods"""
591+
keys = jax.random.split(key, T)
592+
593+
def scan_fn(state, key):
594+
next_state = markov_update(state, P, key)
595+
return next_state, state
596+
597+
_, path = jax.lax.scan(scan_fn, initial_state, keys)
598+
return path
599+
600+
s_path = simulate_markov(P, 1, T, key)
557601
558602
fig, axes = plt.subplots(2, 1, figsize=(10, 8))
559603
s_bar_e = jnp.cumsum(s_path) / jnp.arange(1, T+1)
@@ -841,25 +885,25 @@ def compute_optimal_quantities(c, τ,
841885
842886
843887
@jax.jit
844-
def compute_steady_state_quantities(c, τ,
888+
def compute_steady_state_quantities(c, τ,
845889
params: EconomyParameters, w_vec, p_vec):
846890
"""
847891
Compute the steady state unemployment rate given c and τ using optimal
848892
quantities from the McCall model and computing corresponding steady
849893
state quantities
850894
"""
851-
w_bar, λ, V, U = compute_optimal_quantities(c, τ,
895+
w_bar, λ, V, U = compute_optimal_quantities(c, τ,
852896
params, w_vec, p_vec)
853-
897+
854898
# Compute steady state employment and unemployment rates
855-
model = LakeModel(α=params.α_q, λ=λ, b=params.b, d=params.d)
899+
model = create_lake_model(λ=λ, α=params.α_q, b=params.b, d=params.d)
856900
u, e = rate_steady_state(model)
857-
901+
858902
# Compute steady state welfare
859903
mask = (w_vec - τ > w_bar)
860904
w = jnp.sum(V * p_vec * mask) / jnp.sum(p_vec * mask)
861905
welfare = e * w + u * U
862-
906+
863907
return e, u, welfare
864908
865909
@@ -970,7 +1014,7 @@ We begin by constructing the model with default parameters and finding the
9701014
initial steady state
9711015

9721016
```{code-cell} ipython3
973-
model_initial = LakeModel()
1017+
model_initial = create_lake_model()
9741018
x0 = rate_steady_state(model_initial)
9751019
print(f"Initial Steady State: {x0}")
9761020
```
@@ -985,7 +1029,7 @@ T = 50
9851029
New legislation changes $\lambda$ to $0.2$
9861030

9871031
```{code-cell} ipython3
988-
model_ex2 = LakeModel(λ=0.2)
1032+
model_ex2 = create_lake_model(λ=0.2)
9891033
xbar = rate_steady_state(model_ex2) # new steady state
9901034
9911035
# Simulate paths
@@ -1063,7 +1107,7 @@ Let's start off at the baseline parameterization and record the steady
10631107
state
10641108

10651109
```{code-cell} ipython3
1066-
model_baseline = LakeModel()
1110+
model_baseline = create_lake_model()
10671111
x0 = rate_steady_state(model_baseline)
10681112
N0 = 100
10691113
T = 50
@@ -1079,7 +1123,7 @@ T_hat = 20
10791123
Let's increase $b$ to the new value and simulate for 20 periods
10801124

10811125
```{code-cell} ipython3
1082-
model_high_b = LakeModel(b=b_hat)
1126+
model_high_b = create_lake_model(b=b_hat)
10831127
10841128
# Simulate stocks and rates for first 20 periods
10851129
X_path1 = generate_path(stock_update, x0 * N0, T_hat, model=model_high_b)

0 commit comments

Comments
 (0)