@@ -86,13 +86,13 @@ When unemployed and receiving wage offer $w$, the agent chooses between:
8686
8787## Bellman Equations
8888
89- The unemployed worker's value function satisfies:
89+ The unemployed worker's value function satisfies the Bellman equation
9090
9191$$
9292 v_u(w) = \max\{v_e(w), c + \beta \sum_{w'} v_u(w') P(w,w')\}
9393$$
9494
95- The employed worker's value function satisfies:
95+ The employed worker's value function satisfies the Bellman equation
9696
9797$$
9898 v_e(w) =
107107
108108We use the following approach to solve this problem.
109109
110- 1 . Solve the employed value function analytically:
110+ (As usual, for a function $h$ we set $(Ph)(w) = \sum_ {w'} h(w') P(w,w')$.)
111+
112+ 1 . Use the employed worker's Bellman equation to express $v_e$ in terms of
113+ $Pv_u$:
111114
112115$$
113116 v_e(w) =
@@ -170,19 +173,12 @@ def successive_approx(
170173 return x_final
171174```
172175
173- Let's set up a ` Model ` class to store information needed to solve the model.
174176
175- We include ` P_cumsum ` , the row-wise cumulative sum of the transition matrix, to
176- optimize the simulation.
177+ Next let's set up a ` Model ` class to store information needed to solve the model.
177178
178- When simulating the Markov chain, we need to draw from the distribution in each
179- row of $P$ many times.
180-
181- Rather than computing the cumulative sum repeatedly during simulation, we
182- precompute it once and store it in the model.
179+ We include ` P_cumsum ` , the row-wise cumulative sum of the transition matrix, to
180+ optimize the simulation -- the details are explained below.
183181
184- This converts millions of O(n) cumsum operations into a single precomputation,
185- significantly speeding up large-scale simulations.
186182
187183``` {code-cell} ipython3
188184class Model(NamedTuple):
@@ -348,15 +344,19 @@ Can you provide an intuitive economic story behind the outcome that you see in t
348344
349345Now let's simulate the employment dynamics of a single agent under the optimal policy.
350346
351- The function ` update_agent ` advances the agent's state by one period.
347+ Note that, when simulating the Markov chain for wage offers, we need to draw from the distribution in each
348+ row of $P$ many times.
352349
353- To draw from the Markov chain transition probabilities , we use the inverse
350+ To do this , we use the inverse
354351transform method: draw a uniform random variable and find where it falls in the
355352cumulative distribution.
356353
357354This is implemented via ` jnp.searchsorted ` on the precomputed cumulative sum
358355` P_cumsum ` , which is much faster than recomputing the cumulative sum each time.
359356
357+ The function ` update_agent ` advances the agent's state by one period.
358+
359+
360360``` {code-cell} ipython3
361361@jit
362362def update_agent(key, is_employed, wage_idx, model, σ):
@@ -372,7 +372,9 @@ def update_agent(key, is_employed, wage_idx, model, σ):
372372
373373 key1, key2 = jax.random.split(key)
374374 # Use precomputed cumulative sum for efficient sampling
375- new_wage_idx = jnp.searchsorted(P_cumsum[wage_idx, :], jax.random.uniform(key1))
375+ new_wage_idx = jnp.searchsorted(
376+ P_cumsum[wage_idx, :], jax.random.uniform(key1)
377+ )
376378 separation_occurs = jax.random.uniform(key2) < α
377379 accepts = σ[wage_idx]
378380
@@ -487,27 +489,32 @@ plt.show()
487489
488490The simulation helps to visualize outcomes associated with this model.
489491
490- The agent follows a reservation wage strategy and there are clear cycles between unemployment and employment spells
492+ The agent follows a reservation wage strategy.
493+
494+ Often the agent loses her job and immediately takes another job at a different
495+ wage.
496+
497+ This is because she uses the wage $w$ from her last job to draw a new wage offer
498+ via $P(w, \cdot)$, and positive correlation means that a high current $w$ is
499+ often leads a high new draw.
491500
492- The model captures key features of labor market dynamics
493- with job separation, showing how workers optimally balance the trade-off between
494- accepting current offers versus waiting for better opportunities.
495501
496502
497503## The Ergodic Property
498504
499505Below we examine cross-sectional unemployment.
500506
501507In particular, we will look at the unemployment rate in a cross-sectional
502- simulation and compare it to the time-average unemployment rate, which is the fraction of time an agent spends unemployed over a long time series.
508+ simulation and compare it to the time-average unemployment rate, which is the
509+ fraction of time an agent spends unemployed over a long time series.
503510
504511We will see that these two values are approximately equal -- if fact they are
505512exactly equal in the limit.
506513
507514The reason is that the process $(s_t, w_t)$, where
508515
509- - $s_t \in \{ \text{employed}, \text{unemployed} \} $ is the employment status and
510- - $w_t \in \{ 1, 2, \ldots, n \} $ is the wage
516+ - $s_t$ is the employment status and
517+ - $w_t$ is the wage
511518
512519is Markovian, since the next pair depends only on the current pair and iid
513520randomness, and ergodic.
533540 \lim_{T \to \infty} \frac{1}{T} \sum_{t=1}^{T} \mathbb{1}\{s_t = \text{unemployed}\} = \sum_{w=1}^{n} \pi(\text{unemployed}, w)
534541$$
535542
536- This holds regardless of initial conditions—whether an agent starts employed or unemployed, they converge to the same long-run distribution.
543+ This holds regardless of initial conditions -- provided that we burn in the
544+ cross-sectional distribution (run it forward in time from a given initial cross
545+ section in order to remove the influence of that initial condition).
537546
538547As a result, we can study steady-state unemployment either by:
539548
540549- Following one agent for a long time (time average), or
541550- Observing many agents at a single point in time (cross-sectional average)
542551
543- Both approaches yield the same steady-state unemployment rate.
544-
545- Often the second approach is better for our purposes, since it's far easier to
546- parallelize.
552+ Often the second approach is better for our purposes, since it's easier to parallelize.
547553
548554
549555## Cross-Sectional Analysis
@@ -566,7 +572,8 @@ def _simulate_cross_section_compiled(
566572 n_agents: int,
567573 T: int
568574 ):
569- """JIT-compiled core simulation loop using lax.scan."""
575+ """JIT-compiled core simulation loop using lax.scan.
576+ Returns only the final employment state to save memory."""
570577 n, w_vals, P, P_cumsum, β, c, α = model
571578
572579 # Initialize arrays
@@ -576,42 +583,38 @@ def _simulate_cross_section_compiled(
576583 def scan_fn(loop_state, t):
577584 key, is_employed, wage_indices = loop_state
578585
579- # Record employment status for this time step
580- employment_status = is_employed
581-
582- # Shift loop state forwards
583- key, *agent_keys = jax.random.split(key, n_agents + 1)
584- agent_keys = jnp.array(agent_keys)
586+ # Shift loop state forwards - more efficient key generation
587+ key, subkey = jax.random.split(key)
588+ agent_keys = jax.random.split(subkey, n_agents)
585589
586590 is_employed, wage_indices = update_agents_vmap(
587591 agent_keys, is_employed, wage_indices, model, σ
588592 )
589593
590594 # Pack results and return
591595 new_loop_state = key, is_employed, wage_indices
592- return new_loop_state, employment_status
596+ return new_loop_state, None
593597
594598 # Run simulation using scan
595599 initial_loop_state = (key, is_employed, wage_indices)
596600
597- final_loop_state, employment_matrix = lax.scan(
601+ final_loop_state, _ = lax.scan(
598602 scan_fn, initial_loop_state, jnp.arange(T)
599603 )
600604
601- # Transpose to get (n_agents, T) shape
602- employment_matrix = employment_matrix.T
603-
604- return employment_matrix
605+ # Return only final employment state
606+ _, final_is_employed, _ = final_loop_state
607+ return final_is_employed
605608
606609
607610def simulate_cross_section(
608611 model: Model,
609612 n_agents: int = 100_000,
610613 T: int = 200,
611614 seed: int = 42
612- ) -> tuple[jnp.ndarray, jnp.ndarray] :
615+ ) -> float :
613616 """
614- Simulate employment paths for many agents simultaneously .
617+ Simulate employment paths for many agents and return final unemployment rate .
615618
616619 Parameters:
617620 - model: Model instance with parameters
@@ -620,67 +623,80 @@ def simulate_cross_section(
620623 - seed: Random seed for reproducibility
621624
622625 Returns:
623- - unemployment_rates: Fraction of agents unemployed at each period
624- - employment_matrix: n_agents x T matrix of employment status
626+ - unemployment_rate: Fraction of agents unemployed at time T
625627 """
626628 key = jax.random.PRNGKey(seed)
627629
628630 # Solve for optimal policy
629631 v_star, σ_star = vfi(model)
630-
632+
631633 # Run JIT-compiled simulation
632- employment_matrix = _simulate_cross_section_compiled(
634+ final_employment = _simulate_cross_section_compiled(
633635 key, model, σ_star, n_agents, T
634636 )
635-
636- # Calculate unemployment rate at each period
637- unemployment_rates = 1 - jnp.mean(employment_matrix, axis=0)
638637
639- return unemployment_rates, employment_matrix
638+ # Calculate unemployment rate at final period
639+ unemployment_rate = 1 - jnp.mean(final_employment)
640+
641+ return unemployment_rate
640642```
641643
642644``` {code-cell} ipython3
643- def plot_cross_sectional_unemployment(model: Model):
645+ def plot_cross_sectional_unemployment(model: Model, t_snapshot: int = 200,
646+ n_agents: int = 20_000):
644647 """
645- Generate cross-sectional unemployment rate plot for a given model .
648+ Generate histogram of cross-sectional unemployment at a specific time .
646649
647650 Parameters:
648651 - model: Model instance with parameters
652+ - t_snapshot: Time period at which to take the cross-sectional snapshot
653+ - n_agents: Number of agents to simulate
649654 """
650- unemployment_rates, employment_matrix = simulate_cross_section(model)
651-
652- fig, ax = plt.subplots(figsize=(8, 4))
653-
654- # Plot unemployment rate over time
655- ax.plot(unemployment_rates, 'b-', alpha=0.8, linewidth=1.5,
656- label=f'Cross-sectional unemployment rate (c={model.c})')
657-
658- # Add shaded region for ±1 standard deviation
659- window_size = 50
660- rolling_std = jnp.array([
661- jnp.std(unemployment_rates[max(0, t-window_size):t+1])
662- for t in range(len(unemployment_rates))
663- ])
664-
665- ax.fill_between(range(len(unemployment_rates)),
666- unemployment_rates - rolling_std,
667- unemployment_rates + rolling_std,
668- alpha=0.2, color='blue',
669- label='±1 rolling std')
670-
671- ax.set_xlabel('time')
672- ax.set_ylabel('unemployment rate')
673- ax.set_title(f'Cross-sectional unemployment rate (c={model.c})')
674- ax.grid(alpha=0.4)
675- ax.set_ylim(0, 1)
676- ax.legend()
655+ # Get final employment state directly
656+ key = jax.random.PRNGKey(42)
657+ v_star, σ_star = vfi(model)
658+ final_employment = _simulate_cross_section_compiled(
659+ key, model, σ_star, n_agents, t_snapshot
660+ )
661+
662+ # Calculate unemployment rate
663+ unemployment_rate = 1 - jnp.mean(final_employment)
664+
665+ fig, ax = plt.subplots(figsize=(8, 5))
666+
667+ # Plot histogram as density (bars sum to 1)
668+ weights = jnp.ones_like(final_employment) / len(final_employment)
669+ ax.hist(final_employment, bins=[-0.5, 0.5, 1.5],
670+ alpha=0.7, color='blue', edgecolor='black',
671+ density=True, weights=weights)
672+
673+ ax.set_xlabel('employment status (0=unemployed, 1=employed)')
674+ ax.set_ylabel('density')
675+ ax.set_title(f'Cross-sectional distribution at t={t_snapshot}, ' +
676+ f'unemployment rate = {unemployment_rate:.3f}')
677+ ax.set_xticks([0, 1])
677678
678679 plt.tight_layout()
679680 plt.show()
680681```
681682
682683``` {code-cell} ipython3
683684model = create_js_with_sep_model()
685+ cross_sectional_unemp = simulate_cross_section(
686+ model, n_agents=20_000, T=200
687+ )
688+
689+ time_avg_unemp = jnp.mean(unemployed_indicator)
690+ print(f"Time-average unemployment rate (single agent): "
691+ f"{time_avg_unemp:.4f}")
692+ print(f"Cross-sectional unemployment rate (at t=200): "
693+ f"{cross_sectional_unemp:.4f}")
694+ print(f"Difference: {abs(time_avg_unemp - cross_sectional_unemp):.4f}")
695+ ```
696+
697+ Now let's visualize the cross-sectional distribution:
698+
699+ ``` {code-cell} ipython3
684700plot_cross_sectional_unemployment(model)
685701```
686702
@@ -714,14 +730,16 @@ c_values = 1.0, 0.8, 0.6, 0.4, 0.2
714730rates = []
715731for c in c_values:
716732 model = create_js_with_sep_model(c=c)
717- unemployment_rates, employment_matrix = simulate_cross_section(model)
718- rates.append(unemployment_rates[-1] )
733+ unemployment_rate = simulate_cross_section(model)
734+ rates.append(unemployment_rate )
719735
720736fig, ax = plt.subplots()
721737ax.plot(
722738 c_values, rates, alpha=0.8,
723- linewidth=1.5, label=f'Unemployment rate at c={c} '
739+ linewidth=1.5, label='Steady-state unemployment rate '
724740)
741+ ax.set_xlabel('unemployment compensation (c)')
742+ ax.set_ylabel('unemployment rate')
725743ax.legend(frameon=False)
726744plt.show()
727745```
0 commit comments