@@ -37,6 +37,8 @@ We'll also use JAX's `vmap` function to fully vectorize the Coleman-Reffett oper
3737Let's start with some standard imports:
3838
3939``` {code-cell} ipython
40+ from typing import NamedTuple
41+
4042import matplotlib.pyplot as plt
4143import jax
4244import jax.numpy as jnp
@@ -80,8 +82,6 @@ The `Model` class stores only the data (grids, shocks, and parameters).
8082Utility and production functions will be defined globally to work with JAX's JIT compiler.
8183
8284``` {code-cell} python3
83- from typing import NamedTuple, Callable
84-
8585class Model(NamedTuple):
8686 β: float # discount factor
8787 μ: float # shock location parameter
@@ -225,14 +225,14 @@ plt.show()
225225The fit is excellent.
226226
227227``` {code-cell} python3
228- jnp.max(jnp.abs(σ - σ_star(x, model.α, model.β)))
228+ print(f"Maximum absolute deviation: { jnp.max(jnp.abs(σ - σ_star(x, model.α, model.β))):.6e}" )
229229```
230230
231231The JAX implementation is very fast thanks to JIT compilation and vectorization.
232232
233233``` {code-cell} python3
234234with qe.Timer():
235- σ = solve_model_time_iter(model, σ_init)
235+ σ = solve_model_time_iter(model, σ_init).block_until_ready()
236236```
237237
238238This speed comes from:
@@ -241,3 +241,155 @@ This speed comes from:
241241* Vectorization via ` vmap ` in the Coleman-Reffett operator
242242* Use of ` jax.lax.while_loop ` instead of a Python loop
243243* Efficient JAX array operations throughout
244+
245+ ## Exercises
246+
247+ ``` {exercise}
248+ :label: cake_egm_jax_ex1
249+
250+ Solve the stochastic cake eating problem with CRRA utility
251+
252+ $$
253+ u(c) = \frac{c^{1 - \gamma} - 1}{1 - \gamma}
254+ $$
255+
256+ Compare the optimal policies for values of $\gamma$ approaching 1 from above (e.g., 1.05, 1.1, 1.2).
257+
258+ Show that as $\gamma \to 1$, the optimal policy converges to the policy obtained with log utility ($\gamma = 1$).
259+
260+ Hint: Use values of $\gamma$ close to 1 to ensure the endogenous grids have similar coverage and make visual comparison easier.
261+ ```
262+
263+ ``` {solution-start} cake_egm_jax_ex1
264+ :class: dropdown
265+ ```
266+
267+ We need to create a version of the Coleman-Reffett operator and solver that work with CRRA utility.
268+
269+ The key is to parameterize the utility functions by $\gamma$.
270+
271+ ``` {code-cell} python3
272+ def u_crra(c, γ):
273+ return (c**(1 - γ) - 1) / (1 - γ)
274+
275+ def u_prime_crra(c, γ):
276+ return c**(-γ)
277+
278+ def u_prime_inv_crra(x, γ):
279+ return x**(-1/γ)
280+ ```
281+
282+ Now we create a version of the Coleman-Reffett operator that takes $\gamma$ as a parameter.
283+
284+ ``` {code-cell} python3
285+ def K_crra(σ_array: jnp.ndarray, model: Model, γ: float) -> jnp.ndarray:
286+ """
287+ The Coleman-Reffett operator using EGM with CRRA utility
288+ """
289+ # Simplify names
290+ β, α = model.β, model.α
291+ grid, shocks = model.grid, model.shocks
292+
293+ # Determine endogenous grid
294+ x = grid + σ_array
295+
296+ # Linear interpolation of policy using endogenous grid
297+ σ = lambda x_val: jnp.interp(x_val, x, σ_array)
298+
299+ # Define function to compute consumption at a single grid point
300+ def compute_c(k):
301+ vals = u_prime_crra(σ(f(k, α) * shocks), γ) * f_prime(k, α) * shocks
302+ return u_prime_inv_crra(β * jnp.mean(vals), γ)
303+
304+ # Vectorize over grid using vmap
305+ compute_c_vectorized = jax.vmap(compute_c)
306+ c = compute_c_vectorized(grid)
307+
308+ return c
309+ ```
310+
311+ We also need a solver that uses this operator.
312+
313+ ``` {code-cell} python3
314+ @jax.jit
315+ def solve_model_crra(model: Model,
316+ σ_init: jnp.ndarray,
317+ γ: float,
318+ tol: float = 1e-5,
319+ max_iter: int = 1000) -> jnp.ndarray:
320+ """
321+ Solve the model using time iteration with EGM and CRRA utility.
322+ """
323+
324+ def condition(loop_state):
325+ i, σ, error = loop_state
326+ return (error > tol) & (i < max_iter)
327+
328+ def body(loop_state):
329+ i, σ, error = loop_state
330+ σ_new = K_crra(σ, model, γ)
331+ error = jnp.max(jnp.abs(σ_new - σ))
332+ return i + 1, σ_new, error
333+
334+ # Initialize loop state
335+ initial_state = (0, σ_init, tol + 1)
336+
337+ # Run the loop
338+ i, σ, error = jax.lax.while_loop(condition, body, initial_state)
339+
340+ return σ
341+ ```
342+
343+ Now we solve for $\gamma = 1$ (log utility) and values approaching 1 from above.
344+
345+ ``` {code-cell} python3
346+ γ_values = [1.0, 1.05, 1.1, 1.2]
347+ policies = {}
348+
349+ model_crra = create_model(α=α)
350+
351+ for γ in γ_values:
352+ σ_init = jnp.copy(model_crra.grid)
353+ σ_gamma = solve_model_crra(model_crra, σ_init, γ).block_until_ready()
354+ policies[γ] = σ_gamma
355+ print(f"Solved for γ = {γ}")
356+ ```
357+
358+ Plot the policies on their endogenous grids.
359+
360+ ``` {code-cell} python3
361+ fig, ax = plt.subplots()
362+
363+ for γ in γ_values:
364+ x = model_crra.grid + policies[γ]
365+ if γ == 1.0:
366+ ax.plot(x, policies[γ], 'k-', linewidth=2,
367+ label=f'γ = {γ:.2f} (log utility)', alpha=0.8)
368+ else:
369+ ax.plot(x, policies[γ], label=f'γ = {γ:.2f}', alpha=0.8)
370+
371+ ax.set_xlabel('State x')
372+ ax.set_ylabel('Consumption σ(x)')
373+ ax.legend()
374+ ax.set_title('Optimal policies: CRRA utility approaching log case')
375+ plt.show()
376+ ```
377+
378+ Since the endogenous grids are similar for $\gamma$ values close to 1, the policies overlap nicely.
379+
380+ Note that the plots for $\gamma > 1$ do not cover the entire x-axis range shown.
381+
382+ This is because the endogenous grid $x = k + \sigma(k)$ depends on the consumption policy, which varies with $\gamma$.
383+
384+ Let's check the maximum deviation between the log utility case ($\gamma = 1.0$) and values approaching from above.
385+
386+ ``` {code-cell} python3
387+ for γ in [1.05, 1.1, 1.2]:
388+ max_diff = jnp.max(jnp.abs(policies[1.0] - policies[γ]))
389+ print(f"Max difference between γ=1.0 and γ={γ}: {max_diff:.6e}")
390+ ```
391+
392+ As expected, the differences decrease as $\gamma$ approaches 1 from above, confirming convergence.
393+
394+ ``` {solution-end}
395+ ```
0 commit comments