diff --git a/lectures/jv.md b/lectures/jv.md index 8312ab065..a69c9a041 100644 --- a/lectures/jv.md +++ b/lectures/jv.md @@ -27,6 +27,14 @@ kernelspec: :depth: 2 ``` +In addition to what's in Anaconda, this lecture will need the following library: + +```{code-cell} ipython3 +:tags: [hide-output] + +!pip install jax +``` + ## Overview In this section, we solve a simple on-the-job search model @@ -35,14 +43,19 @@ In this section, we solve a simple on-the-job search model Let's start with some imports: -```{code-cell} ipython +```{code-cell} ipython3 import matplotlib.pyplot as plt -import numpy as np +import jax +import jax.numpy as jnp +import jax.random as jr import scipy.stats as stats -from numba import jit, prange +from typing import NamedTuple + +# Set JAX to use CPU +jax.config.update('jax_platform_name', 'cpu') ``` -### Model Features +### Model features ```{index} single: On-the-Job Search; Model Features ``` @@ -127,7 +140,7 @@ with default parameter values The $\text{Beta}(2,2)$ distribution is supported on $(0,1)$ - it has a unimodal, symmetric density peaked at 0.5. (jvboecalc)= -### Back-of-the-Envelope Calculations +### Back-of-the-envelope calculations Before we solve the model, let's make some quick calculations that provide intuition on what the solution should look like. @@ -174,43 +187,52 @@ Now let's turn to implementation, and see if we can match our predictions. ```{index} single: On-the-Job Search; Programming Implementation ``` -We will set up a class `JVWorker` that holds the parameters of the model described above +We will set up a NamedTuple that holds the parameters of the model described above ```{code-cell} python3 -class JVWorker: - r""" - A Jovanovic-type model of employment with on-the-job search. - +class JVWorker(NamedTuple): + A: float + α: float + β: float # Discount factor + a: float # Parameter of f + b: float # Parameter of f + grid_size: int + mc_size: int + ɛ: float + x_grid: jnp.ndarray + f_rvs: jnp.ndarray + +def create_jv_worker(A=1.4, α=0.6, β=0.96, a=2, b=2, + grid_size=50, mc_size=100, ɛ=1e-4, + key=jr.PRNGKey(0)): """ - - def __init__(self, - A=1.4, - α=0.6, - β=0.96, # Discount factor - π=np.sqrt, # Search effort function - a=2, # Parameter of f - b=2, # Parameter of f - grid_size=50, - mc_size=100, - ɛ=1e-4): - - self.A, self.α, self.β, self.π = A, α, β, π - self.mc_size, self.ɛ = mc_size, ɛ - - self.g = jit(lambda x, ϕ: A * (x * ϕ)**α) # Transition function - self.f_rvs = np.random.beta(a, b, mc_size) - - # Max of grid is the max of a large quantile value for f and the - # fixed point y = g(y, 1) - ɛ = 1e-4 - grid_max = max(A**(1 / (1 - α)), stats.beta(a, b).ppf(1 - ɛ)) - - # Human capital - self.x_grid = np.linspace(ɛ, grid_max, grid_size) + Create a JVWorker instance with computed grids and random draws. + """ + # Generate random draws for Monte Carlo integration + f_rvs = jr.beta(key, a, b, (mc_size,)) + + # Max of grid is the max of a large quantile value for f and the + # fixed point y = g(y, 1) + grid_max = max(A**(1 / (1 - α)), stats.beta.ppf(1 - ɛ, a, b)) + + # Human capital grid + x_grid = jnp.linspace(ɛ, grid_max, grid_size) + + return JVWorker(A=A, α=α, β=β, a=a, b=b, grid_size=grid_size, + mc_size=mc_size, ɛ=ɛ, x_grid=x_grid, f_rvs=f_rvs) + +@jax.jit +def g(jv, x, ϕ): + """Transition function for human capital accumulation.""" + return jv.A * (x * ϕ)**jv.α + +@jax.jit +def π(s): + """Search effort function.""" + return jnp.sqrt(s) ``` -The function `operator_factory` takes an instance of this class and returns a -jitted version of the Bellman operator `T`, i.e. +Now we define the Bellman operator `T`, i.e. $$ Tv(x) @@ -227,136 +249,128 @@ w(s, \phi) \beta \pi(s) \int v[g(x, \phi) \vee u] f(du) ``` -When we represent $v$, it will be with a NumPy array `v` giving values on grid `x_grid`. +When we represent $v$, it will be with a JAX array `v` giving values on grid `x_grid`. But to evaluate the right-hand side of {eq}`defw`, we need a function, so -we replace the arrays `v` and `x_grid` with a function `v_func` that gives linear -interpolation of `v` on `x_grid`. - -Inside the `for` loop, for each `x` in the grid over the state space, we -set up the function $w(z) = w(s, \phi)$ defined in {eq}`defw`. - -The function is maximized over all feasible $(s, \phi)$ pairs. - -Another function, `get_greedy` returns the optimal choice of $s$ and $\phi$ -at each $x$, given a value function. +we use JAX interpolation of `v` on `x_grid`. ```{code-cell} python3 -def operator_factory(jv, parallel_flag=True): - +@jax.jit +def state_action_values(jv, s_phi, x, v): """ - Returns a jitted version of the Bellman operator T - - jv is an instance of JVWorker - + Computes the value of state-action pair (x, s, phi) given value function v. """ - - π, β = jv.π, jv.β - x_grid, ɛ, mc_size = jv.x_grid, jv.ɛ, jv.mc_size - f_rvs, g = jv.f_rvs, jv.g - - @jit - def state_action_values(z, x, v): - s, ϕ = z - v_func = lambda x: np.interp(x, x_grid, v) - - integral = 0 - for m in range(mc_size): - u = f_rvs[m] - integral += v_func(max(g(x, ϕ), u)) - integral = integral / mc_size - - q = π(s) * integral + (1 - π(s)) * v_func(g(x, ϕ)) - return x * (1 - ϕ - s) + β * q - - @jit(parallel=parallel_flag) - def T(v): - """ - The Bellman operator - """ - - v_new = np.empty_like(v) - for i in prange(len(x_grid)): - x = x_grid[i] - - # Search on a grid - search_grid = np.linspace(ɛ, 1, 15) - max_val = -1 - for s in search_grid: - for ϕ in search_grid: - current_val = state_action_values((s, ϕ), x, v) if s + ϕ <= 1 else -1 - if current_val > max_val: - max_val = current_val - v_new[i] = max_val - - return v_new - - @jit - def get_greedy(v): - """ - Computes the v-greedy policy of a given function v - """ - s_policy, ϕ_policy = np.empty_like(v), np.empty_like(v) - - for i in range(len(x_grid)): - x = x_grid[i] - # Search on a grid - search_grid = np.linspace(ɛ, 1, 15) - max_val = -1 - for s in search_grid: - for ϕ in search_grid: - current_val = state_action_values((s, ϕ), x, v) if s + ϕ <= 1 else -1 - if current_val > max_val: - max_val = current_val - max_s, max_ϕ = s, ϕ - s_policy[i], ϕ_policy[i] = max_s, max_ϕ - return s_policy, ϕ_policy - - return T, get_greedy + s, ϕ = s_phi + β = jv.β + x_grid, f_rvs = jv.x_grid, jv.f_rvs + + v_func = lambda x_val: jnp.interp(x_val, x_grid, v) + + # Monte Carlo integration over offers + def compute_offer_value(u): + return v_func(jnp.maximum(g(jv, x, ϕ), u)) + + integral = jnp.mean(jax.vmap(compute_offer_value)(f_rvs)) + + q = π(s) * integral + (1 - π(s)) * v_func(g(jv, x, ϕ)) + return x * (1 - ϕ - s) + β * q + +@jax.jit +def T(jv, v): + """ + The Bellman operator. + """ + x_grid, ɛ = jv.x_grid, jv.ɛ + + def maximize_at_x(x): + # Create grid for optimization + search_grid = jnp.linspace(ɛ, 1, 15) + + def objective(s_phi): + s, ϕ = s_phi + # Return negative value if constraint violated + constraint_satisfied = s + ϕ <= 1.0 + value = state_action_values(jv, s_phi, x, v) + return jnp.where(constraint_satisfied, value, -jnp.inf) + + # Grid search over feasible (s, ϕ) pairs + s_vals, phi_vals = jnp.meshgrid(search_grid, search_grid) + s_phi_pairs = jnp.stack( + [s_vals.ravel(), phi_vals.ravel()], axis=1) + + # Evaluate objective at all grid points + values = jax.vmap(objective)(s_phi_pairs) + max_idx = jnp.argmax(values) + return values[max_idx] + + return jax.vmap(maximize_at_x)(x_grid) + +@jax.jit +def get_greedy(jv, v): + """ + Computes the v-greedy policy. + """ + x_grid, ɛ = jv.x_grid, jv.ɛ + + def greedy_at_x(x): + # Create grid for optimization + search_grid = jnp.linspace(ɛ, 1, 15) + + def objective(s_phi): + s, ϕ = s_phi + # Return negative value if constraint violated + constraint_satisfied = s + ϕ <= 1.0 + value = state_action_values(jv, s_phi, x, v) + return jnp.where(constraint_satisfied, value, -jnp.inf) + + # Grid search over feasible (s, ϕ) pairs + s_vals, phi_vals = jnp.meshgrid(search_grid, search_grid) + s_phi_pairs = jnp.stack( + [s_vals.ravel(), phi_vals.ravel()], axis=1) + + # Evaluate objective at all grid points + values = jax.vmap(objective)(s_phi_pairs) + max_idx = jnp.argmax(values) + return s_phi_pairs[max_idx] + + policies = jax.vmap(greedy_at_x)(x_grid) + return policies[:, 0], policies[:, 1] # s_policy, ϕ_policy ``` To solve the model, we will write a function that uses the Bellman operator and iterates to find a fixed point. ```{code-cell} python3 -def solve_model(jv, - use_parallel=True, - tol=1e-4, - max_iter=1000, - verbose=True, - print_skip=25): - +@jax.jit +def solve_model(jv, tol=1e-4, max_iter=1000): """ Solves the model by value function iteration * jv is an instance of JVWorker - """ - - T, _ = operator_factory(jv, parallel_flag=use_parallel) - - # Set up loop - v = jv.x_grid * 0.5 # Initial condition - i = 0 - error = tol + 1 - - while i < max_iter and error > tol: - v_new = T(v) - error = np.max(np.abs(v - v_new)) - i += 1 - if verbose and i % print_skip == 0: - print(f"Error at iteration {i} is {error}.") - v = v_new - - if error > tol: - print("Failed to converge!") - elif verbose: - print(f"\nConverged in {i} iterations.") - - return v_new + + def cond_fun(state): + v, i, error = state + return jnp.logical_and(error > tol, i < max_iter) + + def body_fun(state): + v, i, error = state + v_new = T(jv, v) + error_new = jnp.max(jnp.abs(v - v_new)) + return v_new, i + 1, error_new + + # Initial state + v_init = jv.x_grid * 0.5 # Initial condition + init_state = (v_init, 0, tol + 1) + + # Run iteration + v_final, iterations, final_error = jax.lax.while_loop( + cond_fun, body_fun, init_state) + + return v_final ``` -## Solving for Policies +## Solving for policies ```{index} single: On-the-Job Search; Solving for Policies ``` @@ -365,10 +379,9 @@ Let's generate the optimal policies and see what they look like. (jv_policies)= ```{code-cell} python3 -jv = JVWorker() -T, get_greedy = operator_factory(jv) +jv = create_jv_worker() v_star = solve_model(jv) -s_star, ϕ_star = get_greedy(v_star) +s_star, ϕ_star = get_greedy(jv, v_star) ``` Here are the plots: @@ -382,7 +395,6 @@ fig, axes = plt.subplots(3, 1, figsize=(12, 12)) for ax, plot, title in zip(axes, plots, titles): ax.plot(jv.x_grid, plot) ax.set(title=title) - ax.grid() axes[-1].set_xlabel("x") plt.show() @@ -421,7 +433,7 @@ diagram, setting ```{code-block} python3 jv = JVWorker(grid_size=25, mc_size=50) plot_grid_max, plot_grid_size = 1.2, 100 -plot_grid = np.linspace(0, plot_grid_max, plot_grid_size) +plot_grid = jnp.linspace(0, plot_grid_max, plot_grid_size) fig, ax = plt.subplots() ax.set_xlim(0, plot_grid_max) ax.set_ylim(0, plot_grid_max) @@ -439,25 +451,24 @@ Argue that at the steady state, $s_t \approx 0$ and $\phi_t \approx 0.6$. :class: dropdown ``` -Here’s code to produce the 45 degree diagram +Here's code to produce the 45 degree diagram ```{code-cell} python3 -jv = JVWorker(grid_size=25, mc_size=50) -π, g, f_rvs, x_grid = jv.π, jv.g, jv.f_rvs, jv.x_grid -T, get_greedy = operator_factory(jv) -v_star = solve_model(jv, verbose=False) -s_policy, ϕ_policy = get_greedy(v_star) +jv = create_jv_worker(grid_size=25, mc_size=50) +f_rvs, x_grid = jv.f_rvs, jv.x_grid +v_star = solve_model(jv) +s_policy, ϕ_policy = get_greedy(jv, v_star) # Turn the policy function arrays into actual functions -s = lambda y: np.interp(y, x_grid, s_policy) -ϕ = lambda y: np.interp(y, x_grid, ϕ_policy) +s = lambda y: jnp.interp(y, x_grid, s_policy) +ϕ = lambda y: jnp.interp(y, x_grid, ϕ_policy) def h(x, b, u): - return (1 - b) * g(x, ϕ(x)) + b * max(g(x, ϕ(x)), u) + return (1 - b) * g(jv, x, ϕ(x)) + b * max(g(jv, x, ϕ(x)), u) plot_grid_max, plot_grid_size = 1.2, 100 -plot_grid = np.linspace(0, plot_grid_max, plot_grid_size) +plot_grid = jnp.linspace(0, plot_grid_max, plot_grid_size) fig, ax = plt.subplots(figsize=(8, 8)) ticks = (0.25, 0.5, 0.75, 1.0) ax.set(xticks=ticks, yticks=ticks, @@ -466,9 +477,13 @@ ax.set(xticks=ticks, yticks=ticks, xlabel='$x_t$', ylabel='$x_{t+1}$') ax.plot(plot_grid, plot_grid, 'k--', alpha=0.6) # 45 degree line + +# Generate random values for plotting +key = jr.PRNGKey(0) for x in plot_grid: for i in range(jv.mc_size): - b = 1 if np.random.uniform(0, 1) < π(s(x)) else 0 + key, subkey = jr.split(key) + b = 1 if jr.uniform(subkey) < π(s(x)) else 0 u = f_rvs[i] y = h(x, b, u) ax.plot(x, y, 'go', alpha=0.25) @@ -524,13 +539,13 @@ Can you give a rough interpretation for the value that you see? The figure can be produced as follows ```{code-cell} python3 -jv = JVWorker() +jv = create_jv_worker() def xbar(ϕ): A, α = jv.A, jv.α return (A * ϕ**α)**(1 / (1 - α)) -ϕ_grid = np.linspace(0, 1, 100) +ϕ_grid = jnp.linspace(0, 1, 100) fig, ax = plt.subplots(figsize=(9, 7)) ax.set(xlabel=r'$\phi$') ax.plot(ϕ_grid, [xbar(ϕ) * (1 - ϕ) for ϕ in ϕ_grid], label=r'$w^*(\phi)$')