@@ -3,10 +3,12 @@ jupytext:
33 text_representation :
44 extension : .md
55 format_name : myst
6+ format_version : 0.13
7+ jupytext_version : 1.17.1
68kernelspec :
7- display_name : Python 3
8- language : python
99 name : python3
10+ display_name : Python 3 (ipykernel)
11+ language : python
1012---
1113
1214(mccall_with_sep)=
@@ -29,10 +31,9 @@ kernelspec:
2931
3032In addition to what's in Anaconda, this lecture will need the following libraries:
3133
32- ``` {code-cell} ipython
33- ---
34- tags: [hide-output]
35- ---
34+ ``` {code-cell} ipython3
35+ :tags: [hide-output]
36+
3637!pip install quantecon
3738```
3839
@@ -54,12 +55,12 @@ worker preferences slightly more sophisticated.
5455
5556We'll need the following imports
5657
57- ``` {code-cell} ipython
58+ ``` {code-cell} ipython3
5859import matplotlib.pyplot as plt
59- plt.rcParams["figure.figsize"] = (11, 5) #set default figure size
6060import numpy as np
61- from numba import jit, float64
62- from numba.experimental import jitclass
61+ import jax
62+ import jax.numpy as jnp
63+ from typing import NamedTuple
6364from quantecon.distributions import BetaBinomial
6465```
6566
@@ -306,87 +307,74 @@ This helps to tidy up the code and provides an object that's easy to pass to fun
306307
307308The default utility function is a CRRA utility function
308309
309- ``` {code-cell} python3
310- @jit
310+ ``` {code-cell} ipython3
311+ @jax. jit
311312def u(c, σ=2.0):
312313 return (c**(1 - σ) - 1) / (1 - σ)
313314```
314315
315316Also, here's a default wage distribution, based around the BetaBinomial
316317distribution:
317318
318- ``` {code-cell} python3
319+ ``` {code-cell} ipython3
319320n = 60 # n possible outcomes for w
320- w_default = np .linspace(10, 20, n) # wages between 10 and 20
321+ w_default = jnp .linspace(10, 20, n) # wages between 10 and 20
321322a, b = 600, 400 # shape parameters
322- dist = BetaBinomial(n-1, a, b)
323- q_default = dist.pdf()
323+ dist = BetaBinomial(n-1, a, b) # distribution
324+ q_default = jnp.array( dist.pdf()) # probabilities as a JAX array
324325```
325326
326327Here's our jitted class for the McCall model with separation.
327328
328- ``` {code-cell} python3
329- mccall_data = [
330- ('α', float64), # job separation rate
331- ('β', float64), # discount factor
332- ('c', float64), # unemployment compensation
333- ('w', float64[:]), # list of wage values
334- ('q', float64[:]) # pmf of random variable w
335- ]
336-
337- @jitclass(mccall_data)
338- class McCallModel:
339- """
340- Stores the parameters and functions associated with a given model.
341- """
342-
343- def __init__(self, α=0.2, β=0.98, c=6.0, w=w_default, q=q_default):
344-
345- self.α, self.β, self.c, self.w, self.q = α, β, c, w, q
346-
347-
348- def update(self, v, d):
349-
350- α, β, c, w, q = self.α, self.β, self.c, self.w, self.q
351-
352- v_new = np.empty_like(v)
353-
354- for i in range(len(w)):
355- v_new[i] = u(w[i]) + β * ((1 - α) * v[i] + α * d)
356-
357- d_new = np.sum(np.maximum(v, u(c) + β * d) * q)
358-
359- return v_new, d_new
329+ ``` {code-cell} ipython3
330+ class Model(NamedTuple):
331+ α: float = 0.2 # job separation rate
332+ β: float = 0.98 # discount factor
333+ c: float = 6.0 # unemployment compensation
334+ w: jnp.ndarray = w_default # wage outcome space
335+ q: jnp.ndarray = q_default # probabilities over wage offers
360336```
361337
362338Now we iterate until successive realizations are closer together than some small tolerance level.
363339
364340We then return the current iterate as an approximate solution.
365341
366- ``` {code-cell} python3
367- @jit
368- def solve_model(mcm, tol=1e-5, max_iter=2000):
369- """
370- Iterates to convergence on the Bellman equations
371-
372- * mcm is an instance of McCallModel
373- """
374-
375- v = np.ones_like(mcm.w) # Initial guess of v
376- d = 1 # Initial guess of d
377- i = 0
378- error = tol + 1
379-
380- while error > tol and i < max_iter:
381- v_new, d_new = mcm.update(v, d)
382- error_1 = np.max(np.abs(v_new - v))
383- error_2 = np.abs(d_new - d)
384- error = max(error_1, error_2)
385- v = v_new
386- d = d_new
387- i += 1
388-
389- return v, d
342+ ``` {code-cell} ipython3
343+ @jax.jit
344+ def update(model, v, d):
345+ " One update on the Bellman equations. "
346+ α, β, c, w, q = model.α, model.β, model.c, model.w, model.q
347+ v_new = u(w) + β * ((1 - α) * v + α * d)
348+ d_new = jnp.sum(jnp.maximum(v, u(c) + β * d) * q)
349+ return v_new, d_new
350+
351+ @jax.jit
352+ def solve_model(model, tol=1e-5, max_iter=2000):
353+ " Iterates to convergence on the Bellman equations. "
354+
355+ def cond_fun(state):
356+ v, d, i, error = state
357+ return jnp.logical_and(error > tol, i < max_iter)
358+
359+ def body_fun(state):
360+ v, d, i, error = state
361+ v_new, d_new = update(model, v, d)
362+ error_1 = jnp.max(jnp.abs(v_new - v))
363+ error_2 = jnp.abs(d_new - d)
364+ error_new = jnp.maximum(error_1, error_2)
365+ return v_new, d_new, i + 1, error_new
366+
367+ # Initial state: (v, d, i, error)
368+ v_init = jnp.ones_like(model.w)
369+ d_init = 1.0
370+ i_init = 0
371+ error_init = tol + 1
372+
373+ init_state = (v_init, d_init, i_init, error_init)
374+ final_state = jax.lax.while_loop(cond_fun, body_fun, init_state)
375+ v_final, d_final, _, _ = final_state
376+
377+ return v_final, d_final
390378```
391379
392380### The Reservation Wage: First Pass
@@ -401,45 +389,40 @@ Let's compare $v$ and $h$ to see what they look like.
401389
402390We'll use the default parameterizations found in the code above.
403391
404- ``` {code-cell} python3
405- mcm = McCallModel ()
406- v, d = solve_model(mcm )
407- h = u(mcm .c) + mcm .β * d
392+ ``` {code-cell} ipython3
393+ model = Model ()
394+ v, d = solve_model(model )
395+ h = u(model .c) + model .β * d
408396
409397fig, ax = plt.subplots()
410-
411- ax.plot(mcm.w, v, 'b-', lw=2, alpha=0.7, label='$v$')
412- ax.plot(mcm.w, [h] * len(mcm.w),
398+ ax.plot(model.w, v, 'b-', lw=2, alpha=0.7, label='$v$')
399+ ax.plot(model.w, [h] * len(model.w),
413400 'g-', lw=2, alpha=0.7, label='$h$')
414- ax.set_xlim(min(mcm .w), max(mcm .w))
401+ ax.set_xlim(min(model .w), max(model .w))
415402ax.legend()
416-
417403plt.show()
418404```
419405
420406The value $v$ is increasing because higher $w$ generates a higher wage flow conditional on staying employed.
421407
422408### The Reservation Wage: Computation
423409
424- Here's a function ` compute_reservation_wage ` that takes an instance of ` McCallModel `
410+ Here's a function ` compute_reservation_wage ` that takes an instance of ` Model `
425411and returns the associated reservation wage.
426412
427- ``` {code-cell} python3
428- @jit
429- def compute_reservation_wage(mcm ):
413+ ``` {code-cell} ipython3
414+ @jax. jit
415+ def compute_reservation_wage(model ):
430416 """
431417 Computes the reservation wage of an instance of the McCall model
432- by finding the smallest w such that v(w) >= h.
433-
434- If no such w exists, then w_bar is set to np.inf.
418+ by finding the smallest w such that v(w) >= h. If no such w exists, then
419+ w_bar is set to np.inf.
435420 """
436-
437- v, d = solve_model(mcm)
438- h = u(mcm.c) + mcm.β * d
439-
440- i = np.searchsorted(v, h, side='right')
441- w_bar = mcm.w[i]
442-
421+
422+ v, d = solve_model(model)
423+ h = u(model.c) + model.β * d
424+ i = jnp.searchsorted(v, h, side='left')
425+ w_bar = jnp.where(i >= len(model.w), jnp.inf, model.w[i])
443426 return w_bar
444427```
445428
@@ -453,7 +436,7 @@ In each instance below, we'll show you a figure and then ask you to reproduce it
453436
454437First, let's look at how $\bar w$ varies with unemployment compensation.
455438
456- In the figure below, we use the default parameters in the ` McCallModel ` class, apart from
439+ In the figure below, we use the default parameters in the ` Model ` class, apart from
457440c (which takes the values given on the horizontal axis)
458441
459442``` {figure} /_static/lecture_specific/mccall_model_with_separation/mccall_resw_c.png
@@ -503,11 +486,11 @@ Reproduce all the reservation wage figures shown above.
503486
504487Regarding the values on the horizontal axis, use
505488
506- ``` {code-cell} python3
489+ ``` {code-cell} ipython3
507490grid_size = 25
508- c_vals = np .linspace(2, 12, grid_size) # unemployment compensation
509- beta_vals = np .linspace(0.8, 0.99, grid_size) # discount factors
510- alpha_vals = np .linspace(0.05, 0.5, grid_size) # separation rate
491+ c_vals = jnp .linspace(2, 12, grid_size) # unemployment compensation
492+ β_vals = jnp .linspace(0.8, 0.99, grid_size) # discount factors
493+ α_vals = jnp .linspace(0.05, 0.5, grid_size) # separation rate
511494```
512495
513496``` {exercise-end}
@@ -519,57 +502,52 @@ alpha_vals = np.linspace(0.05, 0.5, grid_size) # separation rate
519502
520503Here's the first figure.
521504
522- ``` {code-cell} python3
523- mcm = McCallModel()
505+ ``` {code-cell} ipython3
506+ def compute_res_wage_given_c(c):
507+ model = Model(c=c)
508+ w_bar = compute_reservation_wage(model)
509+ return w_bar
524510
525- w_bar_vals = np.empty_like (c_vals)
511+ w_bar_vals = jax.vmap(compute_res_wage_given_c) (c_vals)
526512
527513fig, ax = plt.subplots()
528-
529- for i, c in enumerate(c_vals):
530- mcm.c = c
531- w_bar = compute_reservation_wage(mcm)
532- w_bar_vals[i] = w_bar
533-
534- ax.set(xlabel='unemployment compensation',
535- ylabel='reservation wage')
514+ ax.set(xlabel='unemployment compensation', ylabel='reservation wage')
536515ax.plot(c_vals, w_bar_vals, label=r'$\bar w$ as a function of $c$')
537516ax.legend()
538-
539517plt.show()
540518```
541519
542520Here's the second one.
543521
544- ``` {code-cell} python3
545- fig, ax = plt.subplots()
522+ ``` {code-cell} ipython3
523+ def compute_res_wage_given_beta(β):
524+ model = Model(β=β)
525+ w_bar = compute_reservation_wage(model)
526+ return w_bar
546527
547- for i, β in enumerate(beta_vals):
548- mcm.β = β
549- w_bar = compute_reservation_wage(mcm)
550- w_bar_vals[i] = w_bar
528+ w_bar_vals = jax.vmap(compute_res_wage_given_beta)(β_vals)
551529
530+ fig, ax = plt.subplots()
552531ax.set(xlabel='discount factor', ylabel='reservation wage')
553- ax.plot(beta_vals , w_bar_vals, label=r'$\bar w$ as a function of $\beta$')
532+ ax.plot(β_vals , w_bar_vals, label=r'$\bar w$ as a function of $\beta$')
554533ax.legend()
555-
556534plt.show()
557535```
558536
559537Here's the third.
560538
561- ``` {code-cell} python3
562- fig, ax = plt.subplots()
539+ ``` {code-cell} ipython3
540+ def compute_res_wage_given_alpha(α):
541+ model = Model(α=α)
542+ w_bar = compute_reservation_wage(model)
543+ return w_bar
563544
564- for i, α in enumerate(alpha_vals):
565- mcm.α = α
566- w_bar = compute_reservation_wage(mcm)
567- w_bar_vals[i] = w_bar
545+ w_bar_vals = jax.vmap(compute_res_wage_given_alpha)(α_vals)
568546
547+ fig, ax = plt.subplots()
569548ax.set(xlabel='separation rate', ylabel='reservation wage')
570- ax.plot(alpha_vals , w_bar_vals, label=r'$\bar w$ as a function of $\alpha$')
549+ ax.plot(α_vals , w_bar_vals, label=r'$\bar w$ as a function of $\alpha$')
571550ax.legend()
572-
573551plt.show()
574552```
575553
0 commit comments