Skip to content

Commit 2016211

Browse files
committed
misc
1 parent 6144e81 commit 2016211

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

lectures/cake_eating_egm_jax.md

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,18 @@ We'll also use JAX's `vmap` function to fully vectorize the Coleman-Reffett oper
3737
Let's start with some standard imports:
3838

3939
```{code-cell} ipython
40-
from typing import NamedTuple
41-
4240
import matplotlib.pyplot as plt
4341
import jax
4442
import jax.numpy as jnp
4543
import quantecon as qe
44+
from typing import NamedTuple
4645
```
4746

4847
## Implementation
4948

50-
For details on the endogenous grid method, please see {doc}`cake_eating_egm`.
49+
For details on the savings problem and the endogenous grid method (EGM), please see {doc}`cake_eating_egm`.
5150

52-
Here we focus on the JAX implementation.
51+
Here we focus on the JAX implementation of EGM.
5352

5453
We use the same setting as in {doc}`cake_eating_egm`:
5554

@@ -222,16 +221,17 @@ ax.legend()
222221
plt.show()
223222
```
224223

225-
The fit is excellent.
224+
The fit is very good.
226225

227226
```{code-cell} python3
228-
print(f"Maximum absolute deviation: {jnp.max(jnp.abs(σ - σ_star(x, model.α, model.β))):.6e}")
227+
max_dev = jnp.max(jnp.abs(σ - σ_star(x, model.α, model.β)))
228+
print(f"Maximum absolute deviation: {max_dev:.7}")
229229
```
230230

231231
The JAX implementation is very fast thanks to JIT compilation and vectorization.
232232

233233
```{code-cell} python3
234-
with qe.Timer():
234+
with qe.Timer(precision=8):
235235
σ = solve_model_time_iter(model, σ_init).block_until_ready()
236236
```
237237

@@ -250,7 +250,7 @@ This speed comes from:
250250
Solve the stochastic cake eating problem with CRRA utility
251251
252252
$$
253-
u(c) = \frac{c^{1 - \gamma} - 1}{1 - \gamma}
253+
u(c) = \frac{c^{1 - \gamma} - 1}{1 - \gamma}
254254
$$
255255
256256
Compare the optimal policies for values of $\gamma$ approaching 1 from above (e.g., 1.05, 1.1, 1.2).
@@ -375,8 +375,6 @@ ax.set_title('Optimal policies: CRRA utility approaching log case')
375375
plt.show()
376376
```
377377

378-
Since the endogenous grids are similar for $\gamma$ values close to 1, the policies overlap nicely.
379-
380378
Note that the plots for $\gamma > 1$ do not cover the entire x-axis range shown.
381379

382380
This is because the endogenous grid $x = k + \sigma(k)$ depends on the consumption policy, which varies with $\gamma$.
@@ -386,7 +384,7 @@ Let's check the maximum deviation between the log utility case ($\gamma = 1.0$)
386384
```{code-cell} python3
387385
for γ in [1.05, 1.1, 1.2]:
388386
max_diff = jnp.max(jnp.abs(policies[1.0] - policies[γ]))
389-
print(f"Max difference between γ=1.0 and γ={γ}: {max_diff:.6e}")
387+
print(f"Max difference between γ=1.0 and γ={γ}: {max_diff:.6}")
390388
```
391389

392390
As expected, the differences decrease as $\gamma$ approaches 1 from above, confirming convergence.

0 commit comments

Comments
 (0)