@@ -37,19 +37,18 @@ We'll also use JAX's `vmap` function to fully vectorize the Coleman-Reffett oper
3737Let's start with some standard imports:
3838
3939``` {code-cell} ipython
40- from typing import NamedTuple
41-
4240import matplotlib.pyplot as plt
4341import jax
4442import jax.numpy as jnp
4543import quantecon as qe
44+ from typing import NamedTuple
4645```
4746
4847## Implementation
4948
50- For details on the endogenous grid method, please see {doc}` cake_eating_egm ` .
49+ For details on the savings problem and the endogenous grid method (EGM) , please see {doc}` cake_eating_egm ` .
5150
52- Here we focus on the JAX implementation.
51+ Here we focus on the JAX implementation of EGM .
5352
5453We use the same setting as in {doc}` cake_eating_egm ` :
5554
@@ -222,16 +221,17 @@ ax.legend()
222221plt.show()
223222```
224223
225- The fit is excellent .
224+ The fit is very good .
226225
227226``` {code-cell} python3
228- print(f"Maximum absolute deviation: {jnp.max(jnp.abs(σ - σ_star(x, model.α, model.β))):.6e}")
227+ max_dev = jnp.max(jnp.abs(σ - σ_star(x, model.α, model.β)))
228+ print(f"Maximum absolute deviation: {max_dev:.7}")
229229```
230230
231231The JAX implementation is very fast thanks to JIT compilation and vectorization.
232232
233233``` {code-cell} python3
234- with qe.Timer():
234+ with qe.Timer(precision=8 ):
235235 σ = solve_model_time_iter(model, σ_init).block_until_ready()
236236```
237237
@@ -250,7 +250,7 @@ This speed comes from:
250250Solve the stochastic cake eating problem with CRRA utility
251251
252252$$
253- u(c) = \frac{c^{1 - \gamma} - 1}{1 - \gamma}
253+ u(c) = \frac{c^{1 - \gamma} - 1}{1 - \gamma}
254254$$
255255
256256Compare the optimal policies for values of $\gamma$ approaching 1 from above (e.g., 1.05, 1.1, 1.2).
@@ -375,8 +375,6 @@ ax.set_title('Optimal policies: CRRA utility approaching log case')
375375plt.show()
376376```
377377
378- Since the endogenous grids are similar for $\gamma$ values close to 1, the policies overlap nicely.
379-
380378Note that the plots for $\gamma > 1$ do not cover the entire x-axis range shown.
381379
382380This is because the endogenous grid $x = k + \sigma(k)$ depends on the consumption policy, which varies with $\gamma$.
@@ -386,7 +384,7 @@ Let's check the maximum deviation between the log utility case ($\gamma = 1.0$)
386384``` {code-cell} python3
387385for γ in [1.05, 1.1, 1.2]:
388386 max_diff = jnp.max(jnp.abs(policies[1.0] - policies[γ]))
389- print(f"Max difference between γ=1.0 and γ={γ}: {max_diff:.6e }")
387+ print(f"Max difference between γ=1.0 and γ={γ}: {max_diff:.6 }")
390388```
391389
392390As expected, the differences decrease as $\gamma$ approaches 1 from above, confirming convergence.
0 commit comments