Skip to content

Commit c6ae618

Browse files
longye-tianHumphreyYangCopilotmmcky
authored
[mccall_fitted_vfi] JAX conversion (#567)
* Update mccall_fitted_vfi.md * minor fixes * Update lectures/mccall_fitted_vfi.md Co-authored-by: Humphrey Yang <[email protected]> * Update lectures/mccall_fitted_vfi.md Co-authored-by: Humphrey Yang <[email protected]> * Update lectures/mccall_fitted_vfi.md Co-authored-by: Humphrey Yang <[email protected]> * Update lectures/mccall_fitted_vfi.md Co-authored-by: Humphrey Yang <[email protected]> * Update mccall_fitted_vfi.md * Update lectures/mccall_fitted_vfi.md Co-authored-by: Copilot <[email protected]> * testing cpu run time * Update lectures/mccall_fitted_vfi.md * Apply suggestions from code review pep8 compliance --------- Co-authored-by: Humphrey Yang <[email protected]> Co-authored-by: Humphrey Yang <[email protected]> Co-authored-by: Copilot <[email protected]> Co-authored-by: Matt McKay <[email protected]>
1 parent 307c7f1 commit c6ae618

File tree

1 file changed

+134
-124
lines changed

1 file changed

+134
-124
lines changed

lectures/mccall_fitted_vfi.md

Lines changed: 134 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,18 @@ We will use the following imports:
5050

5151
```{code-cell} ipython3
5252
import matplotlib.pyplot as plt
53-
import numpy as np
54-
from numba import jit, float64
55-
from numba.experimental import jitclass
53+
import jax
54+
import jax.numpy as jnp
55+
from typing import NamedTuple
56+
import quantecon as qe
57+
58+
# Set JAX to use CPU
59+
jax.config.update('jax_platform_name', 'cpu')
5660
```
5761

58-
## The Algorithm
62+
## The algorithm
5963

60-
The model is the same as the McCall model with job separation we {doc}`studied before <mccall_model_with_separation>`, except that the wage offer distribution is continuous.
64+
The model is the same as the McCall model with job separation that we {doc}`studied before <mccall_model_with_separation>`, except that the wage offer distribution is continuous.
6165

6266
We are going to start with the two Bellman equations we obtained for the model with job separation after {ref}`a simplifying transformation <ast_mcm>`.
6367

@@ -82,16 +86,16 @@ v(w) = u(w) + \beta
8286

8387
The unknowns here are the function $v$ and the scalar $d$.
8488

85-
The difference between these and the pair of Bellman equations we previously worked on are
89+
The differences between these and the pair of Bellman equations we previously worked on are
8690

87-
1. in {eq}`bell1mcmc`, what used to be a sum over a finite number of wage values is an integral over an infinite set.
91+
1. In {eq}`bell1mcmc`, what used to be a sum over a finite number of wage values is an integral over an infinite set.
8892
1. The function $v$ in {eq}`bell2mcmc` is defined over all $w \in \mathbb R_+$.
8993

9094
The function $q$ in {eq}`bell1mcmc` is the density of the wage offer distribution.
9195

9296
Its support is taken as equal to $\mathbb R_+$.
9397

94-
### Value Function Iteration
98+
### Value function iteration
9599

96100
In theory, we should now proceed as follows:
97101

@@ -106,12 +110,12 @@ The iterates of the value function can neither be calculated exactly nor stored
106110

107111
To see the issue, consider {eq}`bell2mcmc`.
108112

109-
Even if $v$ is a known function, the only way to store its update $v'$
113+
Even if $v$ is a known function, the only way to store its update $v'$
110114
is to record its value $v'(w)$ for every $w \in \mathbb R_+$.
111115

112116
Clearly, this is impossible.
113117

114-
### Fitted Value Function Iteration
118+
### Fitted value function iteration
115119

116120
What we will do instead is use **fitted value function iteration**.
117121

@@ -141,25 +145,25 @@ One good choice from both respects is continuous piecewise linear interpolation.
141145

142146
This method
143147

144-
1. combines well with value function iteration (see., e.g.,
148+
1. combines well with value function iteration (see, e.g.,
145149
{cite}`gordon1995stable` or {cite}`stachurski2008continuous`) and
146150
1. preserves useful shape properties such as monotonicity and concavity/convexity.
147151

148-
Linear interpolation will be implemented using [numpy.interp](https://numpy.org/doc/stable/reference/generated/numpy.interp.html).
152+
Linear interpolation will be implemented using JAX's interpolation function `jnp.interp`.
149153

150154
The next figure illustrates piecewise linear interpolation of an arbitrary
151155
function on grid points $0, 0.2, 0.4, 0.6, 0.8, 1$.
152156

153157
```{code-cell} python3
154158
def f(x):
155-
y1 = 2 * np.cos(6 * x) + np.sin(14 * x)
159+
y1 = 2 * jnp.cos(6 * x) + jnp.sin(14 * x)
156160
return y1 + 2.5
157161
158-
c_grid = np.linspace(0, 1, 6)
159-
f_grid = np.linspace(0, 1, 150)
162+
c_grid = jnp.linspace(0, 1, 6)
163+
f_grid = jnp.linspace(0, 1, 150)
160164
161165
def Af(x):
162-
return np.interp(x, c_grid, f(c_grid))
166+
return jnp.interp(x, c_grid, f(c_grid))
163167
164168
fig, ax = plt.subplots()
165169
@@ -175,123 +179,128 @@ plt.show()
175179

176180
## Implementation
177181

178-
The first step is to build a jitted class for the McCall model with separation and
179-
a continuous wage offer distribution.
182+
The first step is to build a JAX-compatible structure for the McCall model with separation and a continuous wage offer distribution.
180183

181184
We will take the utility function to be the log function for this application, with $u(c) = \ln c$.
182185

183186
We will adopt the lognormal distribution for wages, with $w = \exp(\mu + \sigma z)$
184187
when $z$ is standard normal and $\mu, \sigma$ are parameters.
185188

186189
```{code-cell} python3
187-
@jit
188190
def lognormal_draws(n=1000, μ=2.5, σ=0.5, seed=1234):
189-
np.random.seed(seed)
190-
z = np.random.randn(n)
191-
w_draws = np.exp(μ + σ * z)
191+
key = jax.random.PRNGKey(seed)
192+
z = jax.random.normal(key, (n,))
193+
w_draws = jnp.exp(μ + σ * z)
192194
return w_draws
193195
```
194196

195-
Here's our class.
197+
Here's our model structure using a NamedTuple.
196198

197199
```{code-cell} python3
198-
mccall_data_continuous = [
199-
('c', float64), # unemployment compensation
200-
('α', float64), # job separation rate
201-
('β', float64), # discount factor
202-
('w_grid', float64[:]), # grid of points for fitted VFI
203-
('w_draws', float64[:]) # draws of wages for Monte Carlo
204-
]
205-
206-
@jitclass(mccall_data_continuous)
207-
class McCallModelContinuous:
208-
209-
def __init__(self,
210-
c=1,
211-
α=0.1,
212-
β=0.96,
213-
grid_min=1e-10,
214-
grid_max=5,
215-
grid_size=100,
216-
w_draws=lognormal_draws()):
217-
218-
self.c, self.α, self.β = c, α, β
219-
220-
self.w_grid = np.linspace(grid_min, grid_max, grid_size)
221-
self.w_draws = w_draws
222-
223-
def update(self, v, d):
224-
225-
# Simplify names
226-
c, α, β = self.c, self.α, self.β
227-
w = self.w_grid
228-
u = lambda x: np.log(x)
229-
230-
# Interpolate array represented value function
231-
vf = lambda x: np.interp(x, w, v)
232-
233-
# Update d using Monte Carlo to evaluate integral
234-
d_new = np.mean(np.maximum(vf(self.w_draws), u(c) + β * d))
235-
236-
# Update v
237-
v_new = u(w) + β * ((1 - α) * v + α * d)
238-
239-
return v_new, d_new
200+
class McCallModelContinuous(NamedTuple):
201+
c: float # unemployment compensation
202+
α: float # job separation rate
203+
β: float # discount factor
204+
w_grid: jnp.ndarray # grid of points for fitted VFI
205+
w_draws: jnp.ndarray # draws of wages for Monte Carlo
206+
207+
def create_mccall_model(c=1,
208+
α=0.1,
209+
β=0.96,
210+
grid_min=1e-10,
211+
grid_max=5,
212+
grid_size=100,
213+
μ=2.5,
214+
σ=0.5,
215+
mc_size=1000,
216+
seed=1234,
217+
w_draws=None):
218+
"""Factory function to create a McCall model instance."""
219+
if w_draws is None:
220+
221+
# Generate wage draws if not provided
222+
w_draws = lognormal_draws(n=mc_size, μ=μ, σ=σ, seed=seed)
223+
224+
w_grid = jnp.linspace(grid_min, grid_max, grid_size)
225+
return McCallModelContinuous(c=c, α=α, β=β, w_grid=w_grid, w_draws=w_draws)
226+
227+
@jax.jit
228+
def update(model, v, d):
229+
"""Update value function and continuation value."""
230+
231+
# Unpack model parameters
232+
c, α, β, w_grid, w_draws = model
233+
u = jnp.log
234+
235+
# Interpolate array represented value function
236+
vf = lambda x: jnp.interp(x, w_grid, v)
237+
238+
# Update d using Monte Carlo to evaluate integral
239+
d_new = jnp.mean(jnp.maximum(vf(w_draws), u(c) + β * d))
240+
241+
# Update v
242+
v_new = u(w_grid) + β * ((1 - α) * v + α * d)
243+
244+
return v_new, d_new
240245
```
241246

242247
We then return the current iterate as an approximate solution.
243248

244249
```{code-cell} python3
245-
@jit
246-
def solve_model(mcm, tol=1e-5, max_iter=2000):
250+
@jax.jit
251+
def solve_model(model, tol=1e-5, max_iter=2000):
247252
"""
248253
Iterates to convergence on the Bellman equations
249254
250-
* mcm is an instance of McCallModel
255+
* model is an instance of McCallModelContinuous
251256
"""
252-
253-
v = np.ones_like(mcm.w_grid) # Initial guess of v
254-
d = 1 # Initial guess of d
255-
i = 0
256-
error = tol + 1
257-
258-
while error > tol and i < max_iter:
259-
v_new, d_new = mcm.update(v, d)
260-
error_1 = np.max(np.abs(v_new - v))
261-
error_2 = np.abs(d_new - d)
262-
error = max(error_1, error_2)
263-
v = v_new
264-
d = d_new
265-
i += 1
266-
267-
return v, d
257+
258+
# Initial guesses
259+
v = jnp.ones_like(model.w_grid)
260+
d = 1.0
261+
262+
def body_fun(state):
263+
v, d, i, error = state
264+
v_new, d_new = update(model, v, d)
265+
error_1 = jnp.max(jnp.abs(v_new - v))
266+
error_2 = jnp.abs(d_new - d)
267+
error = jnp.maximum(error_1, error_2)
268+
return v_new, d_new, i + 1, error
269+
270+
def cond_fun(state):
271+
_, _, i, error = state
272+
return (error > tol) & (i < max_iter)
273+
274+
initial_state = (v, d, 0, tol + 1)
275+
v_final, d_final, _, _ = jax.lax.while_loop(cond_fun, body_fun, initial_state)
276+
277+
return v_final, d_final
268278
```
269279

270280
Here's a function `compute_reservation_wage` that takes an instance of `McCallModelContinuous`
271281
and returns the associated reservation wage.
272282

273-
If $v(w) < h$ for all $w$, then the function returns np.inf
283+
If $v(w) < h$ for all $w$, then the function returns `jnp.inf`
274284

275285
```{code-cell} python3
276-
@jit
277-
def compute_reservation_wage(mcm):
286+
@jax.jit
287+
def compute_reservation_wage(model):
278288
"""
279289
Computes the reservation wage of an instance of the McCall model
280290
by finding the smallest w such that v(w) >= h.
281291
282-
If no such w exists, then w_bar is set to np.inf.
292+
If no such w exists, then w_bar is set to inf.
283293
"""
284-
u = lambda x: np.log(x)
285-
286-
v, d = solve_model(mcm)
287-
h = u(mcm.c) + mcm.β * d
288-
289-
w_bar = np.inf
290-
for i, wage in enumerate(mcm.w_grid):
291-
if v[i] > h:
292-
w_bar = wage
293-
break
294-
294+
c, α, β, w_grid, w_draws = model
295+
u = jnp.log
296+
297+
v, d = solve_model(model)
298+
h = u(c) + β * d
299+
300+
# Find the first wage where v(w) >= h
301+
indices = jnp.where(v >= h, size=1, fill_value=-1)
302+
w_bar = jnp.where(indices[0] >= 0, w_grid[indices[0]], jnp.inf)
303+
295304
return w_bar
296305
```
297306

@@ -305,7 +314,7 @@ The exercises ask you to explore the solution and how it changes with parameters
305314
Use the code above to explore what happens to the reservation wage when the wage parameter $\mu$
306315
changes.
307316
308-
Use the default parameters and $\mu$ in `mu_vals = np.linspace(0.0, 2.0, 15)`.
317+
Use the default parameters and $\mu$ in `μ_vals = jnp.linspace(0.0, 2.0, 15)`.
309318
310319
Is the impact on the reservation wage as you expected?
311320
```
@@ -317,21 +326,18 @@ Is the impact on the reservation wage as you expected?
317326
Here is one solution
318327

319328
```{code-cell} python3
320-
mcm = McCallModelContinuous()
321-
mu_vals = np.linspace(0.0, 2.0, 15)
322-
w_bar_vals = np.empty_like(mu_vals)
323-
324-
fig, ax = plt.subplots()
329+
def compute_res_wage_given_μ(μ):
330+
model = create_mccall_model(μ=μ)
331+
w_bar = compute_reservation_wage(model)
332+
return w_bar
325333
326-
for i, m in enumerate(mu_vals):
327-
mcm.w_draws = lognormal_draws(μ=m)
328-
w_bar = compute_reservation_wage(mcm)
329-
w_bar_vals[i] = w_bar
334+
μ_vals = jnp.linspace(0.0, 2.0, 15)
335+
w_bar_vals = jax.vmap(compute_res_wage_given_μ)(μ_vals)
330336
337+
fig, ax = plt.subplots()
331338
ax.set(xlabel='mean', ylabel='reservation wage')
332-
ax.plot(mu_vals, w_bar_vals, label=r'$\bar w$ as a function of $\mu$')
339+
ax.plot(μ_vals, w_bar_vals, label=r'$\bar w$ as a function of $\mu$')
333340
ax.legend()
334-
335341
plt.show()
336342
```
337343

@@ -354,11 +360,11 @@ support.
354360
355361
(This is a form of *mean-preserving spread*.)
356362
357-
Use `s_vals = np.linspace(1.0, 2.0, 15)` and `m = 2.0`.
363+
Use `s_vals = jnp.linspace(1.0, 2.0, 15)` and `m = 2.0`.
358364
359365
State how you expect the reservation wage to vary with $s$.
360366
361-
Now compute it. Is this as you expected?
367+
Now compute it - is this as you expected?
362368
```
363369

364370
```{solution-start} mfv_ex2
@@ -368,23 +374,27 @@ Now compute it. Is this as you expected?
368374
Here is one solution
369375

370376
```{code-cell} python3
371-
mcm = McCallModelContinuous()
372-
s_vals = np.linspace(1.0, 2.0, 15)
373-
m = 2.0
374-
w_bar_vals = np.empty_like(s_vals)
377+
def compute_res_wage_given_s(s, m=2.0, seed=1234):
378+
a, b = m - s, m + s
379+
key = jax.random.PRNGKey(seed)
380+
uniform_draws = jax.random.uniform(key, shape=(10_000,), minval=a, maxval=b)
375381
376-
fig, ax = plt.subplots()
382+
# Create model with default parameters but replace wage draws
383+
model = create_mccall_model(w_draws=uniform_draws)
384+
w_bar = compute_reservation_wage(model)
385+
return w_bar
377386
378-
for i, s in enumerate(s_vals):
379-
a, b = m - s, m + s
380-
mcm.w_draws = np.random.uniform(low=a, high=b, size=10_000)
381-
w_bar = compute_reservation_wage(mcm)
382-
w_bar_vals[i] = w_bar
387+
s_vals = jnp.linspace(1.0, 2.0, 15)
383388
389+
# Use vmap with different seeds for each s value
390+
seeds = jnp.arange(len(s_vals))
391+
compute_vectorized = jax.vmap(compute_res_wage_given_s, in_axes=(0, None, 0))
392+
w_bar_vals = compute_vectorized(s_vals, 2.0, seeds)
393+
394+
fig, ax = plt.subplots()
384395
ax.set(xlabel='volatility', ylabel='reservation wage')
385396
ax.plot(s_vals, w_bar_vals, label=r'$\bar w$ as a function of wage volatility')
386397
ax.legend()
387-
388398
plt.show()
389399
```
390400

0 commit comments

Comments
 (0)