Skip to content

Commit 6573f2a

Browse files
jstacclaude
andcommitted
Fix JAX compatibility issues in Job Search III lecture
Updated mccall_model_with_sep_markov.md to fix several JAX-related issues: - Refactored vfi() to return only v_final instead of tuple, making it more consistent with VFI pattern - Removed separate successive_approx() function and integrated iteration logic directly into vfi() - Fixed JAX decorators: changed @jit to @jax.jit and @partial(jit, ...) to @partial(jax.jit, ...) - Rewrote get_reservation_wage() to use jnp.argmax() instead of jnp.where() to avoid JAX concretization errors in JIT compilation - Updated all vfi() call sites to explicitly compute policy with get_greedy(v_star, model) - Removed @jit decorators from T() and get_greedy() functions (not needed) Also improved wording in mccall_model_with_separation.md for clarity. Tested: Converted to Python and ran successfully without errors. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent b61b30e commit 6573f2a

File tree

2 files changed

+48
-71
lines changed

2 files changed

+48
-71
lines changed

lectures/mccall_model_with_sep_markov.md

Lines changed: 47 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ jupytext:
44
extension: .md
55
format_name: myst
66
format_version: 0.13
7-
jupytext_version: 1.17.1
7+
jupytext_version: 1.17.2
88
kernelspec:
9-
name: python3
109
display_name: Python 3 (ipykernel)
1110
language: python
11+
name: python3
1212
---
1313

1414
(mccall_with_sep_markov)=
@@ -58,7 +58,7 @@ We use the following imports:
5858
from quantecon.markov import tauchen
5959
import jax.numpy as jnp
6060
import jax
61-
from jax import jit, lax
61+
from jax import lax
6262
from typing import NamedTuple
6363
import matplotlib.pyplot as plt
6464
from functools import partial
@@ -138,48 +138,11 @@ The optimal policy turns out to be a reservation wage strategy: accept all wages
138138

139139
## Code
140140

141-
142-
First, we implement the successive approximation algorithm.
143-
144-
This algorithm takes an operator $T$ and an initial condition and iterates until
145-
convergence.
146-
147-
We will use it for value function iteration.
148-
149-
```{code-cell} ipython3
150-
@partial(jit, static_argnums=(0,))
151-
def successive_approx(
152-
T, # Operator (callable) - marked as static
153-
x_0, # Initial condition
154-
tolerance: float = 1e-6, # Error tolerance
155-
max_iter: int = 100_000, # Max iteration bound
156-
):
157-
"""Computes the approximate fixed point of T via successive
158-
approximation using lax.while_loop."""
159-
160-
def cond_fn(carry):
161-
x, error, k = carry
162-
return (error > tolerance) & (k <= max_iter)
163-
164-
def body_fn(carry):
165-
x, error, k = carry
166-
x_new = T(x)
167-
error = jnp.max(jnp.abs(x_new - x))
168-
return (x_new, error, k + 1)
169-
170-
initial_carry = (x_0, tolerance + 1, 1)
171-
x_final, _, _ = lax.while_loop(cond_fn, body_fn, initial_carry)
172-
173-
return x_final
174-
```
175-
176-
177-
Next let's set up a `Model` class to store information needed to solve the model.
141+
Let's set up a `Model` class to store information needed to solve the model.
178142

179143
We include `P_cumsum`, the row-wise cumulative sum of the transition matrix, to
180144
optimize the simulation -- the details are explained below.
181145

182-
183146
```{code-cell} ipython3
184147
class Model(NamedTuple):
185148
n: int
@@ -215,7 +178,6 @@ def create_js_with_sep_model(
215178
Here's the Bellman operator for the unemployed worker's value function:
216179

217180
```{code-cell} ipython3
218-
@jit
219181
def T(v: jnp.ndarray, model: Model) -> jnp.ndarray:
220182
"""The Bellman operator for the value of being unemployed."""
221183
n, w_vals, P, P_cumsum, β, c, α = model
@@ -229,7 +191,6 @@ The next function computes the optimal policy under the assumption that $v$ is
229191
the value function:
230192

