Skip to content

Commit 10343f6

Browse files
jstacclaude
andcommitted
Add JAX-based cake eating EGM lecture and refine related lectures
- Created new lecture cake_eating_egm_jax.md implementing EGM with JAX - Uses JAX's vmap for vectorization instead of for loops - JIT-compiled solver with jax.lax.while_loop - Global utility/production functions for JAX compatibility - Streamlined Model class to contain only arrays and scalars - Focuses on JAX implementation patterns, refers to cake_eating_egm for theory - Improved cake_eating_time_iter.md - Moved imports to top of file - Removed timing benchmarks, added clearer performance discussion - Explained why time iteration is faster (exploits differentiability/FOCs) - Referenced EGM as even faster variation - Enhanced cake_eating_egm.md - Removed default values from Model class (now only in create_model) - Aligned all comments in Model class definition - Replaced %%timeit with qe.Timer for consistency - Simplified timing discussion - Updated _toc.yml to include new lecture after cake_eating_egm 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 8f5d9bf commit 10343f6

File tree

4 files changed

+266
-29
lines changed

4 files changed

+266
-29
lines changed

lectures/_toc.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ parts:
8080
- file: cake_eating_stochastic
8181
- file: cake_eating_time_iter
8282
- file: cake_eating_egm
83+
- file: cake_eating_egm_jax
8384
- file: ifp
8485
- file: ifp_advanced
8586
- caption: LQ Control

lectures/cake_eating_egm.md

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ Let's start with some standard imports:
4444
```{code-cell} ipython
4545
import matplotlib.pyplot as plt
4646
import numpy as np
47+
import quantecon as qe
4748
```
4849

4950
## Key Idea
@@ -169,17 +170,17 @@ We reuse the `Model` structure from {doc}`Cake Eating IV <cake_eating_time_iter>
169170
from typing import NamedTuple, Callable
170171
171172
class Model(NamedTuple):
172-
u: Callable # utility function
173-
f: Callable # production function
174-
β: float # discount factor
175-
μ: float # shock location parameter
176-
s: float # shock scale parameter
177-
grid: np.ndarray # state grid
178-
shocks: np.ndarray # shock draws
179-
α: float = 0.4 # production function parameter
180-
u_prime: Callable = None # derivative of utility
181-
f_prime: Callable = None # derivative of production
182-
u_prime_inv: Callable = None # inverse of u_prime
173+
u: Callable # utility function
174+
f: Callable # production function
175+
β: float # discount factor
176+
μ: float # shock location parameter
177+
s: float # shock scale parameter
178+
grid: np.ndarray # state grid
179+
shocks: np.ndarray # shock draws
180+
α: float # production function parameter
181+
u_prime: Callable # derivative of utility
182+
f_prime: Callable # derivative of production
183+
u_prime_inv: Callable # inverse of u_prime
183184
184185
185186
def create_model(u: Callable,
@@ -322,16 +323,13 @@ The maximal absolute deviation between the two policies is
322323
np.max(np.abs(σ - σ_star(x, model.α, model.β)))
323324
```
324325

325-
How long does it take to converge?
326+
Here's the execution time:
326327

327328
```{code-cell} python3
328-
%%timeit -n 3 -r 1
329-
σ = solve_model_time_iter(model, σ_init, verbose=False)
329+
with qe.Timer():
330+
σ = solve_model_time_iter(model, σ_init, verbose=False)
330331
```
331332

332-
Relative to time iteration, which was already found to be highly efficient, EGM
333-
has managed to shave off still more run time without compromising accuracy.
333+
EGM is faster than time iteration because it avoids numerical root-finding.
334334

335-
This is due to the lack of a numerical root-finding step.
336-
337-
We can now solve the stochastic cake eating problem at given parameters extremely fast.
335+
Instead, we invert the marginal utility function directly, which is much more efficient.

