Skip to content

Commit 90f6a6f

Browse files
authored
Merge branch 'main' into remove-pytorch-pyro
2 parents 8c0d609 + 0d564d4 commit 90f6a6f

File tree

5 files changed

+211
-181
lines changed

5 files changed

+211
-181
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.

lectures/mccall_model_with_sep_markov.md

Lines changed: 61 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ jupytext:
44
extension: .md
55
format_name: myst
66
format_version: 0.13
7-
jupytext_version: 1.17.1
7+
jupytext_version: 1.17.2
88
kernelspec:
9-
name: python3
109
display_name: Python 3 (ipykernel)
1110
language: python
11+
name: python3
1212
---
1313

1414
(mccall_with_sep_markov)=
@@ -49,7 +49,7 @@ libraries
4949
```{code-cell} ipython3
5050
:tags: [hide-output]
5151
52-
!pip install quantecon jax
52+
!pip install quantecon
5353
```
5454

5555
We use the following imports:
@@ -58,7 +58,7 @@ We use the following imports:
5858
from quantecon.markov import tauchen
5959
import jax.numpy as jnp
6060
import jax
61-
from jax import jit, lax
61+
from jax import lax
6262
from typing import NamedTuple
6363
import matplotlib.pyplot as plt
6464
from functools import partial
@@ -138,48 +138,11 @@ The optimal policy turns out to be a reservation wage strategy: accept all wages
138138

139139
## Code
140140

141-
142-
First, we implement the successive approximation algorithm.
143-
144-
This algorithm takes an operator $T$ and an initial condition and iterates until
145-
convergence.
146-
147-
We will use it for value function iteration.
148-
149-
```{code-cell} ipython3
150-
@partial(jit, static_argnums=(0,))
151-
def successive_approx(
152-
T, # Operator (callable) - marked as static
153-
x_0, # Initial condition
154-
tolerance: float = 1e-6, # Error tolerance
155-
max_iter: int = 100_000, # Max iteration bound
156-
):
157-
"""Computes the approximate fixed point of T via successive
158-
approximation using lax.while_loop."""
159-
160-
def cond_fn(carry):
161-
x, error, k = carry
162-
return (error > tolerance) & (k <= max_iter)
163-
164-
def body_fn(carry):
165-
x, error, k = carry
166-
x_new = T(x)
167-
error = jnp.max(jnp.abs(x_new - x))
168-
return (x_new, error, k + 1)
169-
170-
initial_carry = (x_0, tolerance + 1, 1)
171-
x_final, _, _ = lax.while_loop(cond_fn, body_fn, initial_carry)
172-
173-
return x_final
174-
```
175-
176-
177-
Next let's set up a `Model` class to store information needed to solve the model.
141+
Let's set up a `Model` class to store information needed to solve the model.
178142

179143
We include `P_cumsum`, the row-wise cumulative sum of the transition matrix, to
180144
optimize the simulation -- the details are explained below.
181145

182-
183146
```{code-cell} ipython3
184147
class Model(NamedTuple):
185148
n: int
@@ -215,7 +178,6 @@ def create_js_with_sep_model(
215178
Here's the Bellman operator for the unemployed worker's value function:
216179

217180
```{code-cell} ipython3
218-
@jit
219181
def T(v: jnp.ndarray, model: Model) -> jnp.ndarray:
220182
"""The Bellman operator for the value of being unemployed."""
221183
n, w_vals, P, P_cumsum, β, c, α = model
@@ -229,7 +191,6 @@ The next function computes the optimal policy under the assumption that $v$ is
229191
the value function:
230192

