Skip to content

Commit 61004bc

Browse files
committed
convert mccall model from numba to jax
1 parent edd7608 commit 61004bc

File tree

1 file changed

+68
-71
lines changed

1 file changed

+68
-71
lines changed

lectures/mccall_q.md

Lines changed: 68 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ n, a, b = 10, 200, 100 # default parameters
133133
q_default = BetaBinomial(n, a, b).pdf() # default choice of q
134134
135135
w_min, w_max = 10, 60
136-
w_default = np.linspace(w_min, w_max, n+1)
136+
w_default = jnp.linspace(w_min, w_max, n+1)
137137
138138
# plot distribution of wage offer
139139
fig, ax = plt.subplots(figsize=(10,6))
@@ -149,81 +149,78 @@ Next we'll compute the worker's optimal value function by iterating to convergen
149149
Then we'll plot various iterates on the Bellman operator.
150150
151151
```{code-cell} ipython3
152-
mccall_data = [
153-
('c', float64), # unemployment compensation
154-
('β', float64), # discount factor
155-
('w', float64[::1]), # array of wage values, w[i] = wage at state i
156-
('q', float64[::1]) # array of probabilities
157-
]
158-
159-
160-
@jitclass(mccall_data)
161-
class McCallModel:
162-
163-
def __init__(self, c=25, β=0.99, w=w_default, q=q_default):
164-
165-
self.c, self.β = c, β
166-
self.w, self.q = w, q
167-
168-
def state_action_values(self, i, v):
169-
"""
170-
The values of state-action pairs.
171-
"""
172-
# Simplify names
173-
c, β, w, q = self.c, self.β, self.w, self.q
174-
# Evaluate value for each state-action pair
175-
# Consider action = accept or reject the current offer
176-
accept = w[i] / (1 - β)
177-
reject = c + β * (v @ q)
178-
179-
return np.array([accept, reject])
180-
181-
def VFI(self, eps=1e-5, max_iter=500):
182-
"""
183-
Find the optimal value function.
184-
"""
185-
186-
n = len(self.w)
187-
v = self.w / (1 - self.β)
188-
v_next = np.empty_like(v)
189-
flag=0
190-
191-
for i in range(max_iter):
192-
for j in range(n):
193-
v_next[j] = np.max(self.state_action_values(j, v))
194-
195-
if np.max(np.abs(v_next - v))<=eps:
196-
flag=1
197-
break
198-
v[:] = v_next
199-
200-
return v, flag
152+
class McCallModel(NamedTuple):
153+
c: float # unemployment compensation
154+
β: float # discount factor
155+
w: jnp.ndarray # array of wage values, w[i] = wage at state i
156+
q: jnp.ndarray # array of probabilities
157+
158+
def create_mccall_model(c=25, β=0.99, w=w_default, q=q_default):
159+
return McCallModel(c=c, β=β, w=w, q=q)
160+
161+
@jax.jit
162+
def state_action_values(model, i, v):
163+
"""The values of state-action pairs."""
164+
# Unpack model parameters
165+
c, β, w, q = model.c, model.β, model.w, model.q
166+
# Evaluate value for each state-action pair
167+
# Consider action = accept or reject the current offer
168+
accept = w[i] / (1 - β)
169+
reject = c + β * (v @ q)
170+
return jnp.array([accept, reject])
171+
172+
@jax.jit
173+
def VFI(model, eps=1e-5, max_iter=500):
174+
"""Find the optimal value function."""
175+
n = len(model.w)
176+
v_init = model.w / (1 - model.β)
177+
178+
def body_fun(state):
179+
v, i, error = state
180+
v_next = jnp.empty_like(v)
181+
182+
# Update all elements of v_next
183+
for j in range(n):
184+
v_next = v_next.at[j].set(jnp.max(state_action_values(model, j, v)))
185+
186+
error = jnp.max(jnp.abs(v_next - v))
187+
return v_next, i + 1, error
188+
189+
def cond_fun(state):
190+
v, i, error = state
191+
return (error > eps) & (i < max_iter)
192+
193+
# Initial state: (v, iteration, error)
194+
init_state = (v_init, 0, eps + 1)
195+
final_v, final_i, final_error = jax.lax.while_loop(cond_fun, body_fun, init_state)
196+
197+
flag = jnp.where(final_error <= eps, 1, 0)
198+
return final_v, flag
201199
202200
def plot_value_function_seq(mcm, ax, num_plots=8):
203-
"""
204-
Plot a sequence of value functions.
205-
206-
* mcm is an instance of McCallModel
207-
* ax is an axes object that implements a plot method.
208-
209-
"""
210-
211-
n = len(mcm.w)
212-
v = mcm.w / (1 - mcm.β)
213-
v_next = np.empty_like(v)
214-
for i in range(num_plots):
215-
ax.plot(mcm.w, v, '-', alpha=0.4, label=f"iterate {i}")
216-
# Update guess
217-
for i in range(n):
218-
v_next[i] = np.max(mcm.state_action_values(i, v))
219-
v[:] = v_next # copy contents into v
220-
221-
ax.legend(loc='lower right')
201+
"""
202+
Plot a sequence of value functions.
203+
204+
* mcm is an instance of McCallModel
205+
* ax is an axes object that implements a plot method.
206+
207+
"""
208+
n = len(mcm.w)
209+
v = mcm.w / (1 - mcm.β)
210+
v_next = jnp.empty_like(v)
211+
for i in range(num_plots):
212+
ax.plot(mcm.w, v, '-', alpha=0.4, label=f"iterate {i}")
213+
# Update guess
214+
for j in range(n): # changed variable name to avoid conflict
215+
v_next = v_next.at[j].set(jnp.max(state_action_values(mcm, j, v)))
216+
v = v_next # handling immutability
217+
218+
ax.legend(loc='lower right')
222219
```
223220
224221
```{code-cell} ipython3
225-
mcm = McCallModel()
226-
valfunc_VFI, flag = mcm.VFI()
222+
mcm = create_mccall_model()
223+
valfunc_VFI, flag = VFI(mcm)
227224
228225
fig, ax = plt.subplots(figsize=(10,6))
229226
ax.set_xlabel('wage')

0 commit comments

Comments
 (0)