@@ -86,6 +86,30 @@ The wage updates are as follows:
8686* If an unemployed agent rejects offer $w$, then their next offer is drawn from $P(w, \cdot)$
8787* If an employed agent loses a job in which they were paid wage $w$, then their next offer is drawn from $P(w, \cdot)$
8888
89+ ### The Wage Offer Process
90+
91+ To construct the wage offer process we start with an AR1 process.
92+
93+ $$
94+ X_{t+1} = \rho X_t + \nu Z_{t+1}
95+ $$
96+
97+ where $\{ Z_t\} $ is IID and standard normal.
98+
99+ Informally, we set $W_t = \exp(Z_t)$.
100+
101+ In practice, we
102+
103+ * discretize the AR1 process using {ref}` Tauchen's method <fm_ex3> ` and
104+ * take the exponential of the resulting wage offer values.
105+
106+ Below we will always choose $\rho \in (0, 1)$.
107+
108+ This means that the wage process will be positively correlated: the higher the current
109+ wage offer, the more likely we are to get a high offer tomorrow.
110+
111+
112+
89113### Value Functions
90114
91115We let
@@ -100,7 +124,10 @@ The only change is that expectations for next period are computed using the tran
100124The unemployed worker's value function satisfies the Bellman equation
101125
102126$$
103- v_u(w) = \max\{v_e(w), u(c) + \beta \sum_{w'} v_u(w') P(w,w')\}
127+ v_u(w) = \max
128+ \left\{
129+ v_e(w), u(c) + \beta \sum_{w'} v_u(w') P(w,w')
130+ \right\}
104131$$
105132
106133The employed worker's value function satisfies the Bellman equation
119146 (Ph)(w) = \sum_{w'} h(w') P(w,w')
120147$$
121148
122- (To understand this expression, think of $P$ as a matrix and $h$ as a column vector.)
149+ (To understand this expression, think of $P$ as a matrix, $h$ as a column vector, and $w$ as a row index .)
123150
124151With this notation, the Bellman equations become
125152
139166
140167+++
141168
142- ### The Wage Process
143-
144- To construct the wage offer process we start with an AR1 process.
145-
146- $$
147- X_{t+1} = \rho X_t + \nu Z_{t+1}
148- $$
149-
150- where $\{ Z_t\} $ is IID and standard normal.
151-
152- Informally, we set $W_t = \exp(Z_t)$.
153-
154- In practice, we
155-
156- * discretize the AR1 process using {ref}` Tauchen's method <fm_ex3> ` and
157- * take the exponential of the resulting wage offer values.
158-
159- Below we will always choose $\rho \in (0, 1)$.
160-
161- This means that the wage process will be positively correlated: the higher the current
162- wage offer, the more likely we are to get a high offer tomorrow.
163-
164169
165170## Computational Approach
166171
167172To solve this problem, we use the employed worker's Bellman equation to express
168- $v_e$ in terms of $Pv_u$:
173+ $v_e$ in terms of $Pv_u$
169174
170175$$
171176 v_e(w) =
@@ -354,23 +359,28 @@ ax.set_xlabel(r"$w$")
354359plt.show()
355360```
356361
362+ The reservation wage is at the intersection of the stopping value function, which is
363+ equal to $v_e$, and the continuation value function, which is the value of
364+ rejecting
365+
357366## Sensitivity Analysis
358367
359368Let's examine how reservation wages change with the separation rate.
360369
361370``` {code-cell} ipython3
362371α_vals: jnp.ndarray = jnp.linspace(0.0, 1.0, 10)
363372
364- w_star_vec = jnp.empty_like(α_vals)
365- for (i_α, α) in enumerate( α_vals) :
373+ w_star_vec = []
374+ for α in α_vals:
366375 model = create_js_with_sep_model(α=α)
367376 v_star = vfi(model)
368377 w_star = get_reservation_wage(v_star, model)
369- w_star_vec = w_star_vec.at[i_α].set (w_star)
378+ w_star_vec.append (w_star)
370379
371380fig, ax = plt.subplots(figsize=(9, 5.2))
372- ax.plot(α_vals, w_star_vec, linewidth=2, alpha=0.6,
373- label="reservation wage")
381+ ax.plot(
382+ α_vals, w_star_vec, linewidth=2, alpha=0.6, label="reservation wage"
383+ )
374384ax.legend(frameon=False)
375385ax.set_xlabel(r"$\alpha$")
376386ax.set_ylabel(r"$w$")
@@ -397,20 +407,22 @@ This is implemented via `jnp.searchsorted` on the precomputed cumulative sum
397407
398408The function ` update_agent ` advances the agent's state by one period.
399409
410+ The agent's state is a pair $(s_t, w_t)$, where $s_t$ is employment status (0 if
411+ unemployed, 1 if employed) and $w_t$ is
412+
413+ * their current wage offer, if unemployed, or
414+ * their current wage, if employed.
415+
400416``` {code-cell} ipython3
401417@jax.jit
402- def update_agent(key, is_employed , wage_idx, model, w_star):
418+ def update_agent(key, status , wage_idx, model, w_star):
403419 """
404- Updates an agent by one period. Updates their employment status and their
405- current wage (stored by index).
406-
407- Agents who lose their job that pays wage w receive a new draw in the next
408- period via the probabilites in P(w, .)
420+ Updates an agent's employment status and current wage.
409421
410422 Parameters:
411423 - key: JAX random key
412- - is_employed : Current employment status (0 or 1)
413- - wage_idx: Current wage index
424+ - status : Current employment status (0 or 1)
425+ - wage_idx: Current wage, recorded as an array index
414426 - model: Model instance
415427 - w_star: Reservation wage
416428
@@ -419,6 +431,7 @@ def update_agent(key, is_employed, wage_idx, model, w_star):
419431
420432 key1, key2 = jax.random.split(key)
421433 # Use precomputed cumulative sum for efficient sampling
434+ # via the inverse transform method.
422435 new_wage_idx = jnp.searchsorted(
423436 P_cumsum[wage_idx, :], jax.random.uniform(key1)
424437 )
@@ -428,21 +441,21 @@ def update_agent(key, is_employed, wage_idx, model, w_star):
428441
429442 # If employed: status = 1 if no separation, 0 if separation
430443 # If unemployed: status = 1 if accepts, 0 if rejects
431- final_employment = jnp.where(
432- is_employed ,
444+ next_status = jnp.where(
445+ status ,
433446 1 - separation_occurs.astype(jnp.int32), # employed path
434447 accepts.astype(jnp.int32) # unemployed path
435448 )
436449
437450 # If employed: wage = current if no separation, new if separation
438451 # If unemployed: wage = current if accepts, new if rejects
439- final_wage = jnp.where(
440- is_employed ,
452+ next_wage = jnp.where(
453+ status ,
441454 jnp.where(separation_occurs, new_wage_idx, wage_idx), # employed path
442455 jnp.where(accepts, wage_idx, new_wage_idx) # unemployed path
443456 )
444457
445- return final_employment, final_wage
458+ return next_status, next_wage
446459```
447460
448461Here's a function to simulate the employment path of a single agent.
@@ -463,22 +476,22 @@ def simulate_employment_path(
463476 n, w_vals, P, P_cumsum, β, c, α, γ = model
464477
465478 # Initial conditions
466- is_employed = 0
479+ status = 0
467480 wage_idx = 0
468481
469- wage_path_list = []
470- employment_status_list = []
482+ wage_path = []
483+ status_path = []
471484
472485 for t in range(T):
473- wage_path_list .append(w_vals[wage_idx])
474- employment_status_list .append(is_employed )
486+ wage_path .append(w_vals[wage_idx])
487+ status_path .append(status )
475488
476489 key, subkey = jax.random.split(key)
477- is_employed , wage_idx = update_agent(
478- subkey, is_employed , wage_idx, model, w_star
490+ status , wage_idx = update_agent(
491+ subkey, status , wage_idx, model, w_star
479492 )
480493
481- return jnp.array(wage_path_list ), jnp.array(employment_status_list )
494+ return jnp.array(wage_path ), jnp.array(status_path )
482495```
483496
484497Let's create a comprehensive plot of the employment simulation:
@@ -631,23 +644,23 @@ def _simulate_cross_section_compiled(
631644
632645 # Initialize arrays
633646 wage_indices = jnp.zeros(n_agents, dtype=jnp.int32)
634- is_employed = jnp.zeros(n_agents, dtype=jnp.int32)
647+ status = jnp.zeros(n_agents, dtype=jnp.int32)
635648
636649 def update(t, loop_state):
637- key, is_employed , wage_indices = loop_state
650+ key, status , wage_indices = loop_state
638651
639652 # Shift loop state forwards
640653 key, subkey = jax.random.split(key)
641654 agent_keys = jax.random.split(subkey, n_agents)
642655
643- is_employed , wage_indices = update_agents_vmap(
644- agent_keys, is_employed , wage_indices, model, w_star
656+ status , wage_indices = update_agents_vmap(
657+ agent_keys, status , wage_indices, model, w_star
645658 )
646659
647- return key, is_employed , wage_indices
660+ return key, status , wage_indices
648661
649662 # Run simulation using fori_loop
650- initial_loop_state = (key, is_employed , wage_indices)
663+ initial_loop_state = (key, status , wage_indices)
651664 final_loop_state = lax.fori_loop(0, T, update, initial_loop_state)
652665
653666 # Return only final employment state
@@ -680,12 +693,12 @@ def simulate_cross_section(
680693 w_star = get_reservation_wage(v_star, model)
681694
682695 # Run JIT-compiled simulation
683- final_employment = _simulate_cross_section_compiled(
696+ final_status = _simulate_cross_section_compiled(
684697 key, model, w_star, n_agents, T
685698 )
686699
687700 # Calculate unemployment rate at final period
688- unemployment_rate = 1 - jnp.mean(final_employment )
701+ unemployment_rate = 1 - jnp.mean(final_status )
689702
690703 return unemployment_rate
691704```
@@ -707,18 +720,18 @@ def plot_cross_sectional_unemployment(model: Model, t_snapshot: int = 200,
707720 key = jax.random.PRNGKey(42)
708721 v_star = vfi(model)
709722 w_star = get_reservation_wage(v_star, model)
710- final_employment = _simulate_cross_section_compiled(
723+ final_status = _simulate_cross_section_compiled(
711724 key, model, w_star, n_agents, t_snapshot
712725 )
713726
714727 # Calculate unemployment rate
715- unemployment_rate = 1 - jnp.mean(final_employment )
728+ unemployment_rate = 1 - jnp.mean(final_status )
716729
717730 fig, ax = plt.subplots(figsize=(8, 5))
718731
719732 # Plot histogram as density (bars sum to 1)
720- weights = jnp.ones_like(final_employment ) / len(final_employment )
721- ax.hist(final_employment , bins=[-0.5, 0.5, 1.5],
733+ weights = jnp.ones_like(final_status ) / len(final_status )
734+ ax.hist(final_status , bins=[-0.5, 0.5, 1.5],
722735 alpha=0.7, color='blue', edgecolor='black',
723736 density=True, weights=weights)
724737
0 commit comments