231193
```{code-cell} ipython3
232-
@jit
233194
def get_greedy(v: jnp.ndarray, model: Model) -> jnp.ndarray:
234195
"""Get a v-greedy policy."""
235196
n, w_vals, P, P_cumsum, β, c, α = model
@@ -247,14 +208,34 @@ The second routine requires a policy function, which we will typically obtain by
247208
applying the `vfi` function.
248209

249210
```{code-cell} ipython3
250-
def vfi(model: Model):
251-
"""Solve by VFI."""
211+
@jax.jit
212+
def vfi(
213+
model: Model,
214+
tolerance: float = 1e-6, # Error tolerance
215+
max_iter: int = 100_000, # Max iteration bound
216+
):
217+
252218
v_init = jnp.zeros(model.w_vals.shape)
253-
v_star = successive_approx(lambda v: T(v, model), v_init)
254-
σ_star = get_greedy(v_star, model)
255-
return v_star, σ_star
219+
220+
def cond(loop_state):
221+
v, error, i = loop_state
222+
return (error > tolerance) & (i <= max_iter)
223+
224+
def update(loop_state):
225+
v, error, i = loop_state
226+
v_new = T(v, model)
227+
error = jnp.max(jnp.abs(v_new - v))
228+
new_loop_state = v_new, error, i + 1
229+
return new_loop_state
230+
231+
initial_state = (v_init, tolerance + 1, 1)
232+
final_loop_state = lax.while_loop(cond, update, initial_state)
233+
v_final, error, i = final_loop_state
256234
235+
return v_final
257236
237+
238+
@jax.jit
258239
def get_reservation_wage(σ: jnp.ndarray, model: Model) -> float:
259240
"""
260241
Calculate the reservation wage from a given policy.
@@ -268,25 +249,24 @@ def get_reservation_wage(σ: jnp.ndarray, model: Model) -> float:
268249
"""
269250
n, w_vals, P, P_cumsum, β, c, α = model
270251
271-
# Find all wage indices where policy indicates acceptance
272-
accept_indices = jnp.where(σ == 1)[0]
273-
274-
if len(accept_indices) == 0:
275-
return jnp.inf # Agent never accepts any wage
252+
# Find the first index where policy indicates acceptance
253+
# σ is a boolean array, argmax returns the first True value
254+
first_accept_idx = jnp.argmax(σ)
276255
277-
# Return the lowest wage that is accepted
278-
return w_vals[accept_indices[0]]
256+
# If no acceptance (all False), return infinity
257+
# Otherwise return the wage at the first acceptance index
258+
return jnp.where(jnp.any(σ), w_vals[first_accept_idx], jnp.inf)
279259
```
280260

281-
282261
## Computing the Solution
283262

284263
Let's solve the model:
285264

286265
```{code-cell} ipython3
287266
model = create_js_with_sep_model()
288267
n, w_vals, P, P_cumsum, β, c, α = model
289-
v_star, σ_star = vfi(model)
268+
v_star = vfi(model)
269+
σ_star = get_greedy(v_star, model)
290270
```
291271

292272
Next we compute some related quantities, including the reservation wage.
@@ -312,19 +292,18 @@ ax.set_xlabel(r"$w$")
312292
plt.show()
313293
```
314294

315-
316295
## Sensitivity Analysis
317296

318297
Let's examine how reservation wages change with the separation rate.
319298

