Skip to content

Commit a6625f0

Browse files
jstacclaude
andcommitted
Merge branch 'main' into mmsm_a
Resolve merge conflicts in mccall_model_with_sep_markov.md by combining variable naming improvements from both branches: - Use 'v_e' and 'continuation_values' for clarity (plural since it's an array) - Use 'accept_indices' for the boolean array - Keep improved comments from both versions 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
2 parents 2fb466c + 83b7d39 commit a6625f0

File tree

2 files changed

+130
-128
lines changed

2 files changed

+130
-128
lines changed

lectures/mccall_model_with_sep_markov.md

Lines changed: 45 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ kernelspec:
3434
This lecture builds on the job search model with separation presented in the
3535
{doc}`previous lecture <mccall_model_with_separation>`.
3636

37-
The key difference is that wage offers now follow a **Markov chain** rather than
38-
being independent and identically distributed (IID).
37+
The key difference is that wage offers now follow a {doc}`Markov chain <finite_markov>` rather than
38+
being IID.
3939

4040
This modification adds persistence to the wage offer process, meaning that
4141
today's wage offer provides information about tomorrow's offer.
@@ -266,8 +266,7 @@ def T(v: jnp.ndarray, model: Model) -> jnp.ndarray:
266266
return jnp.maximum(v_e, h)
267267
```
268268

269-
Here's a routine for value function iteration, as well as a second routine that
270-
computes the reservation wage directly from the value function.
269+
Here's a routine for value function iteration.
271270

272271
```{code-cell} ipython3
273272
@jax.jit
@@ -295,22 +294,20 @@ def vfi(
295294
v_final, error, i = final_loop_state
296295
297296
return v_final
297+
```
298298

299+
Here is a routine that computes the reservation wage from the value function.
299300

301+
```{code-cell} ipython3
300302
@jax.jit
301303
def get_reservation_wage(v: jnp.ndarray, model: Model) -> float:
302304
"""
303-
Calculate the reservation wage directly from the value function.
305+
Calculate the reservation wage from the unemployed agents
306+
value function v := v_u.
304307
305308
The reservation wage is the lowest wage w where accepting (v_e(w))
306-
is at least as good as rejecting (u(c) + β(Pv)(w)).
307-
308-
Parameters:
309-
- v: Value function v_u
310-
- model: Model instance containing parameters
309+
is at least as good as rejecting (u(c) + β(Pv_u)(w)).
311310
312-
Returns:
313-
- Reservation wage (lowest wage for which acceptance is optimal)
314311
"""
315312
n, w_vals, P, P_cumsum, β, c, α, γ = model
316313
@@ -320,42 +317,42 @@ def get_reservation_wage(v: jnp.ndarray, model: Model) -> float:
320317
continuation_values = u(c, γ) + β * P @ v
321318
322319
# Find where acceptance becomes optimal
323-
should_accept = v_e >= continuation_values
324-
first_accept_idx = jnp.argmax(should_accept) # first True in Boolean array
320+
accept_indices = v_e >= continuation_values
321+
first_accept_idx = jnp.argmax(accept_indices) # index of first True
325322
326323
# If no acceptance (all False), return infinity
327324
# Otherwise return the wage at the first acceptance index
328-
return jnp.where(jnp.any(should_accept), w_vals[first_accept_idx], jnp.inf)
325+
return jnp.where(jnp.any(accept_indices), w_vals[first_accept_idx], jnp.inf)
329326
```
330327

328+
331329
## Computing the Solution
332330

333331
Let's solve the model:
334332

335333
```{code-cell} ipython3
336334
model = create_js_with_sep_model()
337335
n, w_vals, P, P_cumsum, β, c, α, γ = model
338-
v_star = vfi(model)
339-
w_star = get_reservation_wage(v_star, model)
336+
v_u = vfi(model)
337+
w_bar = get_reservation_wage(v_u, model)
340338
```
341339

342340
Next we compute some related quantities for plotting.
343341

344342
```{code-cell} ipython3
345343
d = 1 / (1 - β * (1 - α))
346-
accept = d * (u(w_vals, γ) + α * β * P @ v_star)
347-
h_star = u(c, γ) + β * P @ v_star
344+
v_e = d * (u(w_vals, γ) + α * β * P @ v_u)
345+
h = u(c, γ) + β * P @ v_u
348346
```
349347

350348
Let's plot our results.
351349

352350
```{code-cell} ipython3
353351
fig, ax = plt.subplots(figsize=(9, 5.2))
354-
ax.plot(w_vals, h_star, linewidth=4, ls="--", alpha=0.4,
355-
label="continuation value")
356-
ax.plot(w_vals, accept, linewidth=4, ls="--", alpha=0.4,
357-
label="stopping value")
358-
ax.plot(w_vals, v_star, "k-", alpha=0.7, label=r"$v_u^*(w)$")
352+
ax.plot(w_vals, h, 'g-', linewidth=2,
353+
label="continuation value function $h$")
354+
ax.plot(w_vals, v_e, 'b-', linewidth=2,
355+
label="employment value function $v_e$")
359356
ax.legend(frameon=False)
360357
ax.set_xlabel(r"$w$")
361358
plt.show()
@@ -372,16 +369,16 @@ Let's examine how reservation wages change with the separation rate.
372369
```{code-cell} ipython3
373370
α_vals: jnp.ndarray = jnp.linspace(0.0, 1.0, 10)
374371
375-
w_star_vec = []
372+
w_bar_vec = []
376373
for α in α_vals:
377374
model = create_js_with_sep_model(α=α)
378-
v_star = vfi(model)
379-
w_star = get_reservation_wage(v_star, model)
380-
w_star_vec.append(w_star)
375+
v_u = vfi(model)
376+
w_bar = get_reservation_wage(v_u, model)
377+
w_bar_vec.append(w_bar)
381378
382379
fig, ax = plt.subplots(figsize=(9, 5.2))
383380
ax.plot(
384-
α_vals, w_star_vec, linewidth=2, alpha=0.6, label="reservation wage"
381+
α_vals, w_bar_vec, linewidth=2, alpha=0.6, label="reservation wage"
385382
)
386383
ax.legend(frameon=False)
387384
ax.set_xlabel(r"$\alpha$")
@@ -416,7 +413,7 @@ unemployed, 1 if employed) and $w_t$ is
416413
* their current wage, if employed.
417414

418415
```{code-cell} ipython3
419-
def update_agent(key, status, wage_idx, model, w_star):
416+
def update_agent(key, status, wage_idx, model, w_bar):
420417
"""
421418
Updates an agent's employment status and current wage.
422419
@@ -425,7 +422,7 @@ def update_agent(key, status, wage_idx, model, w_star):
425422
- status: Current employment status (0 or 1)
426423
- wage_idx: Current wage, recorded as an array index
427424
- model: Model instance
428-
- w_star: Reservation wage
425+
- w_bar: Reservation wage
429426
430427
"""
431428
n, w_vals, P, P_cumsum, β, c, α, γ = model
@@ -438,7 +435,7 @@ def update_agent(key, status, wage_idx, model, w_star):
438435
)
439436
separation_occurs = jax.random.uniform(key2) < α
440437
# Accept if current wage meets or exceeds reservation wage
441-
accepts = w_vals[wage_idx] >= w_star
438+
accepts = w_vals[wage_idx] >= w_bar
442439
443440
# If employed: status = 1 if no separation, 0 if separation
444441
# If unemployed: status = 1 if accepts, 0 if rejects
@@ -464,7 +461,7 @@ Here's a function to simulate the employment path of a single agent.
464461
```{code-cell} ipython3
465462
def simulate_employment_path(
466463
model: Model, # Model details
467-
w_star: float, # Reservation wage
464+
w_bar: float, # Reservation wage
468465
T: int = 2_000, # Simulation length
469466
seed: int = 42 # Set seed for simulation
470467
):
@@ -489,7 +486,7 @@ def simulate_employment_path(
489486
490487
key, subkey = jax.random.split(key)
491488
status, wage_idx = update_agent(
492-
subkey, status, wage_idx, model, w_star
489+
subkey, status, wage_idx, model, w_bar
493490
)
494491
495492
return jnp.array(wage_path), jnp.array(status_path)
@@ -501,10 +498,10 @@ Let's create a comprehensive plot of the employment simulation:
501498
model = create_js_with_sep_model()
502499
503500
# Calculate reservation wage for plotting
504-
v_star = vfi(model)
505-
w_star = get_reservation_wage(v_star, model)
501+
v_u = vfi(model)
502+
w_bar = get_reservation_wage(v_u, model)
506503
507-
wage_path, employment_status = simulate_employment_path(model, w_star)
504+
wage_path, employment_status = simulate_employment_path(model, w_bar)
508505
509506
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(8, 6))
510507
@@ -520,8 +517,8 @@ ax1.set_ylim(-0.1, 1.1)
520517
521518
# Plot wage path with employment status coloring
522519
ax2.plot(wage_path, 'b-', alpha=0.7, linewidth=1)
523-
ax2.axhline(y=w_star, color='black', linestyle='--', alpha=0.8,
524-
label=f'Reservation wage: {w_star:.2f}')
520+
ax2.axhline(y=w_bar, color='black', linestyle='--', alpha=0.8,
521+
label=f'Reservation wage: {w_bar:.2f}')
525522
ax2.set_xlabel('time')
526523
ax2.set_ylabel('wage')
527524
ax2.set_title('Wage path (actual and offers)')
@@ -622,7 +619,7 @@ We first create a vectorized version of `update_agent` to efficiently update all
622619

623620
```{code-cell} ipython3
624621
# Create vectorized version of update_agent
625-
# The last parameter is now w_star (scalar) instead of σ (array)
622+
# The last parameter is now w_bar (scalar) instead of σ (array)
626623
update_agents_vmap = jax.vmap(
627624
update_agent, in_axes=(0, 0, 0, None, None)
628625
)
@@ -635,7 +632,7 @@ Next we define the core simulation function, which uses `lax.fori_loop` to effic
635632
def _simulate_cross_section_compiled(
636633
key: jnp.ndarray,
637634
model: Model,
638-
w_star: float,
635+
w_bar: float,
639636
n_agents: int,
640637
T: int
641638
):
@@ -655,7 +652,7 @@ def _simulate_cross_section_compiled(
655652
agent_keys = jax.random.split(subkey, n_agents)
656653
657654
status, wage_indices = update_agents_vmap(
658-
agent_keys, status, wage_indices, model, w_star
655+
agent_keys, status, wage_indices, model, w_bar
659656
)
660657
661658
return key, status, wage_indices
@@ -690,12 +687,12 @@ def simulate_cross_section(
690687
key = jax.random.PRNGKey(seed)
691688
692689
# Solve for optimal reservation wage
693-
v_star = vfi(model)
694-
w_star = get_reservation_wage(v_star, model)
690+
v_u = vfi(model)
691+
w_bar = get_reservation_wage(v_u, model)
695692
696693
# Run JIT-compiled simulation
697694
final_status = _simulate_cross_section_compiled(
698-
key, model, w_star, n_agents, T
695+
key, model, w_bar, n_agents, T
699696
)
700697
701698
# Calculate unemployment rate at final period
@@ -719,10 +716,10 @@ def plot_cross_sectional_unemployment(model: Model, t_snapshot: int = 200,
719716
"""
720717
# Get final employment state directly
721718
key = jax.random.PRNGKey(42)
722-
v_star = vfi(model)
723-
w_star = get_reservation_wage(v_star, model)
719+
v_u = vfi(model)
720+
w_bar = get_reservation_wage(v_u, model)
724721
final_status = _simulate_cross_section_compiled(
725-
key, model, w_star, n_agents, t_snapshot
722+
key, model, w_bar, n_agents, t_snapshot
726723
)
727724
728725
# Calculate unemployment rate

0 commit comments

Comments
 (0)