Skip to content

Commit a77cbd3

Browse files
committed
misc
1 parent fd20ea1 commit a77cbd3

File tree

1 file changed

+76
-63
lines changed

1 file changed

+76
-63
lines changed

lectures/mccall_model_with_sep_markov.md

Lines changed: 76 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,30 @@ The wage updates are as follows:
8686
* If an unemployed agent rejects offer $w$, then their next offer is drawn from $P(w, \cdot)$
8787
* If an employed agent loses a job in which they were paid wage $w$, then their next offer is drawn from $P(w, \cdot)$
8888

89+
### The Wage Offer Process
90+
91+
To construct the wage offer process we start with an AR1 process.
92+
93+
$$
94+
X_{t+1} = \rho X_t + \nu Z_{t+1}
95+
$$
96+
97+
where $\{Z_t\}$ is IID and standard normal.
98+
99+
Informally, we set $W_t = \exp(Z_t)$.
100+
101+
In practice, we
102+
103+
* discretize the AR1 process using {ref}`Tauchen's method <fm_ex3>` and
104+
* take the exponential of the resulting wage offer values.
105+
106+
Below we will always choose $\rho \in (0, 1)$.
107+
108+
This means that the wage process will be positively correlated: the higher the current
109+
wage offer, the more likely we are to get a high offer tomorrow.
110+
111+
112+
89113
### Value Functions
90114

91115
We let
@@ -100,7 +124,10 @@ The only change is that expectations for next period are computed using the tran
100124
The unemployed worker's value function satisfies the Bellman equation
101125

