Skip to content

Commit 6144e81

Browse files
jstacclaude
andcommitted
Refine JAX EGM lecture: improve imports and add CRRA exercise
- Moved all imports to top of file (NamedTuple with other imports) - Removed unused Callable import - Added block_until_ready() to timing for accurate JAX benchmarking - Improved error output formatting with print statement - Added comprehensive CRRA utility exercise demonstrating convergence Exercise improvements: - Uses correct CRRA form: u(c) = (c^(1-γ) - 1)/(1-γ) that converges to log - Focuses on γ values approaching 1 from above (1.05, 1.1, 1.2) - Plots γ=1 (log case) in black with clear labeling - Includes explanation of endogenous grid coverage differences - Shows numerical convergence with maximum deviation metrics 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 10343f6 commit 6144e81

File tree

1 file changed

+156
-4
lines changed

1 file changed

+156
-4
lines changed

lectures/cake_eating_egm_jax.md

Lines changed: 156 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ We'll also use JAX's `vmap` function to fully vectorize the Coleman-Reffett oper
3737
Let's start with some standard imports:
3838

3939
```{code-cell} ipython
40+
from typing import NamedTuple
41+
4042
import matplotlib.pyplot as plt
4143
import jax
4244
import jax.numpy as jnp
@@ -80,8 +82,6 @@ The `Model` class stores only the data (grids, shocks, and parameters).
8082
Utility and production functions will be defined globally to work with JAX's JIT compiler.
8183

8284
```{code-cell} python3
83-
from typing import NamedTuple, Callable
84-
8585
class Model(NamedTuple):
8686
β: float # discount factor
8787
μ: float # shock location parameter
@@ -225,14 +225,14 @@ plt.show()
225225
The fit is excellent.
226226

227227
```{code-cell} python3
228-
jnp.max(jnp.abs(σ - σ_star(x, model.α, model.β)))
228+
print(f"Maximum absolute deviation: {jnp.max(jnp.abs(σ - σ_star(x, model.α, model.β))):.6e}")
229229
```
230230

231231
The JAX implementation is very fast thanks to JIT compilation and vectorization.
232232

233233
```{code-cell} python3
234234
with qe.Timer():
235-
σ = solve_model_time_iter(model, σ_init)
235+
σ = solve_model_time_iter(model, σ_init).block_until_ready()
236236
```
237237

