Skip to content

Commit 2b606ca

Browse files
committed
minor updates
1 parent 4b966c1 commit 2b606ca

File tree

1 file changed

+35
-33
lines changed

1 file changed

+35
-33
lines changed

lectures/optgrowth_fast.md

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ jupytext:
33
text_representation:
44
extension: .md
55
format_name: myst
6+
format_version: 0.13
7+
jupytext_version: 1.17.1
68
kernelspec:
7-
display_name: Python 3
8-
language: python
99
name: python3
10+
display_name: Python 3 (ipykernel)
11+
language: python
1012
---
1113

1214
(optgrowth_fast)=
@@ -26,10 +28,9 @@ kernelspec:
2628

2729
In addition to what is in Anaconda, this lecture needs an extra package.
2830

29-
```{code-cell} ipython
30-
---
31-
tags: [hide-output]
32-
---
31+
```{code-cell} ipython3
32+
:tags: [hide-output]
33+
3334
!pip install quantecon jax
3435
```
3536

@@ -58,15 +59,13 @@ accelerate our code.
5859

5960
Let's start with some imports:
6061

61-
```{code-cell} ipython
62+
```{code-cell} ipython3
6263
import matplotlib.pyplot as plt
6364
import numpy as np
6465
import jax
6566
import jax.numpy as jnp
6667
from typing import NamedTuple
6768
import quantecon as qe
68-
69-
jax.config.update("jax_platform_name", "cpu")
7069
```
7170

7271
## The model
@@ -90,15 +89,17 @@ The algorithm is unchanged, but the implementation uses JAX.
9089

9190
As before, we will be able to compare with the true solutions
9291

