diff --git a/lectures/mccall_model.md b/lectures/mccall_model.md index 202b9d591..239d18610 100644 --- a/lectures/mccall_model.md +++ b/lectures/mccall_model.md @@ -34,11 +34,11 @@ and the pros and cons as they themselves see them." -- Robert E. Lucas, Jr. In addition to what's in Anaconda, this lecture will need the following libraries: -```{code-cell} ipython +```{code-cell} ipython3 --- tags: [hide-output] --- -!pip install quantecon +!pip install quantecon jax ``` ## Overview @@ -62,8 +62,9 @@ Let's start with some imports: ```{code-cell} ipython import matplotlib.pyplot as plt import numpy as np -from numba import jit, float64 -from numba.experimental import jitclass +import jax +import jax.numpy as jnp +from typing import NamedTuple import quantecon as qe from quantecon.distributions import BetaBinomial ``` @@ -91,9 +92,11 @@ At time $t$, our agent has two choices: The agent is infinitely lived and aims to maximize the expected discounted sum of earnings -$$ -\mathbb{E} \sum_{t=0}^{\infty} \beta^t y_t -$$ +```{math} +:label: obj_model + +{\mathbb E} \sum_{t=0}^\infty \beta^t u(y_t) +``` The constant $\beta$ lies in $(0, 1)$ and is called a **discount factor**. @@ -112,7 +115,7 @@ The worker faces a trade-off: * Waiting too long for a good offer is costly, since the future is discounted. * Accepting too early is costly, since better offers might arrive in the future. -To decide optimally in the face of this trade-off, we use dynamic programming. +To decide optimally in the face of this trade-off, we use [dynamic programming](https://dp.quantecon.org/). Dynamic programming can be thought of as a two-step procedure that @@ -135,10 +138,10 @@ To this end, let $v^*(w)$ be the total lifetime *value* accruing to an unemployed worker who enters the current period unemployed when the wage is $w \in \mathbb{W}$. -In particular, the agent has wage offer $w$ in hand. +(In particular, the agent has wage offer $w$ in hand and can accept or reject it.) More precisely, $v^*(w)$ denotes the value of the objective function -{eq}`objective` when an agent in this situation makes *optimal* decisions now +{eq}`obj_model` when an agent in this situation makes *optimal* decisions now and at all future points in time. Of course $v^*(w)$ is not trivial to calculate because we don't yet know @@ -163,7 +166,7 @@ v^*(w) for every possible $w$ in $\mathbb{W}$. -This important equation is a version of the **Bellman equation**, which is +This is a version of the **Bellman equation**, which is ubiquitous in economic dynamics and other fields involving planning over time. The intuition behind it is as follows: @@ -174,9 +177,12 @@ $$ \frac{w}{1 - \beta} = w + \beta w + \beta^2 w + \cdots $$ -* the second term inside the max operation is the **continuation value**, which is the lifetime payoff from rejecting the current offer and then behaving optimally in all subsequent periods +* the second term inside the max operation is the continuation value, which is + the lifetime payoff from rejecting the current offer and then behaving + optimally in all subsequent periods -If we optimize and pick the best of these two options, we obtain maximal lifetime value from today, given current offer $w$. +If we optimize and pick the best of these two options, we obtain maximal +lifetime value from today, given current offer $w$. But this is precisely $v^*(w)$, which is the left-hand side of {eq}`odu_pv`. @@ -193,7 +199,7 @@ All we have to do is select the maximal choice on the right-hand side of {eq}`od The optimal action is best thought of as a **policy**, which is, in general, a map from states to actions. -Given *any* $w$, we can read off the corresponding best choice (accept or +Given any $w$, we can read off the corresponding best choice (accept or reject) by picking the max on the right-hand side of {eq}`odu_pv`. Thus, we have a map from $\mathbb W$ to $\{0, 1\}$, with 1 meaning accept and 0 meaning reject. @@ -224,7 +230,7 @@ where \bar w := (1 - \beta) \left\{ c + \beta \sum_{w'} v^*(w') q (w') \right\} ``` -Here $\bar w$ (called the *reservation wage*) is a constant depending on +Here $\bar w$ (called the **reservation wage**) is a constant depending on $\beta, c$ and the wage distribution. The agent should accept if and only if the current wage offer exceeds the reservation wage. @@ -234,8 +240,7 @@ In view of {eq}`reswage`, we can compute this reservation wage if we can compute ## Computing the Optimal Policy: Take 1 -To put the above ideas into action, we need to compute the value function at -each possible state $w \in \mathbb W$. +To put the above ideas into action, we need to compute the value function at each $w \in \mathbb W$. To simplify notation, let's set @@ -245,8 +250,7 @@ $$ v^*(i) := v^*(w_i) $$ -The value function is then represented by the vector -$v^* = (v^*(i))_{i=1}^n$. +The value function is then represented by the vector $v^* = (v^*(i))_{i=1}^n$. In view of {eq}`odu_pv`, this vector satisfies the nonlinear system of equations @@ -298,8 +302,7 @@ The theory below elaborates on this point. What's the mathematics behind these ideas? -First, one defines a mapping $T$ from $\mathbb R^n$ to -itself via +First, one defines a mapping $T$ from $\mathbb R^n$ to itself via ```{math} :label: odu_pv3 @@ -316,11 +319,9 @@ itself via (A new vector $Tv$ is obtained from given vector $v$ by evaluating the r.h.s. at each $i$.) -The element $v_k$ in the sequence $\{v_k\}$ of successive -approximations corresponds to $T^k v$. +The element $v_k$ in the sequence $\{v_k\}$ of successive approximations corresponds to $T^k v$. -* This is $T$ applied $k$ times, starting at the initial guess - $v$ +* This is $T$ applied $k$ times, starting at the initial guess $v$ One can show that the conditions of the [Banach fixed point theorem](https://en.wikipedia.org/wiki/Banach_fixed-point_theorem) are satisfied by $T$ on $\mathbb R^n$. @@ -329,12 +330,11 @@ One implication is that $T$ has a unique fixed point in $\mathbb R^n$. * That is, a unique vector $\bar v$ such that $T \bar v = \bar v$. -Moreover, it's immediate from the definition of $T$ that this fixed -point is $v^*$. +Moreover, it's immediate from the definition of $T$ that this fixed point is $v^*$. A second implication of the Banach contraction mapping theorem is that -$\{ T^k v \}$ converges to the fixed point $v^*$ regardless of -$v$. +$\{ T^k v \}$ converges to the fixed point $v^*$ regardless of $v$. + ### Implementation @@ -343,14 +343,14 @@ Our default for $q$, the distribution of the state process, will be ```{code-cell} python3 n, a, b = 50, 200, 100 # default parameters -q_default = BetaBinomial(n, a, b).pdf() # default choice of q +q_default = jnp.array(BetaBinomial(n, a, b).pdf()) ``` Our default set of values for wages will be ```{code-cell} python3 w_min, w_max = 10, 60 -w_default = np.linspace(w_min, w_max, n+1) +w_default = jnp.linspace(w_min, w_max, n+1) ``` Here's a plot of the probabilities of different wage outcomes: @@ -364,60 +364,37 @@ ax.set_ylabel('probabilities') plt.show() ``` -We are going to use Numba to accelerate our code. +We will use [JAX](https://python-programming.quantecon.org/jax_intro.html) to write our code. -* See, in particular, the discussion of `@jitclass` in [our lecture on Numba](https://python-programming.quantecon.org/numba.html). +We'll use `NamedTuple` for our model class to maintain immutability, which works well with JAX's functional programming paradigm. -The following helps Numba by providing some type specifications. +Here's a class that stores the model parameters with default values. ```{code-cell} python3 -mccall_data = [ - ('c', float64), # unemployment compensation - ('β', float64), # discount factor - ('w', float64[::1]), # array of wage values, w[i] = wage at state i - ('q', float64[::1]) # array of probabilities -] +class McCallModel(NamedTuple): + c: float = 25 # unemployment compensation + β: float = 0.99 # discount factor + w: jnp.ndarray = w_default # array of wage values, w[i] = wage at state i + q: jnp.ndarray = q_default # array of probabilities ``` -```{note} -Note the use of `[::1]` in the array type declarations above. - -This notation specifies that the arrays should be C-contiguous. - -This is important for performance, especially when using the `@` operator for matrix multiplication (e.g., `v @ q`). - -Without this specification, Numba might need to handle non-contiguous arrays, which can significantly slow down these operations. - -Try to replace `[::1]` with `[:]` and see what happens. -``` - -Here's a class that stores the data and computes the values of state-action pairs, -i.e. the value in the maximum bracket on the right hand side of the Bellman equation {eq}`odu_pv2p`, -given the current state and an arbitrary feasible action. - -Default parameter values are embedded in the class. +Here is a function that computes the +value in the maximum bracket on the right hand side of the Bellman equation {eq}`odu_pv2p`. ```{code-cell} python3 -@jitclass(mccall_data) -class McCallModel: - - def __init__(self, c=25, β=0.99, w=w_default, q=q_default): - - self.c, self.β = c, β - self.w, self.q = w_default, q_default - - def state_action_values(self, i, v): - """ - The values of state-action pairs. - """ - # Simplify names - c, β, w, q = self.c, self.β, self.w, self.q - # Evaluate value for each state-action pair - # Consider action = accept or reject the current offer - accept = w[i] / (1 - β) - reject = c + β * (v @ q) - - return np.array([accept, reject]) +@jax.jit +def state_action_values(model, i, v): + """ + The values of state-action pairs. + """ + # Simplify names + c, β, w, q = model.c, model.β, model.w, model.q + # Evaluate value for each state-action pair + # Consider action = accept or reject the current offer + accept = w[i] / (1 - β) + reject = c + β * (v @ q) + + return jnp.array([accept, reject]) ``` Based on these defaults, let's try plotting the first few approximate value functions @@ -439,13 +416,13 @@ def plot_value_function_seq(mcm, ax, num_plots=6): n = len(mcm.w) v = mcm.w / (1 - mcm.β) - v_next = np.empty_like(v) for i in range(num_plots): ax.plot(mcm.w, v, '-', alpha=0.4, label=f"iterate {i}") # Update guess + v_next = jnp.zeros_like(v) for j in range(n): - v_next[j] = np.max(mcm.state_action_values(j, v)) - v[:] = v_next # copy contents into v + v_next = v_next.at[j].set(jnp.max(state_action_values(mcm, j, v))) + v = v_next # update v ax.legend(loc='lower right') ``` @@ -469,37 +446,35 @@ Here's a more serious iteration effort to compute the limit, which continues unt Once we obtain a good approximation to the limit, we will use it to calculate the reservation wage. -We'll be using JIT compilation via Numba to turbocharge our loops. +We'll be using JIT compilation via JAX to accelerate our loops. ```{code-cell} python3 -@jit -def compute_reservation_wage(mcm, - max_iter=500, - tol=1e-6): - +@jax.jit +def compute_reservation_wage(mcm, max_iter=500, tol=1e-6): # Simplify names c, β, w, q = mcm.c, mcm.β, mcm.w, mcm.q - - # == First compute the value function == # - + + # First compute the value function n = len(w) - v = w / (1 - β) # initial guess - v_next = np.empty_like(v) - j = 0 - error = tol + 1 - while j < max_iter and error > tol: - + v = w / (1 - β) # initial guess + + def body_fun(state): + v, i, error = state + v_next = jnp.zeros_like(v) for j in range(n): - v_next[j] = np.max(mcm.state_action_values(j, v)) - - error = np.max(np.abs(v_next - v)) - j += 1 - - v[:] = v_next # copy contents into v - - # == Now compute the reservation wage == # - - return (1 - β) * (c + β * (v @ q)) + v_next = v_next.at[j].set(jnp.max(state_action_values(mcm, j, v))) + error = jnp.max(jnp.abs(v_next - v)) + return v_next, i + 1, error + + def cond_fun(state): + v, i, error = state + return jnp.logical_and(i < max_iter, error > tol) + + initial_state = (v, 0, tol + 1) + v_final, _, _ = jax.lax.while_loop(cond_fun, body_fun, initial_state) + + # Now compute the reservation wage + return (1 - β) * (c + β * (v_final @ q)) ``` The next line computes the reservation wage at default parameters @@ -518,15 +493,17 @@ $c$. ```{code-cell} python3 grid_size = 25 -R = np.empty((grid_size, grid_size)) +c_vals = jnp.linspace(10.0, 30.0, grid_size) +β_vals = jnp.linspace(0.9, 0.99, grid_size) -c_vals = np.linspace(10.0, 30.0, grid_size) -β_vals = np.linspace(0.9, 0.99, grid_size) +def compute_R_element(c, β): + mcm = McCallModel(c=c, β=β) + return compute_reservation_wage(mcm) -for i, c in enumerate(c_vals): - for j, β in enumerate(β_vals): - mcm = McCallModel(c=c, β=β) - R[i, j] = compute_reservation_wage(mcm) +# Create meshgrid and vectorize computation +c_grid, β_grid = jnp.meshgrid(c_vals, β_vals, indexing='ij') +compute_R_vectorized = jax.vmap(jax.vmap(compute_R_element, in_axes=(None, 0)), in_axes=(0, None)) +R = compute_R_vectorized(c_vals, β_vals) ``` ```{code-cell} python3 @@ -623,32 +600,30 @@ The big difference here, however, is that we're iterating on a scalar $h$, rathe Here's an implementation: ```{code-cell} python3 -@jit -def compute_reservation_wage_two(mcm, - max_iter=500, - tol=1e-5): - +@jax.jit +def compute_reservation_wage_two(mcm, max_iter=500, tol=1e-5): # Simplify names c, β, w, q = mcm.c, mcm.β, mcm.w, mcm.q - - # == First compute h == # - + + # First compute h h = (w @ q) / (1 - β) - i = 0 - error = tol + 1 - while i < max_iter and error > tol: - - s = np.maximum(w / (1 - β), h) + + def body_fun(state): + h, i, error = state + s = jnp.maximum(w / (1 - β), h) h_next = c + β * (s @ q) - - error = np.abs(h_next - h) - i += 1 - - h = h_next - - # == Now compute the reservation wage == # - - return (1 - β) * h + error = jnp.abs(h_next - h) + return h_next, i + 1, error + + def cond_fun(state): + h, i, error = state + return jnp.logical_and(i < max_iter, error > tol) + + initial_state = (h, 0, tol + 1) + h_final, _, _ = jax.lax.while_loop(cond_fun, body_fun, initial_state) + + # Now compute the reservation wage + return (1 - β) * h_final ``` You can use this code to solve the exercise below. @@ -678,37 +653,42 @@ Plot mean unemployment duration as a function of $c$ in `c_vals`. Here's one solution ```{code-cell} python3 -cdf = np.cumsum(q_default) - -@jit -def compute_stopping_time(w_bar, seed=1234): - - np.random.seed(seed) - t = 1 - while True: - # Generate a wage draw - w = w_default[qe.random.draw(cdf)] - # Stop when the draw is above the reservation wage - if w >= w_bar: - stopping_time = t - break - else: - t += 1 - return stopping_time - -@jit -def compute_mean_stopping_time(w_bar, num_reps=100000): - obs = np.empty(num_reps) - for i in range(num_reps): - obs[i] = compute_stopping_time(w_bar, seed=i) - return obs.mean() - -c_vals = np.linspace(10, 40, 25) -stop_times = np.empty_like(c_vals) -for i, c in enumerate(c_vals): +cdf = jnp.cumsum(q_default) + +@jax.jit +def compute_stopping_time(w_bar, key): + def body_fun(state): + t, key, done = state + key, subkey = jax.random.split(key) + u = jax.random.uniform(subkey) + w = w_default[jnp.searchsorted(cdf, u)] + done = w >= w_bar + t = jnp.where(done, t, t + 1) + return t, key, done + + def cond_fun(state): + t, _, done = state + return jnp.logical_not(done) + + initial_state = (1, key, False) + t_final, _, _ = jax.lax.while_loop(cond_fun, body_fun, initial_state) + return t_final + +@jax.jit +def compute_mean_stopping_time(w_bar, num_reps=100000, seed=1234): + key = jax.random.PRNGKey(seed) + keys = jax.random.split(key, num_reps) + obs = jax.vmap(compute_stopping_time, in_axes=(None, 0))(w_bar, keys) + return jnp.mean(obs) + +c_vals = jnp.linspace(10, 40, 25) + +def compute_stop_time_for_c(c): mcm = McCallModel(c=c) w_bar = compute_reservation_wage_two(mcm) - stop_times[i] = compute_mean_stopping_time(w_bar) + return compute_mean_stopping_time(w_bar) + +stop_times = jax.vmap(compute_stop_time_for_c)(c_vals) fig, ax = plt.subplots() @@ -789,48 +769,41 @@ Once your code is working, investigate how the reservation wage changes with $c$ Here is one solution: ```{code-cell} python3 -mccall_data_continuous = [ - ('c', float64), # unemployment compensation - ('β', float64), # discount factor - ('σ', float64), # scale parameter in lognormal distribution - ('μ', float64), # location parameter in lognormal distribution - ('w_draws', float64[:]) # draws of wages for Monte Carlo -] - -@jitclass(mccall_data_continuous) -class McCallModelContinuous: - - def __init__(self, c=25, β=0.99, σ=0.5, μ=2.5, mc_size=1000): - - self.c, self.β, self.σ, self.μ = c, β, σ, μ - - # Draw and store shocks - np.random.seed(1234) - s = np.random.randn(mc_size) - self.w_draws = np.exp(μ+ σ * s) - - -@jit +class McCallModelContinuous(NamedTuple): + c: float # unemployment compensation + β: float # discount factor + σ: float # scale parameter in lognormal distribution + μ: float # location parameter in lognormal distribution + w_draws: jnp.ndarray # draws of wages for Monte Carlo + +def create_mccall_continuous(c=25, β=0.99, σ=0.5, μ=2.5, mc_size=1000, seed=1234): + key = jax.random.PRNGKey(seed) + s = jax.random.normal(key, (mc_size,)) + w_draws = jnp.exp(μ + σ * s) + return McCallModelContinuous(c=c, β=β, σ=σ, μ=μ, w_draws=w_draws) + +@jax.jit def compute_reservation_wage_continuous(mcmc, max_iter=500, tol=1e-5): - c, β, σ, μ, w_draws = mcmc.c, mcmc.β, mcmc.σ, mcmc.μ, mcmc.w_draws - - h = np.mean(w_draws) / (1 - β) # initial guess - i = 0 - error = tol + 1 - while i < max_iter and error > tol: - - integral = np.mean(np.maximum(w_draws / (1 - β), h)) + + h = jnp.mean(w_draws) / (1 - β) # initial guess + + def body_fun(state): + h, i, error = state + integral = jnp.mean(jnp.maximum(w_draws / (1 - β), h)) h_next = c + β * integral - - error = np.abs(h_next - h) - i += 1 - - h = h_next - - # == Now compute the reservation wage == # - - return (1 - β) * h + error = jnp.abs(h_next - h) + return h_next, i + 1, error + + def cond_fun(state): + h, i, error = state + return jnp.logical_and(i < max_iter, error > tol) + + initial_state = (h, 0, tol + 1) + h_final, _, _ = jax.lax.while_loop(cond_fun, body_fun, initial_state) + + # Now compute the reservation wage + return (1 - β) * h_final ``` Now we investigate how the reservation wage changes with $c$ and @@ -840,15 +813,20 @@ We will do this using a contour plot. ```{code-cell} python3 grid_size = 25 -R = np.empty((grid_size, grid_size)) - -c_vals = np.linspace(10.0, 30.0, grid_size) -β_vals = np.linspace(0.9, 0.99, grid_size) - -for i, c in enumerate(c_vals): - for j, β in enumerate(β_vals): - mcmc = McCallModelContinuous(c=c, β=β) - R[i, j] = compute_reservation_wage_continuous(mcmc) +c_vals = jnp.linspace(10.0, 30.0, grid_size) +β_vals = jnp.linspace(0.9, 0.99, grid_size) + +def compute_R_element(c, β): + mcmc = create_mccall_continuous(c=c, β=β) + return compute_reservation_wage_continuous(mcmc) + +# Create meshgrid and vectorize computation +c_grid, β_grid = jnp.meshgrid(c_vals, β_vals, indexing='ij') +compute_R_vectorized = jax.vmap( + jax.vmap(compute_R_element, + in_axes=(None, 0)), + in_axes=(0, None)) +R = compute_R_vectorized(c_vals, β_vals) ``` ```{code-cell} python3