Skip to content

Commit 1ededf4

Browse files
committed
updates
1 parent 6232e18 commit 1ededf4

File tree

1 file changed

+37
-36
lines changed

1 file changed

+37
-36
lines changed

lectures/career.md

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ jupytext:
44
extension: .md
55
format_name: myst
66
format_version: 0.13
7-
jupytext_version: 1.17.1
7+
jupytext_version: 1.16.6
88
kernelspec:
9-
name: python3
109
display_name: Python 3 (ipykernel)
1110
language: python
11+
name: python3
1212
---
1313

1414
(career)=
@@ -34,7 +34,15 @@ In addition to what's in Anaconda, this lecture will need the following librarie
3434
```{code-cell} ipython3
3535
:tags: [hide-output]
3636
37-
!pip install quantecon
37+
!pip install --upgrade quantecon
38+
```
39+
40+
We also need to install JAX to run this lecture
41+
42+
```{code-cell} ipython3
43+
:tags: [skip-execution]
44+
45+
!pip install -U jax
3846
```
3947

4048
## Overview
@@ -51,15 +59,15 @@ We begin with some imports:
5159
import matplotlib.pyplot as plt
5260
import jax.numpy as jnp
5361
import jax
54-
import jax.random as jr
62+
import jax.random as jr
5563
from typing import NamedTuple
5664
from quantecon.distributions import BetaBinomial
5765
from scipy.special import binom, beta
5866
from mpl_toolkits.mplot3d.axes3d import Axes3D
5967
from matplotlib import cm
6068
```
6169

62-
### Model Features
70+
### Model features
6371

6472
* Career and job within career both chosen to maximize expected discounted wage flow.
6573
* Infinite horizon dynamic programming with two state variables.
@@ -71,7 +79,7 @@ In what follows we distinguish between a career and a job, where
7179
* a *career* is understood to be a general field encompassing many possible jobs, and
7280
* a *job* is understood to be a position with a particular firm
7381

74-
For workers, wages can be decomposed into the contribution of job and career
82+
For workers, wages can be decomposed into the contributions of job and career
7583

7684
* $w_t = \theta_t + \epsilon_t$, where
7785
* $\theta_t$ is the contribution of career at time $t$
@@ -134,14 +142,14 @@ Evidently $I$, $II$ and $III$ correspond to "stay put", "new job" and "new life"
134142
As in {cite}`Ljungqvist2012`, section 6.5, we will focus on a discrete version of the model, parameterized as follows:
135143

136144
* both $\theta$ and $\epsilon$ take values in the set
137-
`np.linspace(0, B, grid_size)` --- an even grid of points between
145+
`jnp.linspace(0, B, grid_size)` --- an even grid of points between
138146
$0$ and $B$ inclusive
139147
* `grid_size = 50`
140148
* `B = 5`
141149
* `β = 0.95`
142150

143151
The distributions $F$ and $G$ are discrete distributions
144-
generating draws from the grid points `np.linspace(0, B, grid_size)`.
152+
generating draws from the grid points `jnp.linspace(0, B, grid_size)`.
145153

146154
A very useful family of discrete distributions is the Beta-binomial family,
147155
with probability mass function
@@ -229,31 +237,34 @@ $I$, $II$ and $III$ are as given in {eq}`eyes`.
229237

230238
```{code-cell} ipython3
231239
@jax.jit
232-
def bellman_operator(model, v):
233-
"""
234-
The Bellman operator for the career choice model.
235-
"""
236-
θ, ε, β = model.θ, model.ε, model.β
237-
F_probs, G_probs = model.F_probs, model.G_probs
238-
F_mean, G_mean = model.F_mean, model.G_mean
239-
240-
# Vectorized computation
241-
# Broadcasting θ and ε to create all combinations
242-
θ_grid, ε_grid = jnp.meshgrid(θ, ε, indexing='ij')
243-
240+
def Q(θ_grid, ε_grid, β, v, F_probs, G_probs, F_mean, G_mean):
244241
# Option 1: Stay put
245242
v1 = θ_grid + ε_grid + β * v
246243
247244
# Option 2: New job (keep θ, new ε)
248-
# For each θ[i], compute expected value over new ε
249245
ev_new_job = jnp.dot(v, G_probs) # Expected value for each θ
250246
v2 = θ_grid + G_mean + β * ev_new_job[:, jnp.newaxis]
251247
252248
# Option 3: New life (new θ and new ε)
253-
# Expected value over both θ and ε
254249
ev_new_life = jnp.dot(F_probs, jnp.dot(v, G_probs))
255250
v3 = jnp.full_like(v, G_mean + F_mean + β * ev_new_life)
256251
252+
return v1, v2, v3
253+
254+
@jax.jit
255+
def bellman_operator(model, v):
256+
"""
257+
The Bellman operator for the career choice model.
258+
"""
259+
θ, ε, β = model.θ, model.ε, model.β
260+
F_probs, G_probs = model.F_probs, model.G_probs
261+
F_mean, G_mean = model.F_mean, model.G_mean
262+
263+
v1, v2, v3 = Q(
264+
*jnp.meshgrid(θ, ε, indexing='ij'),
265+
β, v, F_probs, G_probs, F_mean, G_mean
266+
)
267+
257268
return jnp.maximum(jnp.maximum(v1, v2), v3)
258269
259270
@jax.jit
@@ -266,20 +277,10 @@ def get_greedy_policy(model, v):
266277
F_probs, G_probs = model.F_probs, model.G_probs
267278
F_mean, G_mean = model.F_mean, model.G_mean
268279
269-
# Vectorized computation
270-
# Broadcasting θ and ε to create all combinations
271-
θ_grid, ε_grid = jnp.meshgrid(θ, ε, indexing='ij')
272-
273-
# Option 1: Stay put
274-
v1 = θ_grid + ε_grid + β * v
275-
276-
# Option 2: New job (keep θ, new ε)
277-
ev_new_job = jnp.dot(v, G_probs) # Expected value for each θ
278-
v2 = θ_grid + G_mean + β * ev_new_job[:, jnp.newaxis]
279-
280-
# Option 3: New life (new θ and new ε)
281-
ev_new_life = jnp.dot(F_probs, jnp.dot(v, G_probs))
282-
v3 = jnp.full_like(v, G_mean + F_mean + β * ev_new_life)
280+
v1, v2, v3 = Q(
281+
*jnp.meshgrid(θ, ε, indexing='ij'),
282+
β, v, F_probs, G_probs, F_mean, G_mean
283+
)
283284
284285
# Stack the value arrays and find argmax along first axis
285286
values = jnp.stack([v1, v2, v3], axis=0)

0 commit comments

Comments
 (0)