Skip to content

Commit 257c464

Browse files
jstacclaude
andcommitted
Improve mccall_model_with_sep_markov.md: Add explanations and simplify simulation loop
Made the following improvements: - Added explanatory sentences above all code blocks that lacked context - Replaced lax.scan with lax.fori_loop in cross-sectional simulation (simpler and more appropriate since we only need final state) - Renamed body_fn to update for clarity 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 45d5a7d commit 257c464

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

lectures/mccall_model_with_sep_markov.md

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,9 @@ Often the second approach is better for our purposes, since it's easier to paral
554554

555555
## Cross-Sectional Analysis
556556

557-
Now let's simulate many agents simultaneously to examine the cross-sectional unemployment rate:
557+
Now let's simulate many agents simultaneously to examine the cross-sectional unemployment rate.
558+
559+
We first create a vectorized version of `update_agent` to efficiently update all agents in parallel:
558560

559561
```{code-cell} ipython3
560562
# Create vectorized version of update_agent
@@ -563,6 +565,8 @@ update_agents_vmap = jax.vmap(
563565
)
564566
```
565567

568+
Next we define the core simulation function, which uses `lax.fori_loop` to efficiently iterate many agents forward in time:
569+
566570
```{code-cell} ipython3
567571
@partial(jit, static_argnums=(3, 4))
568572
def _simulate_cross_section_compiled(
@@ -572,15 +576,15 @@ def _simulate_cross_section_compiled(
572576
n_agents: int,
573577
T: int
574578
):
575-
"""JIT-compiled core simulation loop using lax.scan.
579+
"""JIT-compiled core simulation loop using lax.fori_loop.
576580
Returns only the final employment state to save memory."""
577581
n, w_vals, P, P_cumsum, β, c, α = model
578582
579583
# Initialize arrays
580584
wage_indices = jnp.zeros(n_agents, dtype=jnp.int32)
581585
is_employed = jnp.zeros(n_agents, dtype=jnp.int32)
582586
583-
def scan_fn(loop_state, t):
587+
def update(t, loop_state):
584588
key, is_employed, wage_indices = loop_state
585589
586590
# Shift loop state forwards - more efficient key generation
@@ -591,16 +595,11 @@ def _simulate_cross_section_compiled(
591595
agent_keys, is_employed, wage_indices, model, σ
592596
)
593597
594-
# Pack results and return
595-
new_loop_state = key, is_employed, wage_indices
596-
return new_loop_state, None
598+
return key, is_employed, wage_indices
597599
598-
# Run simulation using scan
600+
# Run simulation using fori_loop
599601
initial_loop_state = (key, is_employed, wage_indices)
600-
601-
final_loop_state, _ = lax.scan(
602-
scan_fn, initial_loop_state, jnp.arange(T)
603-
)
602+
final_loop_state = lax.fori_loop(0, T, update, initial_loop_state)
604603
605604
# Return only final employment state
606605
_, final_is_employed, _ = final_loop_state
@@ -641,6 +640,8 @@ def simulate_cross_section(
641640
return unemployment_rate
642641
```
643642

643+
This function generates a histogram showing the distribution of employment status across many agents:
644+
644645
```{code-cell} ipython3
645646
def plot_cross_sectional_unemployment(model: Model, t_snapshot: int = 200,
646647
n_agents: int = 20_000):
@@ -680,6 +681,8 @@ def plot_cross_sectional_unemployment(model: Model, t_snapshot: int = 200,
680681
plt.show()
681682
```
682683

684+
Now let's compare the time-average unemployment rate (from a single agent's long simulation) with the cross-sectional unemployment rate (from many agents at a single point in time):
685+
683686
```{code-cell} ipython3
684687
model = create_js_with_sep_model()
685688
cross_sectional_unemp = simulate_cross_section(
@@ -725,6 +728,8 @@ changes with unemployment compensation.
725728
:class: dropdown
726729
```
727730

731+
We compute the steady-state unemployment rate for different values of unemployment compensation:
732+
728733
```{code-cell} ipython3
729734
c_values = 1.0, 0.8, 0.6, 0.4, 0.2
730735
rates = []

0 commit comments

Comments
 (0)