102126
$$
103-
v_u(w) = \max\{v_e(w), u(c) + \beta \sum_{w'} v_u(w') P(w,w')\}
127+
v_u(w) = \max
128+
\left\{
129+
v_e(w), u(c) + \beta \sum_{w'} v_u(w') P(w,w')
130+
\right\}
104131
$$
105132

106133
The employed worker's value function satisfies the Bellman equation
@@ -119,7 +146,7 @@ $$
119146
(Ph)(w) = \sum_{w'} h(w') P(w,w')
120147
$$
121148

122-
(To understand this expression, think of $P$ as a matrix and $h$ as a column vector.)
149+
(To understand this expression, think of $P$ as a matrix, $h$ as a column vector, and $w$ as a row index.)
123150

124151
With this notation, the Bellman equations become
125152

@@ -139,33 +166,11 @@ $$
139166

140167
+++
141168

142-
### The Wage Process
143-
144-
To construct the wage offer process we start with an AR1 process.
145-
146-
$$
147-
X_{t+1} = \rho X_t + \nu Z_{t+1}
148-
$$
149-
150-
where $\{Z_t\}$ is IID and standard normal.
151-
152-
Informally, we set $W_t = \exp(Z_t)$.
153-
154-
In practice, we
155-
156-
* discretize the AR1 process using {ref}`Tauchen's method <fm_ex3>` and
157-
* take the exponential of the resulting wage offer values.
158-
159-
Below we will always choose $\rho \in (0, 1)$.
160-
161-
This means that the wage process will be positively correlated: the higher the current
162-
wage offer, the more likely we are to get a high offer tomorrow.
163-
164169

165170
## Computational Approach
166171

167172
To solve this problem, we use the employed worker's Bellman equation to express
168-
$v_e$ in terms of $Pv_u$:
173+
$v_e$ in terms of $Pv_u$
169174

170175
$$
171176
v_e(w) =
@@ -354,23 +359,28 @@ ax.set_xlabel(r"$w$")
354359
plt.show()
355360
```
356361

362+
The reservation wage is at the intersection of the stopping value function, which is
363+
equal to $v_e$, and the continuation value function, which is the value of
364+
rejecting
365+
357366
## Sensitivity Analysis
358367

359368
Let's examine how reservation wages change with the separation rate.
360369

361370
```{code-cell} ipython3
362371
α_vals: jnp.ndarray = jnp.linspace(0.0, 1.0, 10)
363372
364-
w_star_vec = jnp.empty_like(α_vals)
365-
for (i_α, α) in enumerate(α_vals):
373+
w_star_vec = []
374+
for α in α_vals:
366375
model = create_js_with_sep_model(α=α)
367376
v_star = vfi(model)
368377
w_star = get_reservation_wage(v_star, model)
369-
w_star_vec = w_star_vec.at[i_α].set(w_star)
378+
w_star_vec.append(w_star)
370379
371380
fig, ax = plt.subplots(figsize=(9, 5.2))
372-
ax.plot(α_vals, w_star_vec, linewidth=2, alpha=0.6,
373-
label="reservation wage")
381+
ax.plot(
382+
α_vals, w_star_vec, linewidth=2, alpha=0.6, label="reservation wage"
383+
)
374384
ax.legend(frameon=False)
375385
ax.set_xlabel(r"$\alpha$")
376386
ax.set_ylabel(r"$w$")
@@ -397,20 +407,22 @@ This is implemented via `jnp.searchsorted` on the precomputed cumulative sum
397407

398408
The function `update_agent` advances the agent's state by one period.
399409

410+
The agent's state is a pair $(s_t, w_t)$, where $s_t$ is employment status (0 if
411+
unemployed, 1 if employed) and $w_t$ is
412+
413+
* their current wage offer, if unemployed, or
414+
* their current wage, if employed.
415+
400416
```{code-cell} ipython3
401417
@jax.jit
402-
def update_agent(key, is_employed, wage_idx, model, w_star):
418+
def update_agent(key, status, wage_idx, model, w_star):
403419
"""
404-
Updates an agent by one period. Updates their employment status and their
405-
current wage (stored by index).
406-
407-
Agents who lose their job that pays wage w receive a new draw in the next
408-
period via the probabilites in P(w, .)
420+
Updates an agent's employment status and current wage.
409421
410422
Parameters:
411423
- key: JAX random key
412-
- is_employed: Current employment status (0 or 1)
413-
- wage_idx: Current wage index
424+
- status: Current employment status (0 or 1)
425+
- wage_idx: Current wage, recorded as an array index
414426
- model: Model instance
415427
- w_star: Reservation wage
416428
@@ -419,6 +431,7 @@ def update_agent(key, is_employed, wage_idx, model, w_star):
419431
420432
key1, key2 = jax.random.split(key)
421433
# Use precomputed cumulative sum for efficient sampling
434+
# via the inverse transform method.
422435
new_wage_idx = jnp.searchsorted(
423436
P_cumsum[wage_idx, :], jax.random.uniform(key1)
424437
)
@@ -428,21 +441,21 @@ def update_agent(key, is_employed, wage_idx, model, w_star):
428441
429442
# If employed: status = 1 if no separation, 0 if separation
430443
# If unemployed: status = 1 if accepts, 0 if rejects
431-
final_employment = jnp.where(
432-
is_employed,
444+
next_status = jnp.where(
445+
status,
433446
1 - separation_occurs.astype(jnp.int32), # employed path
434447
accepts.astype(jnp.int32) # unemployed path
435448
)
436449
437450
# If employed: wage = current if no separation, new if separation
438451
# If unemployed: wage = current if accepts, new if rejects
439-
final_wage = jnp.where(
440-
is_employed,
452+
next_wage = jnp.where(
453+
status,
441454
jnp.where(separation_occurs, new_wage_idx, wage_idx), # employed path
442455
jnp.where(accepts, wage_idx, new_wage_idx) # unemployed path
443456
)
444457
445-
return final_employment, final_wage
458+
return next_status, next_wage
446459
```
447460

448461
Here's a function to simulate the employment path of a single agent.
@@ -463,22 +476,22 @@ def simulate_employment_path(
463476
n, w_vals, P, P_cumsum, β, c, α, γ = model
464477
465478
# Initial conditions
466-
is_employed = 0
479+
status = 0
467480
wage_idx = 0
468481
469-
wage_path_list = []
470-
employment_status_list = []
482+
wage_path = []
483+
status_path = []
471484
472485
for t in range(T):
473-
wage_path_list.append(w_vals[wage_idx])
474-
employment_status_list.append(is_employed)
486+
wage_path.append(w_vals[wage_idx])
487+
status_path.append(status)
475488
476489
key, subkey = jax.random.split(key)
477-
is_employed, wage_idx = update_agent(
478-
subkey, is_employed, wage_idx, model, w_star
490+
status, wage_idx = update_agent(
491+
subkey, status, wage_idx, model, w_star
479492
)
480493
481-
return jnp.array(wage_path_list), jnp.array(employment_status_list)
494+
return jnp.array(wage_path), jnp.array(status_path)
482495
```
483496

484497
Let's create a comprehensive plot of the employment simulation:
@@ -631,23 +644,23 @@ def _simulate_cross_section_compiled(
631644
632645
# Initialize arrays
633646
wage_indices = jnp.zeros(n_agents, dtype=jnp.int32)
634-
is_employed = jnp.zeros(n_agents, dtype=jnp.int32)
647+
status = jnp.zeros(n_agents, dtype=jnp.int32)
635648
636649
def update(t, loop_state):
637-
key, is_employed, wage_indices = loop_state
650+
key, status, wage_indices = loop_state
638651
639652
# Shift loop state forwards
640653
key, subkey = jax.random.split(key)
641654
agent_keys = jax.random.split(subkey, n_agents)
642655
643-
is_employed, wage_indices = update_agents_vmap(
644-
agent_keys, is_employed, wage_indices, model, w_star
656+
status, wage_indices = update_agents_vmap(
657+
agent_keys, status, wage_indices, model, w_star
645658
)
646659
647-
return key, is_employed, wage_indices
660+
return key, status, wage_indices
648661
649662
# Run simulation using fori_loop
650-
initial_loop_state = (key, is_employed, wage_indices)
663+
initial_loop_state = (key, status, wage_indices)
651664
final_loop_state = lax.fori_loop(0, T, update, initial_loop_state)
652665
653666
# Return only final employment state
@@ -680,12 +693,12 @@ def simulate_cross_section(
680693
w_star = get_reservation_wage(v_star, model)
681694
682695
# Run JIT-compiled simulation
683-
final_employment = _simulate_cross_section_compiled(
696+
final_status = _simulate_cross_section_compiled(
684697
key, model, w_star, n_agents, T
685698
)
686699
687700
# Calculate unemployment rate at final period
688-
unemployment_rate = 1 - jnp.mean(final_employment)
701+
unemployment_rate = 1 - jnp.mean(final_status)
689702
690703
return unemployment_rate
691704
```
@@ -707,18 +720,18 @@ def plot_cross_sectional_unemployment(model: Model, t_snapshot: int = 200,
707720
key = jax.random.PRNGKey(42)
708721
v_star = vfi(model)
709722
w_star = get_reservation_wage(v_star, model)
710-
final_employment = _simulate_cross_section_compiled(
723+
final_status = _simulate_cross_section_compiled(
711724
key, model, w_star, n_agents, t_snapshot
712725
)
713726
714727
# Calculate unemployment rate
715-
unemployment_rate = 1 - jnp.mean(final_employment)
728+
unemployment_rate = 1 - jnp.mean(final_status)
716729
717730
fig, ax = plt.subplots(figsize=(8, 5))
718731
719732
# Plot histogram as density (bars sum to 1)
720-
weights = jnp.ones_like(final_employment) / len(final_employment)
721-
ax.hist(final_employment, bins=[-0.5, 0.5, 1.5],
733+
weights = jnp.ones_like(final_status) / len(final_status)
734+
ax.hist(final_status, bins=[-0.5, 0.5, 1.5],
722735
alpha=0.7, color='blue', edgecolor='black',
723736
density=True, weights=weights)
724737

0 commit comments

Comments
 (0)