124124
125125Let's use Python code from {doc}`this quantecon lecture <mccall_model>`.
126126
127- We use a Python method called `VFI ` to compute the optimal value function using value function iterations.
127+ We use a Python method called `vfi ` to compute the optimal value function using value function iterations.
128128
129129We construct an assumed distribution of wages and plot it with the following Python code
130130
@@ -170,32 +170,36 @@ def state_action_values(model, i, v):
170170 return jnp.array([accept, reject])
171171
172172@jax.jit
173- def VFI(model, eps=1e-5, max_iter=500):
174- """Find the optimal value function."""
175- n = len(model.w)
176- v_init = model.w / (1 - model.β)
173+ def update(model, v):
174+ n = model.w.shape[0]
177175
178- def body_fun(state ):
179- v , i, error = state
180- v_next = jnp.empty_like(v )
176+ def v_at_state(i ):
177+ sa = state_action_values(model , i, v)
178+ return jnp.max(sa )
181179
182- # Update all elements of v_next
183- for j in range(n):
184- v_next = v_next.at[j].set(jnp.max(state_action_values(model, j, v)))
180+ indices = jnp.arange(n)
181+ v_new = jax.vmap(v_at_state)(indices)
182+ return v_new
185183
186- error = jnp.max(jnp.abs(v_next - v))
187- return v_next, i + 1, error
184+ @jax.jit
185+ def vfi(model, tol=1e-5, max_iter=500):
186+
187+ v0 = model.w / (1.0 - model.β)
188188
189- def cond_fun(state):
190- v, i, error = state
191- return (error > eps) & (i < max_iter)
189+ def body_fun(state):
190+ v, i, err = state
191+ v_new = update(model, v)
192+ err_new = jnp.max(jnp.abs(v_new - v))
193+ return v_new, i + 1, err_new
192194
193- # Initial state: (v, iteration, error)
194- init_state = (v_init, 0, eps + 1)
195- final_v, final_i, final_error = jax.lax.while_loop(cond_fun, body_fun, init_state )
195+ def cond_fun(state):
196+ _, i, err = state
197+ return (err > tol) & (i < max_iter )
196198
197- flag = jnp.where(final_error <= eps, 1, 0)
198- return final_v, flag
199+ init_state = (v0, 0, tol + 1.0)
200+ v_final, iters, err = jax.lax.while_loop(cond_fun, body_fun, init_state)
201+ converged = jnp.where(err <= tol, 1, 0)
202+ return v_final, converged
199203
200204def plot_value_function_seq(mcm, ax, num_plots=8):
201205 """
@@ -220,7 +224,7 @@ def plot_value_function_seq(mcm, ax, num_plots=8):
220224
221225```{code-cell} ipython3
222226mcm = create_mccall_model()
223- valfunc_VFI, flag = VFI (mcm)
227+ valfunc_VFI, converged = vfi (mcm)
224228
225229fig, ax = plt.subplots(figsize=(10,6))
226230ax.set_xlabel('wage')
@@ -687,7 +691,7 @@ plt.show()
687691```{code-cell} ipython3
688692# VFI
689693mcm = create_mccall_model(w=w_new, q=q_new)
690- valfunc_VFI, flag = VFI (mcm)
694+ valfunc_VFI, converged = vfi (mcm)
691695valfunc_VFI
692696```
693697
0 commit comments