@@ -4,11 +4,11 @@ jupytext:
44 extension : .md
55 format_name : myst
66 format_version : 0.13
7- jupytext_version : 1.17.1
7+ jupytext_version : 1.17.2
88kernelspec :
9- name : python3
109 display_name : Python 3 (ipykernel)
1110 language : python
11+ name : python3
1212---
1313
1414(mccall_with_sep_markov)=
@@ -58,7 +58,7 @@ We use the following imports:
5858from quantecon.markov import tauchen
5959import jax.numpy as jnp
6060import jax
61- from jax import jit, lax
61+ from jax import lax
6262from typing import NamedTuple
6363import matplotlib.pyplot as plt
6464from functools import partial
@@ -138,48 +138,11 @@ The optimal policy turns out to be a reservation wage strategy: accept all wages
138138
139139## Code
140140
141-
142- First, we implement the successive approximation algorithm.
143-
144- This algorithm takes an operator $T$ and an initial condition and iterates until
145- convergence.
146-
147- We will use it for value function iteration.
148-
149- ``` {code-cell} ipython3
150- @partial(jit, static_argnums=(0,))
151- def successive_approx(
152- T, # Operator (callable) - marked as static
153- x_0, # Initial condition
154- tolerance: float = 1e-6, # Error tolerance
155- max_iter: int = 100_000, # Max iteration bound
156- ):
157- """Computes the approximate fixed point of T via successive
158- approximation using lax.while_loop."""
159-
160- def cond_fn(carry):
161- x, error, k = carry
162- return (error > tolerance) & (k <= max_iter)
163-
164- def body_fn(carry):
165- x, error, k = carry
166- x_new = T(x)
167- error = jnp.max(jnp.abs(x_new - x))
168- return (x_new, error, k + 1)
169-
170- initial_carry = (x_0, tolerance + 1, 1)
171- x_final, _, _ = lax.while_loop(cond_fn, body_fn, initial_carry)
172-
173- return x_final
174- ```
175-
176-
177- Next let's set up a ` Model ` class to store information needed to solve the model.
141+ Let's set up a ` Model ` class to store information needed to solve the model.
178142
179143We include ` P_cumsum ` , the row-wise cumulative sum of the transition matrix, to
180144optimize the simulation -- the details are explained below.
181145
182-
183146``` {code-cell} ipython3
184147class Model(NamedTuple):
185148 n: int
@@ -215,7 +178,6 @@ def create_js_with_sep_model(
215178Here's the Bellman operator for the unemployed worker's value function:
216179
217180``` {code-cell} ipython3
218- @jit
219181def T(v: jnp.ndarray, model: Model) -> jnp.ndarray:
220182 """The Bellman operator for the value of being unemployed."""
221183 n, w_vals, P, P_cumsum, β, c, α = model
@@ -229,7 +191,6 @@ The next function computes the optimal policy under the assumption that $v$ is
229191the value function:
230192
231193``` {code-cell} ipython3
232- @jit
233194def get_greedy(v: jnp.ndarray, model: Model) -> jnp.ndarray:
234195 """Get a v-greedy policy."""
235196 n, w_vals, P, P_cumsum, β, c, α = model
@@ -247,14 +208,34 @@ The second routine requires a policy function, which we will typically obtain by
247208applying the ` vfi ` function.
248209
249210``` {code-cell} ipython3
250- def vfi(model: Model):
251- """Solve by VFI."""
211+ @jax.jit
212+ def vfi(
213+ model: Model,
214+ tolerance: float = 1e-6, # Error tolerance
215+ max_iter: int = 100_000, # Max iteration bound
216+ ):
217+
252218 v_init = jnp.zeros(model.w_vals.shape)
253- v_star = successive_approx(lambda v: T(v, model), v_init)
254- σ_star = get_greedy(v_star, model)
255- return v_star, σ_star
219+
220+ def cond(loop_state):
221+ v, error, i = loop_state
222+ return (error > tolerance) & (i <= max_iter)
223+
224+ def update(loop_state):
225+ v, error, i = loop_state
226+ v_new = T(v, model)
227+ error = jnp.max(jnp.abs(v_new - v))
228+ new_loop_state = v_new, error, i + 1
229+ return new_loop_state
230+
231+ initial_state = (v_init, tolerance + 1, 1)
232+ final_loop_state = lax.while_loop(cond, update, initial_state)
233+ v_final, error, i = final_loop_state
256234
235+ return v_final
257236
237+
238+ @jax.jit
258239def get_reservation_wage(σ: jnp.ndarray, model: Model) -> float:
259240 """
260241 Calculate the reservation wage from a given policy.
@@ -268,25 +249,24 @@ def get_reservation_wage(σ: jnp.ndarray, model: Model) -> float:
268249 """
269250 n, w_vals, P, P_cumsum, β, c, α = model
270251
271- # Find all wage indices where policy indicates acceptance
272- accept_indices = jnp.where(σ == 1)[0]
273-
274- if len(accept_indices) == 0:
275- return jnp.inf # Agent never accepts any wage
252+ # Find the first index where policy indicates acceptance
253+ # σ is a boolean array, argmax returns the first True value
254+ first_accept_idx = jnp.argmax(σ)
276255
277- # Return the lowest wage that is accepted
278- return w_vals[accept_indices[0]]
256+ # If no acceptance (all False), return infinity
257+ # Otherwise return the wage at the first acceptance index
258+ return jnp.where(jnp.any(σ), w_vals[first_accept_idx], jnp.inf)
279259```
280260
281-
282261## Computing the Solution
283262
284263Let's solve the model:
285264
286265``` {code-cell} ipython3
287266model = create_js_with_sep_model()
288267n, w_vals, P, P_cumsum, β, c, α = model
289- v_star, σ_star = vfi(model)
268+ v_star = vfi(model)
269+ σ_star = get_greedy(v_star, model)
290270```
291271
292272Next we compute some related quantities, including the reservation wage.
@@ -312,19 +292,18 @@ ax.set_xlabel(r"$w$")
312292plt.show()
313293```
314294
315-
316295## Sensitivity Analysis
317296
318297Let's examine how reservation wages change with the separation rate.
319298
320-
321299``` {code-cell} ipython3
322300α_vals: jnp.ndarray = jnp.linspace(0.0, 1.0, 10)
323301
324302w_star_vec = jnp.empty_like(α_vals)
325303for (i_α, α) in enumerate(α_vals):
326304 model = create_js_with_sep_model(α=α)
327- v_star, σ_star = vfi(model)
305+ v_star = vfi(model)
306+ σ_star = get_greedy(v_star, model)
328307 w_star = get_reservation_wage(σ_star, model)
329308 w_star_vec = w_star_vec.at[i_α].set(w_star)
330309
@@ -356,9 +335,8 @@ This is implemented via `jnp.searchsorted` on the precomputed cumulative sum
356335
357336The function ` update_agent ` advances the agent's state by one period.
358337
359-
360338``` {code-cell} ipython3
361- @jit
339+ @jax. jit
362340def update_agent(key, is_employed, wage_idx, model, σ):
363341 """
364342 Updates an agent by one period. Updates their employment status and their
@@ -439,7 +417,8 @@ Let's create a comprehensive plot of the employment simulation:
439417model = create_js_with_sep_model()
440418
441419# Calculate reservation wage for plotting
442- v_star, σ_star = vfi(model)
420+ v_star = vfi(model)
421+ σ_star = get_greedy(v_star, model)
443422w_star = get_reservation_wage(σ_star, model)
444423
445424wage_path, employment_status = simulate_employment_path(model, σ_star)
@@ -486,7 +465,6 @@ plt.tight_layout()
486465plt.show()
487466```
488467
489-
490468The simulation helps to visualize outcomes associated with this model.
491469
492470The agent follows a reservation wage strategy.
@@ -568,7 +546,7 @@ update_agents_vmap = jax.vmap(
568546Next we define the core simulation function, which uses ` lax.fori_loop ` to efficiently iterate many agents forward in time:
569547
570548``` {code-cell} ipython3
571- @partial(jit, static_argnums=(3, 4))
549+ @partial(jax. jit, static_argnums=(3, 4))
572550def _simulate_cross_section_compiled(
573551 key: jnp.ndarray,
574552 model: Model,
@@ -627,7 +605,8 @@ def simulate_cross_section(
627605 key = jax.random.PRNGKey(seed)
628606
629607 # Solve for optimal policy
630- v_star, σ_star = vfi(model)
608+ v_star = vfi(model)
609+ σ_star = get_greedy(v_star, model)
631610
632611 # Run JIT-compiled simulation
633612 final_employment = _simulate_cross_section_compiled(
@@ -655,7 +634,8 @@ def plot_cross_sectional_unemployment(model: Model, t_snapshot: int = 200,
655634 """
656635 # Get final employment state directly
657636 key = jax.random.PRNGKey(42)
658- v_star, σ_star = vfi(model)
637+ v_star = vfi(model)
638+ σ_star = get_greedy(v_star, model)
659639 final_employment = _simulate_cross_section_compiled(
660640 key, model, σ_star, n_agents, t_snapshot
661641 )
@@ -751,4 +731,3 @@ plt.show()
751731
752732``` {solution-end}
753733```
754-
0 commit comments