238238
This speed comes from:
@@ -241,3 +241,155 @@ This speed comes from:
241241
* Vectorization via `vmap` in the Coleman-Reffett operator
242242
* Use of `jax.lax.while_loop` instead of a Python loop
243243
* Efficient JAX array operations throughout
244+
245+
## Exercises
246+
247+
```{exercise}
248+
:label: cake_egm_jax_ex1
249+
250+
Solve the stochastic cake eating problem with CRRA utility
251+
252+
$$
253+
u(c) = \frac{c^{1 - \gamma} - 1}{1 - \gamma}
254+
$$
255+
256+
Compare the optimal policies for values of $\gamma$ approaching 1 from above (e.g., 1.05, 1.1, 1.2).
257+
258+
Show that as $\gamma \to 1$, the optimal policy converges to the policy obtained with log utility ($\gamma = 1$).
259+
260+
Hint: Use values of $\gamma$ close to 1 to ensure the endogenous grids have similar coverage and make visual comparison easier.
261+
```
262+
263+
```{solution-start} cake_egm_jax_ex1
264+
:class: dropdown
265+
```
266+
267+
We need to create a version of the Coleman-Reffett operator and solver that work with CRRA utility.
268+
269+
The key is to parameterize the utility functions by $\gamma$.
270+
271+
```{code-cell} python3
272+
def u_crra(c, γ):
273+
return (c**(1 - γ) - 1) / (1 - γ)
274+
275+
def u_prime_crra(c, γ):
276+
return c**(-γ)
277+
278+
def u_prime_inv_crra(x, γ):
279+
return x**(-1/γ)
280+
```
281+
282+
Now we create a version of the Coleman-Reffett operator that takes $\gamma$ as a parameter.
283+
284+
```{code-cell} python3
285+
def K_crra(σ_array: jnp.ndarray, model: Model, γ: float) -> jnp.ndarray:
286+
"""
287+
The Coleman-Reffett operator using EGM with CRRA utility
288+
"""
289+
# Simplify names
290+
β, α = model.β, model.α
291+
grid, shocks = model.grid, model.shocks
292+
293+
# Determine endogenous grid
294+
x = grid + σ_array
295+
296+
# Linear interpolation of policy using endogenous grid
297+
σ = lambda x_val: jnp.interp(x_val, x, σ_array)
298+
299+
# Define function to compute consumption at a single grid point
300+
def compute_c(k):
301+
vals = u_prime_crra(σ(f(k, α) * shocks), γ) * f_prime(k, α) * shocks
302+
return u_prime_inv_crra(β * jnp.mean(vals), γ)
303+
304+
# Vectorize over grid using vmap
305+
compute_c_vectorized = jax.vmap(compute_c)
306+
c = compute_c_vectorized(grid)
307+
308+
return c
309+
```
310+
311+
We also need a solver that uses this operator.
312+
313+
```{code-cell} python3
314+
@jax.jit
315+
def solve_model_crra(model: Model,
316+
σ_init: jnp.ndarray,
317+
γ: float,
318+
tol: float = 1e-5,
319+
max_iter: int = 1000) -> jnp.ndarray:
320+
"""
321+
Solve the model using time iteration with EGM and CRRA utility.
322+
"""
323+
324+
def condition(loop_state):
325+
i, σ, error = loop_state
326+
return (error > tol) & (i < max_iter)
327+
328+
def body(loop_state):
329+
i, σ, error = loop_state
330+
σ_new = K_crra(σ, model, γ)
331+
error = jnp.max(jnp.abs(σ_new - σ))
332+
return i + 1, σ_new, error
333+
334+
# Initialize loop state
335+
initial_state = (0, σ_init, tol + 1)
336+
337+
# Run the loop
338+
i, σ, error = jax.lax.while_loop(condition, body, initial_state)
339+
340+
return σ
341+
```
342+
343+
Now we solve for $\gamma = 1$ (log utility) and values approaching 1 from above.
344+
345+
```{code-cell} python3
346+
γ_values = [1.0, 1.05, 1.1, 1.2]
347+
policies = {}
348+
349+
model_crra = create_model(α=α)
350+
351+
for γ in γ_values:
352+
σ_init = jnp.copy(model_crra.grid)
353+
σ_gamma = solve_model_crra(model_crra, σ_init, γ).block_until_ready()
354+
policies[γ] = σ_gamma
355+
print(f"Solved for γ = {γ}")
356+
```
357+
358+
Plot the policies on their endogenous grids.
359+
360+
```{code-cell} python3
361+
fig, ax = plt.subplots()
362+
363+
for γ in γ_values:
364+
x = model_crra.grid + policies[γ]
365+
if γ == 1.0:
366+
ax.plot(x, policies[γ], 'k-', linewidth=2,
367+
label=f'γ = {γ:.2f} (log utility)', alpha=0.8)
368+
else:
369+
ax.plot(x, policies[γ], label=f'γ = {γ:.2f}', alpha=0.8)
370+
371+
ax.set_xlabel('State x')
372+
ax.set_ylabel('Consumption σ(x)')
373+
ax.legend()
374+
ax.set_title('Optimal policies: CRRA utility approaching log case')
375+
plt.show()
376+
```
377+
378+
Since the endogenous grids are similar for $\gamma$ values close to 1, the policies overlap nicely.
379+
380+
Note that the plots for $\gamma > 1$ do not cover the entire x-axis range shown.
381+
382+
This is because the endogenous grid $x = k + \sigma(k)$ depends on the consumption policy, which varies with $\gamma$.
383+
384+
Let's check the maximum deviation between the log utility case ($\gamma = 1.0$) and values approaching from above.
385+
386+
```{code-cell} python3
387+
for γ in [1.05, 1.1, 1.2]:
388+
max_diff = jnp.max(jnp.abs(policies[1.0] - policies[γ]))
389+
print(f"Max difference between γ=1.0 and γ={γ}: {max_diff:.6e}")
390+
```
391+
392+
As expected, the differences decrease as $\gamma$ approaches 1 from above, confirming convergence.
393+
394+
```{solution-end}
395+
```

0 commit comments

Comments
 (0)