@@ -4,11 +4,11 @@ jupytext:
44 extension : .md
55 format_name : myst
66 format_version : 0.13
7- jupytext_version : 1.17.1
7+ jupytext_version : 1.16.6
88kernelspec :
9- name : python3
109 display_name : Python 3 (ipykernel)
1110 language : python
11+ name : python3
1212---
1313
1414(career)=
@@ -34,7 +34,15 @@ In addition to what's in Anaconda, this lecture will need the following librarie
3434``` {code-cell} ipython3
3535:tags: [hide-output]
3636
37- !pip install quantecon
37+ !pip install --upgrade quantecon
38+ ```
39+
40+ We also need to install JAX to run this lecture
41+
42+ ``` {code-cell} ipython3
43+ :tags: [skip-execution]
44+
45+ !pip install -U jax
3846```
3947
4048## Overview
@@ -51,15 +59,15 @@ We begin with some imports:
5159import matplotlib.pyplot as plt
5260import jax.numpy as jnp
5361import jax
54- import jax.random as jr
62+ import jax.random as jr
5563from typing import NamedTuple
5664from quantecon.distributions import BetaBinomial
5765from scipy.special import binom, beta
5866from mpl_toolkits.mplot3d.axes3d import Axes3D
5967from matplotlib import cm
6068```
6169
62- ### Model Features
70+ ### Model features
6371
6472* Career and job within career both chosen to maximize expected discounted wage flow.
6573* Infinite horizon dynamic programming with two state variables.
@@ -71,7 +79,7 @@ In what follows we distinguish between a career and a job, where
7179* a * career* is understood to be a general field encompassing many possible jobs, and
7280* a * job* is understood to be a position with a particular firm
7381
74- For workers, wages can be decomposed into the contribution of job and career
82+ For workers, wages can be decomposed into the contributions of job and career
7583
7684* $w_t = \theta_t + \epsilon_t$, where
7785 * $\theta_t$ is the contribution of career at time $t$
@@ -134,14 +142,14 @@ Evidently $I$, $II$ and $III$ correspond to "stay put", "new job" and "new life"
134142As in {cite}` Ljungqvist2012 ` , section 6.5, we will focus on a discrete version of the model, parameterized as follows:
135143
136144* both $\theta$ and $\epsilon$ take values in the set
137- ` np .linspace(0, B, grid_size)` --- an even grid of points between
145+ ` jnp .linspace(0, B, grid_size)` --- an even grid of points between
138146 $0$ and $B$ inclusive
139147* ` grid_size = 50 `
140148* ` B = 5 `
141149* ` β = 0.95 `
142150
143151The distributions $F$ and $G$ are discrete distributions
144- generating draws from the grid points ` np .linspace(0, B, grid_size)` .
152+ generating draws from the grid points ` jnp .linspace(0, B, grid_size)` .
145153
146154A very useful family of discrete distributions is the Beta-binomial family,
147155with probability mass function
@@ -229,31 +237,34 @@ $I$, $II$ and $III$ are as given in {eq}`eyes`.
229237
230238``` {code-cell} ipython3
231239@jax.jit
232- def bellman_operator(model, v):
233- """
234- The Bellman operator for the career choice model.
235- """
236- θ, ε, β = model.θ, model.ε, model.β
237- F_probs, G_probs = model.F_probs, model.G_probs
238- F_mean, G_mean = model.F_mean, model.G_mean
239-
240- # Vectorized computation
241- # Broadcasting θ and ε to create all combinations
242- θ_grid, ε_grid = jnp.meshgrid(θ, ε, indexing='ij')
243-
240+ def Q(θ_grid, ε_grid, β, v, F_probs, G_probs, F_mean, G_mean):
244241 # Option 1: Stay put
245242 v1 = θ_grid + ε_grid + β * v
246243
247244 # Option 2: New job (keep θ, new ε)
248- # For each θ[i], compute expected value over new ε
249245 ev_new_job = jnp.dot(v, G_probs) # Expected value for each θ
250246 v2 = θ_grid + G_mean + β * ev_new_job[:, jnp.newaxis]
251247
252248 # Option 3: New life (new θ and new ε)
253- # Expected value over both θ and ε
254249 ev_new_life = jnp.dot(F_probs, jnp.dot(v, G_probs))
255250 v3 = jnp.full_like(v, G_mean + F_mean + β * ev_new_life)
256251
252+ return v1, v2, v3
253+
254+ @jax.jit
255+ def bellman_operator(model, v):
256+ """
257+ The Bellman operator for the career choice model.
258+ """
259+ θ, ε, β = model.θ, model.ε, model.β
260+ F_probs, G_probs = model.F_probs, model.G_probs
261+ F_mean, G_mean = model.F_mean, model.G_mean
262+
263+ v1, v2, v3 = Q(
264+ *jnp.meshgrid(θ, ε, indexing='ij'),
265+ β, v, F_probs, G_probs, F_mean, G_mean
266+ )
267+
257268 return jnp.maximum(jnp.maximum(v1, v2), v3)
258269
259270@jax.jit
@@ -266,20 +277,10 @@ def get_greedy_policy(model, v):
266277 F_probs, G_probs = model.F_probs, model.G_probs
267278 F_mean, G_mean = model.F_mean, model.G_mean
268279
269- # Vectorized computation
270- # Broadcasting θ and ε to create all combinations
271- θ_grid, ε_grid = jnp.meshgrid(θ, ε, indexing='ij')
272-
273- # Option 1: Stay put
274- v1 = θ_grid + ε_grid + β * v
275-
276- # Option 2: New job (keep θ, new ε)
277- ev_new_job = jnp.dot(v, G_probs) # Expected value for each θ
278- v2 = θ_grid + G_mean + β * ev_new_job[:, jnp.newaxis]
279-
280- # Option 3: New life (new θ and new ε)
281- ev_new_life = jnp.dot(F_probs, jnp.dot(v, G_probs))
282- v3 = jnp.full_like(v, G_mean + F_mean + β * ev_new_life)
280+ v1, v2, v3 = Q(
281+ *jnp.meshgrid(θ, ε, indexing='ij'),
282+ β, v, F_probs, G_probs, F_mean, G_mean
283+ )
283284
284285 # Stack the value arrays and find argmax along first axis
285286 values = jnp.stack([v1, v2, v3], axis=0)
0 commit comments