@@ -133,7 +133,7 @@ n, a, b = 10, 200, 100 # default parameters
133133q_default = BetaBinomial(n, a, b).pdf() # default choice of q
134134
135135w_min, w_max = 10, 60
136- w_default = np .linspace(w_min, w_max, n+1)
136+ w_default = jnp .linspace(w_min, w_max, n+1)
137137
138138# plot distribution of wage offer
139139fig, ax = plt.subplots(figsize=(10,6))
@@ -149,81 +149,78 @@ Next we'll compute the worker's optimal value function by iterating to convergen
149149Then we'll plot various iterates on the Bellman operator.
150150
151151```{code-cell} ipython3
152- mccall_data = [
153- ('c', float64), # unemployment compensation
154- ('β', float64), # discount factor
155- ('w', float64[::1]), # array of wage values, w[i] = wage at state i
156- ('q', float64[::1]) # array of probabilities
157- ]
158-
159-
160- @jitclass(mccall_data)
161- class McCallModel:
162-
163- def __init__(self, c=25, β=0.99, w=w_default, q=q_default):
164-
165- self.c, self.β = c, β
166- self.w, self.q = w, q
167-
168- def state_action_values(self, i, v):
169- """
170- The values of state-action pairs.
171- """
172- # Simplify names
173- c, β, w, q = self.c, self.β, self.w, self.q
174- # Evaluate value for each state-action pair
175- # Consider action = accept or reject the current offer
176- accept = w[i] / (1 - β)
177- reject = c + β * (v @ q)
178-
179- return np.array([accept, reject])
180-
181- def VFI(self, eps=1e-5, max_iter=500):
182- """
183- Find the optimal value function.
184- """
185-
186- n = len(self.w)
187- v = self.w / (1 - self.β)
188- v_next = np.empty_like(v)
189- flag=0
190-
191- for i in range(max_iter):
192- for j in range(n):
193- v_next[j] = np.max(self.state_action_values(j, v))
194-
195- if np.max(np.abs(v_next - v))<=eps:
196- flag=1
197- break
198- v[:] = v_next
199-
200- return v, flag
152+ class McCallModel(NamedTuple):
153+ c: float # unemployment compensation
154+ β: float # discount factor
155+ w: jnp.ndarray # array of wage values, w[i] = wage at state i
156+ q: jnp.ndarray # array of probabilities
157+
158+ def create_mccall_model(c=25, β=0.99, w=w_default, q=q_default):
159+ return McCallModel(c=c, β=β, w=w, q=q)
160+
161+ @jax.jit
162+ def state_action_values(model, i, v):
163+ """The values of state-action pairs."""
164+ # Unpack model parameters
165+ c, β, w, q = model.c, model.β, model.w, model.q
166+ # Evaluate value for each state-action pair
167+ # Consider action = accept or reject the current offer
168+ accept = w[i] / (1 - β)
169+ reject = c + β * (v @ q)
170+ return jnp.array([accept, reject])
171+
172+ @jax.jit
173+ def VFI(model, eps=1e-5, max_iter=500):
174+ """Find the optimal value function."""
175+ n = len(model.w)
176+ v_init = model.w / (1 - model.β)
177+
178+ def body_fun(state):
179+ v, i, error = state
180+ v_next = jnp.empty_like(v)
181+
182+ # Update all elements of v_next
183+ for j in range(n):
184+ v_next = v_next.at[j].set(jnp.max(state_action_values(model, j, v)))
185+
186+ error = jnp.max(jnp.abs(v_next - v))
187+ return v_next, i + 1, error
188+
189+ def cond_fun(state):
190+ v, i, error = state
191+ return (error > eps) & (i < max_iter)
192+
193+ # Initial state: (v, iteration, error)
194+ init_state = (v_init, 0, eps + 1)
195+ final_v, final_i, final_error = jax.lax.while_loop(cond_fun, body_fun, init_state)
196+
197+ flag = jnp.where(final_error <= eps, 1, 0)
198+ return final_v, flag
201199
202200def plot_value_function_seq(mcm, ax, num_plots=8):
203- """
204- Plot a sequence of value functions.
205-
206- * mcm is an instance of McCallModel
207- * ax is an axes object that implements a plot method.
208-
209- """
210-
211- n = len(mcm.w)
212- v = mcm.w / (1 - mcm.β)
213- v_next = np.empty_like(v)
214- for i in range(num_plots):
215- ax.plot(mcm.w, v, '-', alpha=0.4, label=f"iterate {i}")
216- # Update guess
217- for i in range(n):
218- v_next[i] = np.max(mcm.state_action_values(i, v))
219- v[:] = v_next # copy contents into v
220-
221- ax.legend(loc='lower right')
201+ """
202+ Plot a sequence of value functions.
203+
204+ * mcm is an instance of McCallModel
205+ * ax is an axes object that implements a plot method.
206+
207+ """
208+ n = len(mcm.w)
209+ v = mcm.w / (1 - mcm.β)
210+ v_next = jnp.empty_like(v)
211+ for i in range(num_plots):
212+ ax.plot(mcm.w, v, '-', alpha=0.4, label=f"iterate {i}")
213+ # Update guess
214+ for j in range(n): # changed variable name to avoid conflict
215+ v_next = v_next.at[j].set(jnp.max(state_action_values(mcm, j, v)))
216+ v = v_next # handling immutability
217+
218+ ax.legend(loc='lower right')
222219```
223220
224221```{code-cell} ipython3
225- mcm = McCallModel ()
226- valfunc_VFI, flag = mcm. VFI()
222+ mcm = create_mccall_model ()
223+ valfunc_VFI, flag = VFI(mcm )
227224
228225fig, ax = plt.subplots(figsize=(10,6))
229226ax.set_xlabel('wage')
0 commit comments