lectures/cake_eating_egm_jax.md

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
---
2+
jupytext:
3+
text_representation:
4+
extension: .md
5+
format_name: myst
6+
kernelspec:
7+
display_name: Python 3
8+
language: python
9+
name: python3
10+
---
11+
12+
```{raw} jupyter
13+
<div id="qe-notebook-header" align="right" style="text-align:right;">
14+
<a href="https://quantecon.org/" title="quantecon.org">
15+
<img style="width:250px;display:inline;" width="250px" src="https://assets.quantecon.org/img/qe-menubar-logo.svg" alt="QuantEcon">
16+
</a>
17+
</div>
18+
```
19+
20+
# {index}`Cake Eating VI: EGM with JAX <single: Cake Eating VI: EGM with JAX>`
21+
22+
```{contents} Contents
23+
:depth: 2
24+
```
25+
26+
27+
## Overview
28+
29+
In this lecture, we'll implement the endogenous grid method (EGM) using JAX.
30+
31+
This lecture builds on {doc}`cake_eating_egm`, which introduced EGM using NumPy.
32+
33+
By converting to JAX, we can leverage fast linear algebra, hardware accelerators, and JIT compilation for improved performance.
34+
35+
We'll also use JAX's `vmap` function to fully vectorize the Coleman-Reffett operator.
36+
37+
Let's start with some standard imports:
38+
39+
```{code-cell} ipython
40+
import matplotlib.pyplot as plt
41+
import jax
42+
import jax.numpy as jnp
43+
import quantecon as qe
44+
```
45+
46+
## Implementation
47+
48+
For details on the endogenous grid method, please see {doc}`cake_eating_egm`.
49+
50+
Here we focus on the JAX implementation.
51+
52+
We use the same setting as in {doc}`cake_eating_egm`:
53+
54+
* $u(c) = \ln c$,
55+
* production is Cobb-Douglas, and
56+
* the shocks are lognormal.
57+
58+
Here are the analytical solutions for comparison.
59+
60+
```{code-cell} python3
61+
def v_star(x, α, β, μ):
62+
"""
63+
True value function
64+
"""
65+
c1 = jnp.log(1 - α * β) / (1 - β)
66+
c2 = (μ + α * jnp.log(α * β)) / (1 - α)
67+
c3 = 1 / (1 - β)
68+
c4 = 1 / (1 - α * β)
69+
return c1 + c2 * (c3 - c4) + c4 * jnp.log(x)
70+
71+
def σ_star(x, α, β):
72+
"""
73+
True optimal policy
74+
"""
75+
return (1 - α * β) * x
76+
```
77+
78+
The `Model` class stores only the data (grids, shocks, and parameters).
79+
80+
Utility and production functions will be defined globally to work with JAX's JIT compiler.
81+
82+
```{code-cell} python3
83+
from typing import NamedTuple, Callable
84+
85+
class Model(NamedTuple):
86+
β: float # discount factor
87+
μ: float # shock location parameter
88+
s: float # shock scale parameter
89+
grid: jnp.ndarray # state grid
90+
shocks: jnp.ndarray # shock draws
91+
α: float # production function parameter
92+
93+
94+
def create_model(β: float = 0.96,
95+
μ: float = 0.0,
96+
s: float = 0.1,
97+
grid_max: float = 4.0,
98+
grid_size: int = 120,
99+
shock_size: int = 250,
100+
seed: int = 1234,
101+
α: float = 0.4) -> Model:
102+
"""
103+
Creates an instance of the cake eating model.
104+
"""
105+
# Set up grid
106+
grid = jnp.linspace(1e-4, grid_max, grid_size)
107+
108+
# Store shocks (with a seed, so results are reproducible)
109+
key = jax.random.PRNGKey(seed)
110+
shocks = jnp.exp(μ + s * jax.random.normal(key, shape=(shock_size,)))
111+
112+
return Model(β=β, μ=μ, s=s, grid=grid, shocks=shocks, α=α)
113+
```
114+
115+
Here's the Coleman-Reffett operator using EGM.
116+
117+
The key JAX feature here is `vmap`, which vectorizes the computation over the grid points.
118+
119+
```{code-cell} python3
120+
def K(σ_array: jnp.ndarray, model: Model) -> jnp.ndarray:
121+
"""
122+
The Coleman-Reffett operator using EGM
123+
124+
"""
125+
126+
# Simplify names
127+
β, α = model.β, model.α
128+
grid, shocks = model.grid, model.shocks
129+
130+
# Determine endogenous grid
131+
x = grid + σ_array # x_i = k_i + c_i
132+
133+
# Linear interpolation of policy using endogenous grid
134+
σ = lambda x_val: jnp.interp(x_val, x, σ_array)
135+
136+
# Define function to compute consumption at a single grid point
137+
def compute_c(k):
138+
vals = u_prime(σ(f(k, α) * shocks)) * f_prime(k, α) * shocks
139+
return u_prime_inv(β * jnp.mean(vals))
140+
141+
# Vectorize over grid using vmap
142+
compute_c_vectorized = jax.vmap(compute_c)
143+
c = compute_c_vectorized(grid)
144+
145+
return c
146+
```
147+
148+
We define utility and production functions globally.
149+
150+
Note that `f` and `f_prime` take `α` as an explicit argument, allowing them to work with JAX's functional programming model.
151+
152+
```{code-cell} python3
153+
# Define utility and production functions with derivatives
154+
u = lambda c: jnp.log(c)
155+
u_prime = lambda c: 1 / c
156+
u_prime_inv = lambda x: 1 / x
157+
f = lambda k, α: k**α
158+
f_prime = lambda k, α: α * k**(α - 1)
159+
```
160+
161+
Now we create a model instance.
162+
163+
```{code-cell} python3
164+
α = 0.4
165+
166+
model = create_model(α=α)
167+
grid = model.grid
168+
```
169+
170+
The solver uses JAX's `jax.lax.while_loop` for the iteration and is JIT-compiled for speed.
171+
172+
```{code-cell} python3
173+
@jax.jit
174+
def solve_model_time_iter(model: Model,
175+
σ_init: jnp.ndarray,
176+
tol: float = 1e-5,
177+
max_iter: int = 1000) -> jnp.ndarray:
178+
"""
179+
Solve the model using time iteration with EGM.
180+
"""
181+
182+
def condition(loop_state):
183+
i, σ, error = loop_state
184+
return (error > tol) & (i < max_iter)
185+
186+
def body(loop_state):
187+
i, σ, error = loop_state
188+
σ_new = K(σ, model)
189+
error = jnp.max(jnp.abs(σ_new - σ))
190+
return i + 1, σ_new, error
191+
192+
# Initialize loop state
193+
initial_state = (0, σ_init, tol + 1)
194+
195+
# Run the loop
196+
i, σ, error = jax.lax.while_loop(condition, body, initial_state)
197+
198+
return σ
199+
```
200+
201+
We solve the model starting from an initial guess.
202+
203+
```{code-cell} python3
204+
σ_init = jnp.copy(grid)
205+
σ = solve_model_time_iter(model, σ_init)
206+
```
207+
208+
Let's plot the resulting policy against the analytical solution.
209+
210+
```{code-cell} python3
211+
x = grid + σ # x_i = k_i + c_i
212+
213+
fig, ax = plt.subplots()
214+
215+
ax.plot(x, σ, lw=2,
216+
alpha=0.8, label='approximate policy function')
217+
218+
ax.plot(x, σ_star(x, model.α, model.β), 'k--',
219+
lw=2, alpha=0.8, label='true policy function')
220+
221+
ax.legend()
222+
plt.show()
223+
```
224+
225+
The fit is excellent.
226+
227+
```{code-cell} python3
228+
jnp.max(jnp.abs(σ - σ_star(x, model.α, model.β)))
229+
```
230+
231+
The JAX implementation is very fast thanks to JIT compilation and vectorization.
232+
233+
```{code-cell} python3
234+
with qe.Timer():
235+
σ = solve_model_time_iter(model, σ_init)
236+
```
237+
238+
This speed comes from:
239+
240+
* JIT compilation of the entire solver
241+
* Vectorization via `vmap` in the Coleman-Reffett operator
242+
* Use of `jax.lax.while_loop` instead of a Python loop
243+
* Efficient JAX array operations throughout