93-
```{code-cell} python3
92+
```{code-cell} ipython3
9493
:load: _static/lecture_specific/optgrowth/cd_analytical.py
94+
95+
9596
```
9697

9798
## Computation
9899

99100
We store primitives in a `NamedTuple` built for JAX and create a factory function to generate instances.
100101

101-
```{code-cell} python3
102+
```{code-cell} ipython3
102103
class OptimalGrowthModel(NamedTuple):
103104
α: float # production parameter
104105
β: float # discount factor
@@ -127,7 +128,7 @@ def create_optgrowth_model(α=0.4,
127128
z = jax.random.normal(key, (shock_size,))
128129
shocks = jnp.exp(μ + s * z)
129130
130-
# Avoid endpoints 0 and 1 to keep feasibility and positivity.
131+
# Avoid endpoints 0 and 1 to keep feasibility and positivity
131132
c_grid_frac = jnp.linspace(1e-6, 1.0 - 1e-6, c_grid_size)
132133
return OptimalGrowthModel(α=α, β=β, μ=μ, s=s, γ=γ,
133134
y_grid=y_grid, shocks=shocks,
@@ -136,50 +137,50 @@ def create_optgrowth_model(α=0.4,
136137

137138
We now implement the CRRA utility function, the Bellman operator and the value function iteration loop using JAX
138139

139-
```{code-cell} python3
140+
```{code-cell} ipython3
140141
@jax.jit
141142
def u(c, γ):
142-
# CRRA utility with log at γ = 1.
143+
# CRRA utility with log at γ = 1
143144
return jnp.where(jnp.isclose(γ, 1.0),
144145
jnp.log(c), (c**(1.0 - γ) - 1.0) / (1.0 - γ))
145146
146147
147148
@jax.jit
148149
def T(v, model):
149150
"""
150-
Bellman operator returning greedy policy and updated value.
151+
Bellman operator returning greedy policy and updated value
151152
"""
152153
α, β, γ, shocks = model.α, model.β, model.γ, model.shocks
153154
y_grid, c_grid_frac = model.y_grid, model.c_grid_frac
154155
155-
# Interpolant for value function on the state grid.
156+
# Interpolant for value function on the state grid
156157
vf = lambda x: jnp.interp(x, y_grid, v)
157158
158159
def solve_state(y):
159-
# Candidate consumptions scaled by income.
160+
# Candidate consumptions scaled by income
160161
c = c_grid_frac * y
161162
162-
# Next income for each c and each shock.
163+
# Next income for each c and each shock
163164
k = jnp.maximum(y - c, 1e-12)
164165
y_next = (k**α)[:, None] * shocks[None, :]
165166
166-
# Expected continuation value via Monte Carlo.
167+
# Expected continuation value via Monte Carlo
167168
v_next = vf(y_next.reshape(-1)).reshape(
168169
c.shape[0], shocks.shape[0]).mean(axis=1)
169170
170-
# Objective on the consumption grid.
171+
# Objective on the consumption grid
171172
obj = u(c, γ) + β * v_next
172173
173-
# Maximize over c-grid.
174+
# Maximize over c-grid
174175
idx = jnp.argmax(obj)
175176
176177
c_star = c[idx]
177178
v_val = obj[idx]
178179
return c_star, v_val
179180
180-
# Vectorize across states.
181-
c_star_vec, v_new_vec = jax.vmap(solve_state)(y_grid)
182-
return c_star_vec, v_new_vec
181+
# Vectorize across states
182+
c_star, v_new = jax.vmap(solve_state)(y_grid)
183+
return c_star, v_new
183184
184185
185186
@jax.jit
@@ -205,7 +206,7 @@ def vfi(model, tol=1e-4, max_iter=1_000):
205206

206207
Let us compute the approximate solution at the default parameters
207208

208-
```{code-cell} python3
209+
```{code-cell} ipython3
209210
og = create_optgrowth_model()
210211
211212
with qe.Timer(unit="milliseconds"):
@@ -214,7 +215,7 @@ with qe.Timer(unit="milliseconds"):
214215

215216
Here is a plot of the resulting policy, compared with the true policy:
216217

217-
```{code-cell} python3
218+
```{code-cell} ipython3
218219
fig, ax = plt.subplots()
219220
220221
ax.plot(og.y_grid, v_greedy, lw=2, alpha=0.8,
@@ -232,7 +233,7 @@ the algorithm.
232233

233234
The maximal absolute deviation between the two policies is
234235

235-
```{code-cell} python3
236+
```{code-cell} ipython3
236237
np.max(np.abs(np.asarray(v_greedy)
237238
- np.asarray((1 - og.α * og.β) * og.y_grid)))
238239
```
@@ -264,7 +265,8 @@ Here is the timing.
264265
```{code-cell} ipython3
265266
with qe.Timer(unit="milliseconds"):
266267
for _ in range(20):
267-
v = T(v, og)[1].block_until_ready()
268+
_, v = T(v, og)
269+
v.block_until_ready()
268270
```
269271

270272
Compared with our {ref}`timing <og_ex2>` for the non-compiled version of
@@ -301,20 +303,20 @@ Compare execution time as well.
301303

302304
Here is the CRRA variant using the same code path
303305

304-
```{code-cell} python3
306+
```{code-cell} ipython3
305307
og_crra = create_optgrowth_model(γ=1.5)
306308
```
307309

308310
Let's solve and time the model
309311

310-
```{code-cell} python3
312+
```{code-cell} ipython3
311313
with qe.Timer(unit="milliseconds"):
312314
v_greedy = vfi(og_crra)[0].block_until_ready()
313315
```
314316

315317
Here is a plot of the resulting policy
316318

317-
```{code-cell} python3
319+
```{code-cell} ipython3
318320
fig, ax = plt.subplots()
319321
320322
ax.plot(og_crra.y_grid, v_greedy, lw=2, alpha=0.6,
@@ -371,7 +373,7 @@ Replicate the figure modulo randomness.
371373

372374
Here is one solution.
373375

374-
```{code-cell} python3
376+
```{code-cell} ipython3
375377
import jax.random as jr
376378
377379
def simulate_og(σ_func, og_model, y0=0.1, ts_length=100, seed=0):
@@ -386,7 +388,7 @@ def simulate_og(σ_func, og_model, y0=0.1, ts_length=100, seed=0):
386388
return y
387389
```
388390

389-
```{code-cell} python3
391+
```{code-cell} ipython3
390392
fig, ax = plt.subplots()
391393
392394
for β in (0.8, 0.9, 0.98):

0 commit comments

Comments
 (0)