@@ -3,10 +3,12 @@ jupytext:
33 text_representation :
44 extension : .md
55 format_name : myst
6+ format_version : 0.13
7+ jupytext_version : 1.17.1
68kernelspec :
7- display_name : Python 3
8- language : python
99 name : python3
10+ display_name : Python 3 (ipykernel)
11+ language : python
1012---
1113
1214(optgrowth_fast)=
@@ -26,10 +28,9 @@ kernelspec:
2628
2729In addition to what is in Anaconda, this lecture needs an extra package.
2830
29- ``` {code-cell} ipython
30- ---
31- tags: [hide-output]
32- ---
31+ ``` {code-cell} ipython3
32+ :tags: [hide-output]
33+
3334!pip install quantecon jax
3435```
3536
@@ -58,15 +59,13 @@ accelerate our code.
5859
5960Let's start with some imports:
6061
61- ``` {code-cell} ipython
62+ ``` {code-cell} ipython3
6263import matplotlib.pyplot as plt
6364import numpy as np
6465import jax
6566import jax.numpy as jnp
6667from typing import NamedTuple
6768import quantecon as qe
68-
69- jax.config.update("jax_platform_name", "cpu")
7069```
7170
7271## The model
@@ -90,15 +89,17 @@ The algorithm is unchanged, but the implementation uses JAX.
9089
9190As before, we will be able to compare with the true solutions
9291
93- ``` {code-cell} python3
92+ ``` {code-cell} ipython3
9493:load: _static/lecture_specific/optgrowth/cd_analytical.py
94+
95+
9596```
9697
9798## Computation
9899
99100We store primitives in a ` NamedTuple ` built for JAX and create a factory function to generate instances.
100101
101- ``` {code-cell} python3
102+ ``` {code-cell} ipython3
102103class OptimalGrowthModel(NamedTuple):
103104 α: float # production parameter
104105 β: float # discount factor
@@ -127,7 +128,7 @@ def create_optgrowth_model(α=0.4,
127128 z = jax.random.normal(key, (shock_size,))
128129 shocks = jnp.exp(μ + s * z)
129130
130- # Avoid endpoints 0 and 1 to keep feasibility and positivity.
131+ # Avoid endpoints 0 and 1 to keep feasibility and positivity
131132 c_grid_frac = jnp.linspace(1e-6, 1.0 - 1e-6, c_grid_size)
132133 return OptimalGrowthModel(α=α, β=β, μ=μ, s=s, γ=γ,
133134 y_grid=y_grid, shocks=shocks,
@@ -136,50 +137,50 @@ def create_optgrowth_model(α=0.4,
136137
137138We now implement the CRRA utility function, the Bellman operator and the value function iteration loop using JAX
138139
139- ``` {code-cell} python3
140+ ``` {code-cell} ipython3
140141@jax.jit
141142def u(c, γ):
142- # CRRA utility with log at γ = 1.
143+ # CRRA utility with log at γ = 1
143144 return jnp.where(jnp.isclose(γ, 1.0),
144145 jnp.log(c), (c**(1.0 - γ) - 1.0) / (1.0 - γ))
145146
146147
147148@jax.jit
148149def T(v, model):
149150 """
150- Bellman operator returning greedy policy and updated value.
151+ Bellman operator returning greedy policy and updated value
151152 """
152153 α, β, γ, shocks = model.α, model.β, model.γ, model.shocks
153154 y_grid, c_grid_frac = model.y_grid, model.c_grid_frac
154155
155- # Interpolant for value function on the state grid.
156+ # Interpolant for value function on the state grid
156157 vf = lambda x: jnp.interp(x, y_grid, v)
157158
158159 def solve_state(y):
159- # Candidate consumptions scaled by income.
160+ # Candidate consumptions scaled by income
160161 c = c_grid_frac * y
161162
162- # Next income for each c and each shock.
163+ # Next income for each c and each shock
163164 k = jnp.maximum(y - c, 1e-12)
164165 y_next = (k**α)[:, None] * shocks[None, :]
165166
166- # Expected continuation value via Monte Carlo.
167+ # Expected continuation value via Monte Carlo
167168 v_next = vf(y_next.reshape(-1)).reshape(
168169 c.shape[0], shocks.shape[0]).mean(axis=1)
169170
170- # Objective on the consumption grid.
171+ # Objective on the consumption grid
171172 obj = u(c, γ) + β * v_next
172173
173- # Maximize over c-grid.
174+ # Maximize over c-grid
174175 idx = jnp.argmax(obj)
175176
176177 c_star = c[idx]
177178 v_val = obj[idx]
178179 return c_star, v_val
179180
180- # Vectorize across states.
181- c_star_vec, v_new_vec = jax.vmap(solve_state)(y_grid)
182- return c_star_vec, v_new_vec
181+ # Vectorize across states
182+ c_star, v_new = jax.vmap(solve_state)(y_grid)
183+ return c_star, v_new
183184
184185
185186@jax.jit
@@ -205,7 +206,7 @@ def vfi(model, tol=1e-4, max_iter=1_000):
205206
206207Let us compute the approximate solution at the default parameters
207208
208- ``` {code-cell} python3
209+ ``` {code-cell} ipython3
209210og = create_optgrowth_model()
210211
211212with qe.Timer(unit="milliseconds"):
@@ -214,7 +215,7 @@ with qe.Timer(unit="milliseconds"):
214215
215216Here is a plot of the resulting policy, compared with the true policy:
216217
217- ``` {code-cell} python3
218+ ``` {code-cell} ipython3
218219fig, ax = plt.subplots()
219220
220221ax.plot(og.y_grid, v_greedy, lw=2, alpha=0.8,
@@ -232,7 +233,7 @@ the algorithm.
232233
233234The maximal absolute deviation between the two policies is
234235
235- ``` {code-cell} python3
236+ ``` {code-cell} ipython3
236237np.max(np.abs(np.asarray(v_greedy)
237238 - np.asarray((1 - og.α * og.β) * og.y_grid)))
238239```
@@ -264,7 +265,8 @@ Here is the timing.
264265``` {code-cell} ipython3
265266with qe.Timer(unit="milliseconds"):
266267 for _ in range(20):
267- v = T(v, og)[1].block_until_ready()
268+ _, v = T(v, og)
269+ v.block_until_ready()
268270```
269271
270272Compared with our {ref}` timing <og_ex2> ` for the non-compiled version of
@@ -301,20 +303,20 @@ Compare execution time as well.
301303
302304Here is the CRRA variant using the same code path
303305
304- ``` {code-cell} python3
306+ ``` {code-cell} ipython3
305307og_crra = create_optgrowth_model(γ=1.5)
306308```
307309
308310Let's solve and time the model
309311
310- ``` {code-cell} python3
312+ ``` {code-cell} ipython3
311313with qe.Timer(unit="milliseconds"):
312314 v_greedy = vfi(og_crra)[0].block_until_ready()
313315```
314316
315317Here is a plot of the resulting policy
316318
317- ``` {code-cell} python3
319+ ``` {code-cell} ipython3
318320fig, ax = plt.subplots()
319321
320322ax.plot(og_crra.y_grid, v_greedy, lw=2, alpha=0.6,
@@ -371,7 +373,7 @@ Replicate the figure modulo randomness.
371373
372374Here is one solution.
373375
374- ``` {code-cell} python3
376+ ``` {code-cell} ipython3
375377import jax.random as jr
376378
377379def simulate_og(σ_func, og_model, y0=0.1, ts_length=100, seed=0):
@@ -386,7 +388,7 @@ def simulate_og(σ_func, og_model, y0=0.1, ts_length=100, seed=0):
386388 return y
387389```
388390
389- ``` {code-cell} python3
391+ ``` {code-cell} ipython3
390392fig, ax = plt.subplots()
391393
392394for β in (0.8, 0.9, 0.98):
0 commit comments