lectures/cake_eating_time_iter.md

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ Let's start with some imports:
6262
import matplotlib.pyplot as plt
6363
import numpy as np
6464
from scipy.optimize import brentq
65+
from typing import NamedTuple, Callable
6566
```
6667

6768
## The Euler Equation
@@ -285,8 +286,6 @@ For this we need access to the functions $u'$ and $f, f'$.
285286
We use the same `Model` structure from {doc}`Cake Eating III <cake_eating_stochastic>`.
286287

287288
```{code-cell} python3
288-
from typing import NamedTuple, Callable
289-
290289
class Model(NamedTuple):
291290
u: Callable # utility function
292291
f: Callable # production function
@@ -481,17 +480,13 @@ The maximal absolute deviation between the two policies is
481480
np.max(np.abs(σ - σ_star(model.grid, model.α, model.β)))
482481
```
483482

484-
How long does it take to converge?
483+
Time iteration runs faster than value function iteration, as discussed in {doc}`cake_eating_stochastic`.
485484

486-
```{code-cell} python3
487-
%%timeit -n 3 -r 1
488-
σ = solve_model_time_iter(model, σ_init, verbose=False)
489-
```
485+
This is because time iteration exploits differentiability and the first order conditions, while value function iteration does not use this available structure.
490486

491-
Convergence is very fast, even compared to the JIT-compiled value function iteration we used in {doc}`Cake Eating III <cake_eating_stochastic>`.
487+
At the same time, there is a variation of time iteration that runs even faster.
492488

493-
Overall, we find that time iteration provides a very high degree of efficiency
494-
and accuracy for the stochastic cake eating problem.
489+
This is the endogenous grid method, which we will introduce in {doc}`cake_eating_egm`.
495490

496491
## Exercises
497492

0 commit comments

Comments
 (0)