320-
321299
```{code-cell} ipython3
322300
α_vals: jnp.ndarray = jnp.linspace(0.0, 1.0, 10)
323301
324302
w_star_vec = jnp.empty_like(α_vals)
325303
for (i_α, α) in enumerate(α_vals):
326304
model = create_js_with_sep_model(α=α)
327-
v_star, σ_star = vfi(model)
305+
v_star = vfi(model)
306+
σ_star = get_greedy(v_star, model)
328307
w_star = get_reservation_wage(σ_star, model)
329308
w_star_vec = w_star_vec.at[i_α].set(w_star)
330309
@@ -356,9 +335,8 @@ This is implemented via `jnp.searchsorted` on the precomputed cumulative sum
356335

357336
The function `update_agent` advances the agent's state by one period.
358337

359-
360338
```{code-cell} ipython3
361-
@jit
339+
@jax.jit
362340
def update_agent(key, is_employed, wage_idx, model, σ):
363341
"""
364342
Updates an agent by one period. Updates their employment status and their
@@ -439,7 +417,8 @@ Let's create a comprehensive plot of the employment simulation:
439417
model = create_js_with_sep_model()
440418
441419
# Calculate reservation wage for plotting
442-
v_star, σ_star = vfi(model)
420+
v_star = vfi(model)
421+
σ_star = get_greedy(v_star, model)
443422
w_star = get_reservation_wage(σ_star, model)
444423
445424
wage_path, employment_status = simulate_employment_path(model, σ_star)
@@ -486,7 +465,6 @@ plt.tight_layout()
486465
plt.show()
487466
```
488467

489-
490468
The simulation helps to visualize outcomes associated with this model.
491469

492470
The agent follows a reservation wage strategy.
@@ -531,7 +509,7 @@ This holds because:
531509

532510
These properties ensure the chain is ergodic with a unique stationary distribution $\pi$ over states $(s, w)$.
533511

534-
For an ergodic Markov chain, the ergodic theorem guarantees that time averages = ensemble averages.
512+
For an ergodic Markov chain, the ergodic theorem guarantees that time averages = cross-sectional averages.
535513

536514
In particular, the fraction of time a single agent spends unemployed (across all
537515
wage states) converges to the cross-sectional unemployment rate:
@@ -568,7 +546,7 @@ update_agents_vmap = jax.vmap(
568546
Next we define the core simulation function, which uses `lax.fori_loop` to efficiently iterate many agents forward in time:
569547

570548
```{code-cell} ipython3
571-
@partial(jit, static_argnums=(3, 4))
549+
@partial(jax.jit, static_argnums=(3, 4))
572550
def _simulate_cross_section_compiled(
573551
key: jnp.ndarray,
574552
model: Model,
@@ -627,7 +605,8 @@ def simulate_cross_section(
627605
key = jax.random.PRNGKey(seed)
628606
629607
# Solve for optimal policy
630-
v_star, σ_star = vfi(model)
608+
v_star = vfi(model)
609+
σ_star = get_greedy(v_star, model)
631610
632611
# Run JIT-compiled simulation
633612
final_employment = _simulate_cross_section_compiled(
@@ -655,7 +634,8 @@ def plot_cross_sectional_unemployment(model: Model, t_snapshot: int = 200,
655634
"""
656635
# Get final employment state directly
657636
key = jax.random.PRNGKey(42)
658-
v_star, σ_star = vfi(model)
637+
v_star = vfi(model)
638+
σ_star = get_greedy(v_star, model)
659639
final_employment = _simulate_cross_section_compiled(
660640
key, model, σ_star, n_agents, t_snapshot
661641
)
@@ -681,7 +661,12 @@ def plot_cross_sectional_unemployment(model: Model, t_snapshot: int = 200,
681661
plt.show()
682662
```
683663

684-
Now let's compare the time-average unemployment rate (from a single agent's long simulation) with the cross-sectional unemployment rate (from many agents at a single point in time):
664+
Now let's compare the time-average unemployment rate (from a single agent's long simulation) with the cross-sectional unemployment rate (from many agents at a single point in time).
665+
666+
We claimed above that these numbers will be approximately equal in large
667+
samples, due to ergodicity.
668+
669+
Let's see if that's true.
685670

686671
```{code-cell} ipython3
687672
model = create_js_with_sep_model()
@@ -697,28 +682,31 @@ print(f"Cross-sectional unemployment rate (at t=200): "
697682
print(f"Difference: {abs(time_avg_unemp - cross_sectional_unemp):.4f}")
698683
```
699684

685+
Indeed, they are very close.
686+
700687
Now let's visualize the cross-sectional distribution:
701688

702689
```{code-cell} ipython3
703690
plot_cross_sectional_unemployment(model)
704691
```
705692

706-
## Cross-Sectional Analysis with Lower Unemployment Compensation (c=0.5)
693+
## Lower Unemployment Compensation (c=0.5)
707694

708-
Let's examine how the cross-sectional unemployment rate changes with lower unemployment compensation:
695+
What happens to the cross-sectional unemployment rate with lower unemployment compensation?
709696

710697
```{code-cell} ipython3
711698
model_low_c = create_js_with_sep_model(c=0.5)
712699
plot_cross_sectional_unemployment(model_low_c)
713700
```
714701

702+
715703
## Exercises
716704

717705
```{exercise-start}
718706
:label: mmwsm_ex1
719707
```
720708

721-
Create a plot that shows how the steady state cross-sectional unemployment rate
709+
Create a plot that investigates more carefully how the steady state cross-sectional unemployment rate
722710
changes with unemployment compensation.
723711

724712
```{exercise-end}
@@ -751,4 +739,3 @@ plt.show()
751739

752740
```{solution-end}
753741
```
754-

0 commit comments

Comments
 (0)