diff --git a/lectures/career.md b/lectures/career.md index b826ae9dc..605c9a6a8 100644 --- a/lectures/career.md +++ b/lectures/career.md @@ -3,8 +3,10 @@ jupytext: text_representation: extension: .md format_name: myst + format_version: 0.13 + jupytext_version: 1.16.6 kernelspec: - display_name: Python 3 + display_name: Python 3 (ipykernel) language: python name: python3 --- @@ -29,11 +31,10 @@ kernelspec: In addition to what's in Anaconda, this lecture will need the following libraries: -```{code-cell} ipython ---- -tags: [hide-output] ---- -!pip install quantecon +```{code-cell} ipython3 +:tags: [hide-output] + +!pip install quantecon jax ``` ## Overview @@ -46,18 +47,22 @@ This exposition draws on the presentation in {cite}`Ljungqvist2012`, section 6.5 We begin with some imports: -```{code-cell} ipython +```{code-cell} ipython3 import matplotlib.pyplot as plt -import numpy as np -import quantecon as qe -from numba import jit, prange +import jax.numpy as jnp +import jax +import jax.random as jr +from typing import NamedTuple from quantecon.distributions import BetaBinomial from scipy.special import binom, beta from mpl_toolkits.mplot3d.axes3d import Axes3D from matplotlib import cm + +# Set JAX to use CPU +jax.config.update('jax_platform_name', 'cpu') ``` -### Model Features +### Model features * Career and job within career both chosen to maximize expected discounted wage flow. * Infinite horizon dynamic programming with two state variables. @@ -69,7 +74,7 @@ In what follows we distinguish between a career and a job, where * a *career* is understood to be a general field encompassing many possible jobs, and * a *job* is understood to be a position with a particular firm -For workers, wages can be decomposed into the contribution of job and career +For workers, wages can be decomposed into the contributions of job and career * $w_t = \theta_t + \epsilon_t$, where * $\theta_t$ is the contribution of career at time $t$ @@ -132,14 +137,14 @@ Evidently $I$, $II$ and $III$ correspond to "stay put", "new job" and "new life" As in {cite}`Ljungqvist2012`, section 6.5, we will focus on a discrete version of the model, parameterized as follows: * both $\theta$ and $\epsilon$ take values in the set - `np.linspace(0, B, grid_size)` --- an even grid of points between + `jnp.linspace(0, B, grid_size)` --- an even grid of points between $0$ and $B$ inclusive * `grid_size = 50` * `B = 5` * `β = 0.95` The distributions $F$ and $G$ are discrete distributions -generating draws from the grid points `np.linspace(0, B, grid_size)`. +generating draws from the grid points `jnp.linspace(0, B, grid_size)`. A very useful family of discrete distributions is the Beta-binomial family, with probability mass function @@ -163,11 +168,11 @@ Nice properties: Here's a figure showing the effect on the pmf of different shape parameters when $n=50$. -```{code-cell} python3 +```{code-cell} ipython3 def gen_probs(n, a, b): - probs = np.zeros(n+1) - for k in range(n+1): - probs[k] = binom(n, k) * beta(k + a, n - k + b) / beta(a, b) + probs = jnp.zeros(n+1) + k_vals = jnp.arange(n+1) + probs = jnp.array([binom(n, k) * beta(k + a, n - k + b) / beta(a, b) for k in range(n+1)]) return probs n = 50 @@ -183,161 +188,156 @@ plt.show() ## Implementation -We will first create a class `CareerWorkerProblem` which will hold the -default parameterizations of the model and an initial guess for the value function. - -```{code-cell} python3 -class CareerWorkerProblem: - - def __init__(self, - B=5.0, # Upper bound - β=0.95, # Discount factor - grid_size=50, # Grid size - F_a=1, - F_b=1, - G_a=1, - G_b=1): - - self.β, self.grid_size, self.B = β, grid_size, B - - self.θ = np.linspace(0, B, grid_size) # Set of θ values - self.ϵ = np.linspace(0, B, grid_size) # Set of ϵ values - - self.F_probs = BetaBinomial(grid_size - 1, F_a, F_b).pdf() - self.G_probs = BetaBinomial(grid_size - 1, G_a, G_b).pdf() - self.F_mean = self.θ @ self.F_probs - self.G_mean = self.ϵ @ self.G_probs - - # Store these parameters for str and repr methods - self._F_a, self._F_b = F_a, F_b - self._G_a, self._G_b = G_a, G_b +We will first create a JAX-compatible model structure using `NamedTuple` to store +the model parameters and computed distributions. + +```{code-cell} ipython3 +class CareerWorkerProblem(NamedTuple): + β: float # Discount factor + B: float # Upper bound + grid_size: int # Grid size + θ: jnp.ndarray # Set of θ values + ε: jnp.ndarray # Set of ε values + F_probs: jnp.ndarray # Probabilities for F distribution + G_probs: jnp.ndarray # Probabilities for G distribution + F_mean: float # Mean of F distribution + G_mean: float # Mean of G distribution + +def create_career_worker_problem(B=5.0, β=0.95, grid_size=50, + F_a=1, F_b=1, G_a=1, G_b=1): + """ + Factory function to create a CareerWorkerProblem instance. + """ + θ = jnp.linspace(0, B, grid_size) # Set of θ values + ε = jnp.linspace(0, B, grid_size) # Set of ε values + + F_probs = jnp.array(BetaBinomial(grid_size - 1, F_a, F_b).pdf()) + G_probs = jnp.array(BetaBinomial(grid_size - 1, G_a, G_b).pdf()) + F_mean = θ @ F_probs + G_mean = ε @ G_probs + + return CareerWorkerProblem( + β=β, B=B, grid_size=grid_size, + θ=θ, ε=ε, + F_probs=F_probs, G_probs=G_probs, + F_mean=F_mean, G_mean=G_mean + ) ``` -The following function takes an instance of `CareerWorkerProblem` and returns -the corresponding Bellman operator $T$ and the greedy policy function. +The following functions implement the Bellman operator $T$ and the greedy policy function +using JAX. In this model, $T$ is defined by $Tv(\theta, \epsilon) = \max\{I, II, III\}$, where $I$, $II$ and $III$ are as given in {eq}`eyes`. -```{code-cell} python3 -def operator_factory(cw, parallel_flag=True): +```{code-cell} ipython3 +@jax.jit +def Q(θ_grid, ε_grid, β, v, F_probs, G_probs, F_mean, G_mean): + # Option 1: Stay put + v1 = θ_grid + ε_grid + β * v - """ - Returns jitted versions of the Bellman operator and the - greedy policy function + # Option 2: New job (keep θ, new ε) + ev_new_job = jnp.dot(v, G_probs) # Expected value for each θ + v2 = θ_grid + G_mean + β * ev_new_job[:, jnp.newaxis] - cw is an instance of ``CareerWorkerProblem`` - """ + # Option 3: New life (new θ and new ε) + ev_new_life = jnp.dot(F_probs, jnp.dot(v, G_probs)) + v3 = jnp.full_like(v, G_mean + F_mean + β * ev_new_life) - θ, ϵ, β = cw.θ, cw.ϵ, cw.β - F_probs, G_probs = cw.F_probs, cw.G_probs - F_mean, G_mean = cw.F_mean, cw.G_mean + return v1, v2, v3 - @jit(parallel=parallel_flag) - def T(v): - "The Bellman operator" - - v_new = np.empty_like(v) +@jax.jit +def bellman_operator(model, v): + """ + The Bellman operator for the career choice model. + """ + θ, ε, β = model.θ, model.ε, model.β + F_probs, G_probs = model.F_probs, model.G_probs + F_mean, G_mean = model.F_mean, model.G_mean - for i in prange(len(v)): - for j in prange(len(v)): - v1 = θ[i] + ϵ[j] + β * v[i, j] # Stay put - v2 = θ[i] + G_mean + β * v[i, :] @ G_probs # New job - v3 = G_mean + F_mean + β * F_probs @ v @ G_probs # New life - v_new[i, j] = max(v1, v2, v3) + v1, v2, v3 = Q( + *jnp.meshgrid(θ, ε, indexing='ij'), + β, v, F_probs, G_probs, F_mean, G_mean + ) - return v_new + return jnp.maximum(jnp.maximum(v1, v2), v3) - @jit - def get_greedy(v): - "Computes the v-greedy policy" +@jax.jit +def get_greedy_policy(model, v): + """ + Computes the greedy policy given the value function. + * Policy function where 1=stay put, 2=new job, 3=new life + """ + θ, ε, β = model.θ, model.ε, model.β + F_probs, G_probs = model.F_probs, model.G_probs + F_mean, G_mean = model.F_mean, model.G_mean - σ = np.empty(v.shape) + v1, v2, v3 = Q( + *jnp.meshgrid(θ, ε, indexing='ij'), + β, v, F_probs, G_probs, F_mean, G_mean + ) - for i in range(len(v)): - for j in range(len(v)): - v1 = θ[i] + ϵ[j] + β * v[i, j] - v2 = θ[i] + G_mean + β * v[i, :] @ G_probs - v3 = G_mean + F_mean + β * F_probs @ v @ G_probs - if v1 > max(v2, v3): - action = 1 - elif v2 > max(v1, v3): - action = 2 - else: - action = 3 - σ[i, j] = action + # Stack the value arrays and find argmax along first axis + values = jnp.stack([v1, v2, v3], axis=0) - return σ + # +1 because actions are 1, 2, 3 not 0, 1, 2 + policy = jnp.argmax(values, axis=0) + 1 - return T, get_greedy + return policy ``` -Lastly, `solve_model` will take an instance of `CareerWorkerProblem` and +Lastly, `solve_model` will take an instance of `CareerWorkerProblem` and iterate using the Bellman operator to find the fixed point of the Bellman equation. -```{code-cell} python3 -def solve_model(cw, - use_parallel=True, - tol=1e-4, - max_iter=1000, - verbose=True, - print_skip=25): - - T, _ = operator_factory(cw, parallel_flag=use_parallel) - - # Set up loop - v = np.full((cw.grid_size, cw.grid_size), 100.) # Initial guess - i = 0 +```{code-cell} ipython3 +def solve_model(model, tol=1e-4, max_iter=1000): + """ + Solve the career choice model using JAX. + """ + # Initial guess + v = jnp.full((model.grid_size, model.grid_size), 100.0) error = tol + 1 + i = 0 while i < max_iter and error > tol: - v_new = T(v) - error = np.max(np.abs(v - v_new)) - i += 1 - if verbose and i % print_skip == 0: - print(f"Error at iteration {i} is {error}.") + v_new = bellman_operator(model, v) + error = jnp.max(jnp.abs(v_new - v)) v = v_new + i += 1 - if error > tol: - print("Failed to converge!") - - elif verbose: - print(f"\nConverged in {i} iterations.") - - return v_new + return v ``` Here's the solution to the model -- an approximate value function -```{code-cell} python3 -cw = CareerWorkerProblem() -T, get_greedy = operator_factory(cw) -v_star = solve_model(cw, verbose=False) -greedy_star = get_greedy(v_star) +```{code-cell} ipython3 +model = create_career_worker_problem() +v_star = solve_model(model) +greedy_star = get_greedy_policy(model, v_star) fig = plt.figure(figsize=(8, 6)) ax = fig.add_subplot(111, projection='3d') -tg, eg = np.meshgrid(cw.θ, cw.ϵ) +tg, eg = jnp.meshgrid(model.θ, model.ε) ax.plot_surface(tg, eg, v_star.T, cmap=cm.jet, alpha=0.5, linewidth=0.25) -ax.set(xlabel='θ', ylabel='ϵ', zlim=(150, 200)) +ax.set(xlabel='θ', ylabel='ε', zlim=(150, 200)) ax.view_init(ax.elev, 225) plt.show() ``` And here is the optimal policy -```{code-cell} python3 +```{code-cell} ipython3 fig, ax = plt.subplots(figsize=(6, 6)) -tg, eg = np.meshgrid(cw.θ, cw.ϵ) +tg, eg = jnp.meshgrid(model.θ, model.ε) lvls = (0.5, 1.5, 2.5, 3.5) ax.contourf(tg, eg, greedy_star.T, levels=lvls, cmap=cm.winter, alpha=0.5) ax.contour(tg, eg, greedy_star.T, colors='k', levels=lvls, linewidths=2) -ax.set(xlabel='θ', ylabel='ϵ') +ax.set(xlabel='θ', ylabel='ε') ax.text(1.8, 2.5, 'new life', fontsize=14) ax.text(4.5, 2.5, 'new job', fontsize=14, rotation='vertical') ax.text(4.0, 4.5, 'stay put', fontsize=14) @@ -375,7 +375,7 @@ In particular, modulo randomness, reproduce the following figure (where the hori ```{hint} :class: dropdown -To generate the draws from the distributions $F$ and $G$, use `quantecon.random.draw()`. +To generate the draws from the distributions $F$ and $G$, use `quantecon.jr.draw()`. ``` ```{exercise-end} @@ -392,39 +392,53 @@ In reading the code, recall that `optimal_policy[i, j]` = policy at $(\theta_i, \epsilon_j)$ = either 1, 2 or 3; meaning 'stay put', 'new job' and 'new life'. -```{code-cell} python3 -F = np.cumsum(cw.F_probs) -G = np.cumsum(cw.G_probs) -v_star = solve_model(cw, verbose=False) -T, get_greedy = operator_factory(cw) -greedy_star = get_greedy(v_star) - -def gen_path(optimal_policy, F, G, t=20): - i = j = 0 - θ_index = [] - ϵ_index = [] - for t in range(t): - if optimal_policy[i, j] == 1: # Stay put - pass - - elif greedy_star[i, j] == 2: # New job - j = qe.random.draw(G) - - else: # New life - i, j = qe.random.draw(F), qe.random.draw(G) - θ_index.append(i) - ϵ_index.append(j) - return cw.θ[θ_index], cw.ϵ[ϵ_index] +```{code-cell} ipython3 +model = create_career_worker_problem() +F = jnp.cumsum(jnp.asarray(model.F_probs)) +G = jnp.cumsum(jnp.asarray(model.G_probs)) +v_star = solve_model(model) +greedy_star = jnp.asarray(get_greedy_policy(model, v_star)) +def draw_from_cdf(key, cdf): + u = jr.uniform(key) + return jnp.searchsorted(cdf, u, side="left") +def gen_path(optimal_policy, F, G, model, t=20, key=None): + if key is None: + key = jr.PRNGKey(0) + i = 0 + j = 0 + theta_idx = [] + eps_idx = [] + for _ in range(t): + a = optimal_policy[i, j] + key, k1, k2 = jr.split(key, 3) + if a == 1: # Stay put + pass + elif a == 2: # New job + j = draw_from_cdf(k1, G) + else: # New life + i = draw_from_cdf(k1, F) + j = draw_from_cdf(k2, G) + theta_idx.append(i) + eps_idx.append(j) + + theta_idx = jnp.array(theta_idx, dtype=jnp.int32) + eps_idx = jnp.array(eps_idx, dtype=jnp.int32) + return model.θ[theta_idx], model.ε[eps_idx], key + +key = jr.PRNGKey(42) fig, axes = plt.subplots(2, 1, figsize=(10, 8)) + for ax in axes: - θ_path, ϵ_path = gen_path(greedy_star, F, G) - ax.plot(ϵ_path, label='ϵ') + key, subkey = jr.split(key) + θ_path, ε_path, _ = gen_path(greedy_star, F, G, model, key=subkey) + ax.plot(ε_path, label='ε') ax.plot(θ_path, label='θ') ax.set_ylim(0, 6) + ax.legend(loc='upper right') -plt.legend() +plt.tight_layout() plt.show() ``` @@ -464,40 +478,42 @@ Repeat the exercise with $\beta=0.99$ and interpret the change. The median for the original parameterization can be computed as follows -```{code-cell} python3 -cw = CareerWorkerProblem() -F = np.cumsum(cw.F_probs) -G = np.cumsum(cw.G_probs) -T, get_greedy = operator_factory(cw) -v_star = solve_model(cw, verbose=False) -greedy_star = get_greedy(v_star) - -@jit -def passage_time(optimal_policy, F, G): - t = 0 - i = j = 0 - while True: - if optimal_policy[i, j] == 1: # Stay put - return t - elif optimal_policy[i, j] == 2: # New job - j = qe.random.draw(G) - else: # New life - i, j = qe.random.draw(F), qe.random.draw(G) - t += 1 - -@jit(parallel=True) -def median_time(optimal_policy, F, G, M=25000): - samples = np.empty(M) - for i in prange(M): - samples[i] = passage_time(optimal_policy, F, G) - return np.median(samples) +```{code-cell} ipython3 +model = create_career_worker_problem() +F = jnp.cumsum(jnp.asarray(model.F_probs)) +G = jnp.cumsum(jnp.asarray(model.G_probs)) +v_star = solve_model(model) +greedy_star = jnp.asarray(get_greedy_policy(model, v_star)) + +def passage_time(optimal_policy, F, G, key): + def cond(state): + i, j, t, key = state + return optimal_policy[i, j] != 1 + + def body(state): + i, j, t, key = state + a = optimal_policy[i, j] + key, k1, k2 = jr.split(key, 3) + new_j = draw_from_cdf(k1, G) + new_i = draw_from_cdf(k2, F) + i = jnp.where(a == 3, new_i, i) + j = jnp.where((a == 2) | (a == 3), new_j, j) + return i, j, t + 1, key + + i, j, t, _ = jax.lax.while_loop(cond, body, (0, 0, 0, key)) + return t + +def median_time(optimal_policy, F, G, M=25000, seed=0): + keys = jr.split(jr.PRNGKey(seed), M) + times = jax.vmap(lambda k: passage_time(optimal_policy, F, G, k))(keys) + return jnp.median(times) median_time(greedy_star, F, G) ``` To compute the median with $\beta=0.99$ instead of the default -value $\beta=0.95$, replace `cw = CareerWorkerProblem()` with -`cw = CareerWorkerProblem(β=0.99)`. +value $\beta=0.95$, replace `model = create_career_worker_problem()` with +`model = create_career_worker_problem(β=0.99)`. The medians are subject to randomness but should be about 7 and 14 respectively. @@ -520,18 +536,17 @@ figure -- interpret. Here is one solution -```{code-cell} python3 -cw = CareerWorkerProblem(G_a=100, G_b=100) -T, get_greedy = operator_factory(cw) -v_star = solve_model(cw, verbose=False) -greedy_star = get_greedy(v_star) +```{code-cell} ipython3 +model = create_career_worker_problem(G_a=100, G_b=100) +v_star = solve_model(model) +greedy_star = get_greedy_policy(model, v_star) fig, ax = plt.subplots(figsize=(6, 6)) -tg, eg = np.meshgrid(cw.θ, cw.ϵ) +tg, eg = jnp.meshgrid(model.θ, model.ε) lvls = (0.5, 1.5, 2.5, 3.5) ax.contourf(tg, eg, greedy_star.T, levels=lvls, cmap=cm.winter, alpha=0.5) ax.contour(tg, eg, greedy_star.T, colors='k', levels=lvls, linewidths=2) -ax.set(xlabel='θ', ylabel='ϵ') +ax.set(xlabel='θ', ylabel='ε') ax.text(1.8, 2.5, 'new life', fontsize=14) ax.text(4.5, 1.5, 'new job', fontsize=14, rotation='vertical') ax.text(4.0, 4.5, 'stay put', fontsize=14)