Skip to content

Commit a48f019

Browse files
jstacHumphreyYang
andauthored
[mccall_model_with_separation] Conversion to JAX (#501)
* misc * misc * minor updates to static figure and variable / searchsort logic. --------- Co-authored-by: Humphrey Yang <[email protected]>
1 parent 5edb261 commit a48f019

File tree

4 files changed

+106
-128
lines changed

4 files changed

+106
-128
lines changed
9.3 KB
Loading
14.7 KB
Loading
11 KB
Loading

lectures/mccall_model_with_separation.md

Lines changed: 106 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ jupytext:
33
text_representation:
44
extension: .md
55
format_name: myst
6+
format_version: 0.13
7+
jupytext_version: 1.17.1
68
kernelspec:
7-
display_name: Python 3
8-
language: python
99
name: python3
10+
display_name: Python 3 (ipykernel)
11+
language: python
1012
---
1113

1214
(mccall_with_sep)=
@@ -29,10 +31,9 @@ kernelspec:
2931

3032
In addition to what's in Anaconda, this lecture will need the following libraries:
3133

32-
```{code-cell} ipython
33-
---
34-
tags: [hide-output]
35-
---
34+
```{code-cell} ipython3
35+
:tags: [hide-output]
36+
3637
!pip install quantecon
3738
```
3839

@@ -54,12 +55,12 @@ worker preferences slightly more sophisticated.
5455

5556
We'll need the following imports
5657

57-
```{code-cell} ipython
58+
```{code-cell} ipython3
5859
import matplotlib.pyplot as plt
59-
plt.rcParams["figure.figsize"] = (11, 5) #set default figure size
6060
import numpy as np
61-
from numba import jit, float64
62-
from numba.experimental import jitclass
61+
import jax
62+
import jax.numpy as jnp
63+
from typing import NamedTuple
6364
from quantecon.distributions import BetaBinomial
6465
```
6566

@@ -306,87 +307,74 @@ This helps to tidy up the code and provides an object that's easy to pass to fun
306307

307308
The default utility function is a CRRA utility function
308309

309-
```{code-cell} python3
310-
@jit
310+
```{code-cell} ipython3
311+
@jax.jit
311312
def u(c, σ=2.0):
312313
return (c**(1 - σ) - 1) / (1 - σ)
313314
```
314315

315316
Also, here's a default wage distribution, based around the BetaBinomial
316317
distribution:
317318

318-
```{code-cell} python3
319+
```{code-cell} ipython3
319320
n = 60 # n possible outcomes for w
320-
w_default = np.linspace(10, 20, n) # wages between 10 and 20
321+
w_default = jnp.linspace(10, 20, n) # wages between 10 and 20
321322
a, b = 600, 400 # shape parameters
322-
dist = BetaBinomial(n-1, a, b)
323-
q_default = dist.pdf()
323+
dist = BetaBinomial(n-1, a, b) # distribution
324+
q_default = jnp.array(dist.pdf()) # probabilities as a JAX array
324325
```
325326

326327
Here's our jitted class for the McCall model with separation.
327328

328-
```{code-cell} python3
329-
mccall_data = [
330-
('α', float64), # job separation rate
331-
('β', float64), # discount factor
332-
('c', float64), # unemployment compensation
333-
('w', float64[:]), # list of wage values
334-
('q', float64[:]) # pmf of random variable w
335-
]
336-
337-
@jitclass(mccall_data)
338-
class McCallModel:
339-
"""
340-
Stores the parameters and functions associated with a given model.
341-
"""
342-
343-
def __init__(self, α=0.2, β=0.98, c=6.0, w=w_default, q=q_default):
344-
345-
self.α, self.β, self.c, self.w, self.q = α, β, c, w, q
346-
347-
348-
def update(self, v, d):
349-
350-
α, β, c, w, q = self.α, self.β, self.c, self.w, self.q
351-
352-
v_new = np.empty_like(v)
353-
354-
for i in range(len(w)):
355-
v_new[i] = u(w[i]) + β * ((1 - α) * v[i] + α * d)
356-
357-
d_new = np.sum(np.maximum(v, u(c) + β * d) * q)
358-
359-
return v_new, d_new
329+
```{code-cell} ipython3
330+
class Model(NamedTuple):
331+
α: float = 0.2 # job separation rate
332+
β: float = 0.98 # discount factor
333+
c: float = 6.0 # unemployment compensation
334+
w: jnp.ndarray = w_default # wage outcome space
335+
q: jnp.ndarray = q_default # probabilities over wage offers
360336
```
361337

362338
Now we iterate until successive realizations are closer together than some small tolerance level.
363339

364340
We then return the current iterate as an approximate solution.
365341

366-
```{code-cell} python3
367-
@jit
368-
def solve_model(mcm, tol=1e-5, max_iter=2000):
369-
"""
370-
Iterates to convergence on the Bellman equations
371-
372-
* mcm is an instance of McCallModel
373-
"""
374-
375-
v = np.ones_like(mcm.w) # Initial guess of v
376-
d = 1 # Initial guess of d
377-
i = 0
378-
error = tol + 1
379-
380-
while error > tol and i < max_iter:
381-
v_new, d_new = mcm.update(v, d)
382-
error_1 = np.max(np.abs(v_new - v))
383-
error_2 = np.abs(d_new - d)
384-
error = max(error_1, error_2)
385-
v = v_new
386-
d = d_new
387-
i += 1
388-
389-
return v, d
342+
```{code-cell} ipython3
343+
@jax.jit
344+
def update(model, v, d):
345+
" One update on the Bellman equations. "
346+
α, β, c, w, q = model.α, model.β, model.c, model.w, model.q
347+
v_new = u(w) + β * ((1 - α) * v + α * d)
348+
d_new = jnp.sum(jnp.maximum(v, u(c) + β * d) * q)
349+
return v_new, d_new
350+
351+
@jax.jit
352+
def solve_model(model, tol=1e-5, max_iter=2000):
353+
" Iterates to convergence on the Bellman equations. "
354+
355+
def cond_fun(state):
356+
v, d, i, error = state
357+
return jnp.logical_and(error > tol, i < max_iter)
358+
359+
def body_fun(state):
360+
v, d, i, error = state
361+
v_new, d_new = update(model, v, d)
362+
error_1 = jnp.max(jnp.abs(v_new - v))
363+
error_2 = jnp.abs(d_new - d)
364+
error_new = jnp.maximum(error_1, error_2)
365+
return v_new, d_new, i + 1, error_new
366+
367+
# Initial state: (v, d, i, error)
368+
v_init = jnp.ones_like(model.w)
369+
d_init = 1.0
370+
i_init = 0
371+
error_init = tol + 1
372+
373+
init_state = (v_init, d_init, i_init, error_init)
374+
final_state = jax.lax.while_loop(cond_fun, body_fun, init_state)
375+
v_final, d_final, _, _ = final_state
376+
377+
return v_final, d_final
390378
```
391379

392380
### The Reservation Wage: First Pass
@@ -401,45 +389,40 @@ Let's compare $v$ and $h$ to see what they look like.
401389

402390
We'll use the default parameterizations found in the code above.
403391

404-
```{code-cell} python3
405-
mcm = McCallModel()
406-
v, d = solve_model(mcm)
407-
h = u(mcm.c) + mcm.β * d
392+
```{code-cell} ipython3
393+
model = Model()
394+
v, d = solve_model(model)
395+
h = u(model.c) + model.β * d
408396
409397
fig, ax = plt.subplots()
410-
411-
ax.plot(mcm.w, v, 'b-', lw=2, alpha=0.7, label='$v$')
412-
ax.plot(mcm.w, [h] * len(mcm.w),
398+
ax.plot(model.w, v, 'b-', lw=2, alpha=0.7, label='$v$')
399+
ax.plot(model.w, [h] * len(model.w),
413400
'g-', lw=2, alpha=0.7, label='$h$')
414-
ax.set_xlim(min(mcm.w), max(mcm.w))
401+
ax.set_xlim(min(model.w), max(model.w))
415402
ax.legend()
416-
417403
plt.show()
418404
```
419405

420406
The value $v$ is increasing because higher $w$ generates a higher wage flow conditional on staying employed.
421407

422408
### The Reservation Wage: Computation
423409

424-
Here's a function `compute_reservation_wage` that takes an instance of `McCallModel`
410+
Here's a function `compute_reservation_wage` that takes an instance of `Model`
425411
and returns the associated reservation wage.
426412

427-
```{code-cell} python3
428-
@jit
429-
def compute_reservation_wage(mcm):
413+
```{code-cell} ipython3
414+
@jax.jit
415+
def compute_reservation_wage(model):
430416
"""
431417
Computes the reservation wage of an instance of the McCall model
432-
by finding the smallest w such that v(w) >= h.
433-
434-
If no such w exists, then w_bar is set to np.inf.
418+
by finding the smallest w such that v(w) >= h. If no such w exists, then
419+
w_bar is set to np.inf.
435420
"""
436-
437-
v, d = solve_model(mcm)
438-
h = u(mcm.c) + mcm.β * d
439-
440-
i = np.searchsorted(v, h, side='right')
441-
w_bar = mcm.w[i]
442-
421+
422+
v, d = solve_model(model)
423+
h = u(model.c) + model.β * d
424+
i = jnp.searchsorted(v, h, side='left')
425+
w_bar = jnp.where(i >= len(model.w), jnp.inf, model.w[i])
443426
return w_bar
444427
```
445428

@@ -453,7 +436,7 @@ In each instance below, we'll show you a figure and then ask you to reproduce it
453436

454437
First, let's look at how $\bar w$ varies with unemployment compensation.
455438

456-
In the figure below, we use the default parameters in the `McCallModel` class, apart from
439+
In the figure below, we use the default parameters in the `Model` class, apart from
457440
c (which takes the values given on the horizontal axis)
458441

459442
```{figure} /_static/lecture_specific/mccall_model_with_separation/mccall_resw_c.png
@@ -503,11 +486,11 @@ Reproduce all the reservation wage figures shown above.
503486

504487
Regarding the values on the horizontal axis, use
505488

506-
```{code-cell} python3
489+
```{code-cell} ipython3
507490
grid_size = 25
508-
c_vals = np.linspace(2, 12, grid_size) # unemployment compensation
509-
beta_vals = np.linspace(0.8, 0.99, grid_size) # discount factors
510-
alpha_vals = np.linspace(0.05, 0.5, grid_size) # separation rate
491+
c_vals = jnp.linspace(2, 12, grid_size) # unemployment compensation
492+
β_vals = jnp.linspace(0.8, 0.99, grid_size) # discount factors
493+
α_vals = jnp.linspace(0.05, 0.5, grid_size) # separation rate
511494
```
512495

513496
```{exercise-end}
@@ -519,57 +502,52 @@ alpha_vals = np.linspace(0.05, 0.5, grid_size) # separation rate
519502

520503
Here's the first figure.
521504

522-
```{code-cell} python3
523-
mcm = McCallModel()
505+
```{code-cell} ipython3
506+
def compute_res_wage_given_c(c):
507+
model = Model(c=c)
508+
w_bar = compute_reservation_wage(model)
509+
return w_bar
524510
525-
w_bar_vals = np.empty_like(c_vals)
511+
w_bar_vals = jax.vmap(compute_res_wage_given_c)(c_vals)
526512
527513
fig, ax = plt.subplots()
528-
529-
for i, c in enumerate(c_vals):
530-
mcm.c = c
531-
w_bar = compute_reservation_wage(mcm)
532-
w_bar_vals[i] = w_bar
533-
534-
ax.set(xlabel='unemployment compensation',
535-
ylabel='reservation wage')
514+
ax.set(xlabel='unemployment compensation', ylabel='reservation wage')
536515
ax.plot(c_vals, w_bar_vals, label=r'$\bar w$ as a function of $c$')
537516
ax.legend()
538-
539517
plt.show()
540518
```
541519

542520
Here's the second one.
543521

544-
```{code-cell} python3
545-
fig, ax = plt.subplots()
522+
```{code-cell} ipython3
523+
def compute_res_wage_given_beta(β):
524+
model = Model(β=β)
525+
w_bar = compute_reservation_wage(model)
526+
return w_bar
546527
547-
for i, β in enumerate(beta_vals):
548-
mcm.β = β
549-
w_bar = compute_reservation_wage(mcm)
550-
w_bar_vals[i] = w_bar
528+
w_bar_vals = jax.vmap(compute_res_wage_given_beta)(β_vals)
551529
530+
fig, ax = plt.subplots()
552531
ax.set(xlabel='discount factor', ylabel='reservation wage')
553-
ax.plot(beta_vals, w_bar_vals, label=r'$\bar w$ as a function of $\beta$')
532+
ax.plot(β_vals, w_bar_vals, label=r'$\bar w$ as a function of $\beta$')
554533
ax.legend()
555-
556534
plt.show()
557535
```
558536

559537
Here's the third.
560538

561-
```{code-cell} python3
562-
fig, ax = plt.subplots()
539+
```{code-cell} ipython3
540+
def compute_res_wage_given_alpha(α):
541+
model = Model(α=α)
542+
w_bar = compute_reservation_wage(model)
543+
return w_bar
563544
564-
for i, α in enumerate(alpha_vals):
565-
mcm.α = α
566-
w_bar = compute_reservation_wage(mcm)
567-
w_bar_vals[i] = w_bar
545+
w_bar_vals = jax.vmap(compute_res_wage_given_alpha)(α_vals)
568546
547+
fig, ax = plt.subplots()
569548
ax.set(xlabel='separation rate', ylabel='reservation wage')
570-
ax.plot(alpha_vals, w_bar_vals, label=r'$\bar w$ as a function of $\alpha$')
549+
ax.plot(α_vals, w_bar_vals, label=r'$\bar w$ as a function of $\alpha$')
571550
ax.legend()
572-
573551
plt.show()
574552
```
575553

0 commit comments

Comments
 (0)