231193
```{code-cell} ipython3
232-
@jit
233194
def get_greedy(v: jnp.ndarray, model: Model) -> jnp.ndarray:
234195
"""Get a v-greedy policy."""
235196
n, w_vals, P, P_cumsum, β, c, α = model
@@ -247,14 +208,34 @@ The second routine requires a policy function, which we will typically obtain by
247208
applying the `vfi` function.
248209

249210
```{code-cell} ipython3
250-
def vfi(model: Model):
251-
"""Solve by VFI."""
211+
@jax.jit
212+
def vfi(
213+
model: Model,
214+
tolerance: float = 1e-6, # Error tolerance
215+
max_iter: int = 100_000, # Max iteration bound
216+
):
217+
252218
v_init = jnp.zeros(model.w_vals.shape)
253-
v_star = successive_approx(lambda v: T(v, model), v_init)
254-
σ_star = get_greedy(v_star, model)
255-
return v_star, σ_star
219+
220+
def cond(loop_state):
221+
v, error, i = loop_state
222+
return (error > tolerance) & (i <= max_iter)
223+
224+
def update(loop_state):
225+
v, error, i = loop_state
226+
v_new = T(v, model)
227+
error = jnp.max(jnp.abs(v_new - v))
228+
new_loop_state = v_new, error, i + 1
229+
return new_loop_state
230+
231+
initial_state = (v_init, tolerance + 1, 1)
232+
final_loop_state = lax.while_loop(cond, update, initial_state)
233+
v_final, error, i = final_loop_state
256234
235+
return v_final
257236
237+
238+
@jax.jit
258239
def get_reservation_wage(σ: jnp.ndarray, model: Model) -> float:
259240
"""
260241
Calculate the reservation wage from a given policy.
@@ -268,25 +249,24 @@ def get_reservation_wage(σ: jnp.ndarray, model: Model) -> float:
268249
"""
269250
n, w_vals, P, P_cumsum, β, c, α = model
270251
271-
# Find all wage indices where policy indicates acceptance
272-
accept_indices = jnp.where(σ == 1)[0]
273-
274-
if len(accept_indices) == 0:
275-
return jnp.inf # Agent never accepts any wage
252+
# Find the first index where policy indicates acceptance
253+
# σ is a boolean array, argmax returns the first True value
254+
first_accept_idx = jnp.argmax(σ)
276255
277-
# Return the lowest wage that is accepted
278-
return w_vals[accept_indices[0]]
256+
# If no acceptance (all False), return infinity
257+
# Otherwise return the wage at the first acceptance index
258+
return jnp.where(jnp.any(σ), w_vals[first_accept_idx], jnp.inf)
279259
```
280260

281-
282261
## Computing the Solution
283262

284263
Let's solve the model:
285264

286265
```{code-cell} ipython3
287266
model = create_js_with_sep_model()
288267
n, w_vals, P, P_cumsum, β, c, α = model
289-
v_star, σ_star = vfi(model)
268+
v_star = vfi(model)
269+
σ_star = get_greedy(v_star, model)
290270
```
291271

292272
Next we compute some related quantities, including the reservation wage.
@@ -312,19 +292,18 @@ ax.set_xlabel(r"$w$")
312292
plt.show()
313293
```
314294

315-
316295
## Sensitivity Analysis
317296

318297
Let's examine how reservation wages change with the separation rate.
319298

320-
321299
```{code-cell} ipython3
322300
α_vals: jnp.ndarray = jnp.linspace(0.0, 1.0, 10)
323301
324302
w_star_vec = jnp.empty_like(α_vals)
325303
for (i_α, α) in enumerate(α_vals):
326304
model = create_js_with_sep_model(α=α)
327-
v_star, σ_star = vfi(model)
305+
v_star = vfi(model)
306+
σ_star = get_greedy(v_star, model)
328307
w_star = get_reservation_wage(σ_star, model)
329308
w_star_vec = w_star_vec.at[i_α].set(w_star)
330309
@@ -356,9 +335,8 @@ This is implemented via `jnp.searchsorted` on the precomputed cumulative sum
356335

357336
The function `update_agent` advances the agent's state by one period.
358337

359-
360338
```{code-cell} ipython3
361-
@jit
339+
@jax.jit
362340
def update_agent(key, is_employed, wage_idx, model, σ):
363341
"""
364342
Updates an agent by one period. Updates their employment status and their
@@ -439,7 +417,8 @@ Let's create a comprehensive plot of the employment simulation:
439417
model = create_js_with_sep_model()
440418
441419
# Calculate reservation wage for plotting
442-
v_star, σ_star = vfi(model)
420+
v_star = vfi(model)
421+
σ_star = get_greedy(v_star, model)
443422
w_star = get_reservation_wage(σ_star, model)
444423
445424
wage_path, employment_status = simulate_employment_path(model, σ_star)
@@ -486,7 +465,6 @@ plt.tight_layout()
486465
plt.show()
487466
```
488467

489-
490468
The simulation helps to visualize outcomes associated with this model.
491469

492470
The agent follows a reservation wage strategy.
@@ -568,7 +546,7 @@ update_agents_vmap = jax.vmap(
568546
Next we define the core simulation function, which uses `lax.fori_loop` to efficiently iterate many agents forward in time:
569547

570548
```{code-cell} ipython3
571-
@partial(jit, static_argnums=(3, 4))
549+
@partial(jax.jit, static_argnums=(3, 4))
572550
def _simulate_cross_section_compiled(
573551
key: jnp.ndarray,
574552
model: Model,
@@ -627,7 +605,8 @@ def simulate_cross_section(
627605
key = jax.random.PRNGKey(seed)
628606
629607
# Solve for optimal policy
630-
v_star, σ_star = vfi(model)
608+
v_star = vfi(model)
609+
σ_star = get_greedy(v_star, model)
631610
632611
# Run JIT-compiled simulation
633612
final_employment = _simulate_cross_section_compiled(
@@ -655,7 +634,8 @@ def plot_cross_sectional_unemployment(model: Model, t_snapshot: int = 200,
655634
"""
656635
# Get final employment state directly
657636
key = jax.random.PRNGKey(42)
658-
v_star, σ_star = vfi(model)
637+
v_star = vfi(model)
638+
σ_star = get_greedy(v_star, model)
659639
final_employment = _simulate_cross_section_compiled(
660640
key, model, σ_star, n_agents, t_snapshot
661641
)
@@ -751,4 +731,3 @@ plt.show()
751731

752732
```{solution-end}
753733
```
754-

lectures/mccall_model_with_separation.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,7 @@ Our first aim is to obtain these functions.
135135

136136
### The Bellman Equations
137137

138-
Suppose for now that the worker can calculate the functions $v_e$ and $v_u$ and use them in his decision making.
139-
140-
Then $v_e$ and $v_u$ should satisfy
138+
The functions $v_e$ and $v_u$ must satisfy
141139

142140
```{math}
143141
:label: bell1_mccall

0 commit comments

Comments
 (0)