Skip to content

Commit 9be4a24

Browse files
committed
update code to use scalar maximization algos
1 parent 2b606ca commit 9be4a24

File tree

1 file changed

+105
-77
lines changed

1 file changed

+105
-77
lines changed

lectures/optgrowth_fast.md

Lines changed: 105 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ kernelspec:
2626
:depth: 2
2727
```
2828

29-
In addition to what is in Anaconda, this lecture needs an extra package.
29+
In addition to what is in Anaconda, this lecture needs extra packages.
3030

3131
```{code-cell} ipython3
3232
:tags: [hide-output]
@@ -64,6 +64,8 @@ import matplotlib.pyplot as plt
6464
import numpy as np
6565
import jax
6666
import jax.numpy as jnp
67+
import jax.random as jr
68+
import jax.scipy.optimize as jsp
6769
from typing import NamedTuple
6870
import quantecon as qe
6971
```
@@ -101,14 +103,13 @@ We store primitives in a `NamedTuple` built for JAX and create a factory functio
101103

102104
```{code-cell} ipython3
103105
class OptimalGrowthModel(NamedTuple):
104-
α: float # production parameter
105-
β: float # discount factor
106-
μ: float # shock location parameter
107-
s: float # shock scale parameter
108-
γ: float # CRRA parameter (γ = 1 gives log)
109-
y_grid: jnp.ndarray # grid for output/income
110-
shocks: jnp.ndarray # Monte Carlo draws of ξ
111-
c_grid_frac: jnp.ndarray # fractional consumption grid in (0, 1)
106+
α: float # production parameter
107+
β: float # discount factor
108+
μ: float # shock location parameter
109+
s: float # shock scale parameter
110+
γ: float # CRRA parameter (γ = 1 gives log)
111+
y_grid: jnp.ndarray # grid for output/income
112+
shocks: jnp.ndarray # Monte Carlo draws of ξ
112113
113114
114115
def create_optgrowth_model(α=0.4,
@@ -119,68 +120,104 @@ def create_optgrowth_model(α=0.4,
119120
grid_max=4.0,
120121
grid_size=120,
121122
shock_size=250,
122-
c_grid_size=200,
123123
seed=0):
124124
"""Factory function to create an OptimalGrowthModel instance."""
125125
126-
key = jax.random.PRNGKey(seed)
126+
key = jr.PRNGKey(seed)
127127
y_grid = jnp.linspace(1e-5, grid_max, grid_size)
128-
z = jax.random.normal(key, (shock_size,))
128+
z = jr.normal(key, (shock_size,))
129129
shocks = jnp.exp(μ + s * z)
130130
131-
# Avoid endpoints 0 and 1 to keep feasibility and positivity
132-
c_grid_frac = jnp.linspace(1e-6, 1.0 - 1e-6, c_grid_size)
133131
return OptimalGrowthModel(α=α, β=β, μ=μ, s=s, γ=γ,
134-
y_grid=y_grid, shocks=shocks,
135-
c_grid_frac=c_grid_frac)
132+
y_grid=y_grid, shocks=shocks)
136133
```
137134

138-
We now implement the CRRA utility function, the Bellman operator and the value function iteration loop using JAX
135+
We now implement the CRRA utility function, the Bellman operator and the value function iteration loop using JAX.
136+
137+
We also implement a golden section search for scalar maximization needed to solve the Bellman equation.
139138

140139
```{code-cell} ipython3
141-
@jax.jit
142140
def u(c, γ):
143-
# CRRA utility with log at γ = 1
144-
return jnp.where(jnp.isclose(γ, 1.0),
145-
jnp.log(c), (c**(1.0 - γ) - 1.0) / (1.0 - γ))
146-
141+
return jnp.where(jnp.isclose(γ, 1.0),
142+
jnp.log(c), (c**(1.0 - γ) - 1.0) / (1.0 - γ))
147143
148-
@jax.jit
149-
def T(v, model):
144+
def state_action_value(c, y, v, model):
150145
"""
151-
Bellman operator returning greedy policy and updated value
146+
Right hand side of the Bellman equation.
152147
"""
153148
α, β, γ, shocks = model.α, model.β, model.γ, model.shocks
154-
y_grid, c_grid_frac = model.y_grid, model.c_grid_frac
149+
y_grid = model.y_grid
150+
151+
# Compute capital
152+
k = y - c
153+
154+
# Compute next period income for all shocks
155+
y_next = (k**α) * shocks
156+
157+
# Interpolate to get continuation values
158+
continuation = jnp.interp(y_next, y_grid, v).mean()
159+
160+
return u(c, γ) + β * continuation
161+
162+
def golden_max(f, a, b, args=(), tol=1e-5, max_iter=100):
163+
"""
164+
Golden section search for maximum of f on [a, b].
165+
"""
166+
golden_ratio = (jnp.sqrt(5.0) - 1.0) / 2.0
167+
168+
# Initialize
169+
x1 = b - golden_ratio * (b - a)
170+
x2 = a + golden_ratio * (b - a)
171+
f1 = f(x1, *args)
172+
f2 = f(x2, *args)
173+
174+
def body(state):
175+
a, b, x1, x2, f1, f2, i = state
176+
177+
# Update interval based on function values
178+
use_right = f2 > f1
179+
180+
a_new = jnp.where(use_right, x1, a)
181+
b_new = jnp.where(use_right, b, x2)
182+
x1_new = jnp.where(use_right, x2,
183+
b_new - golden_ratio * (b_new - a_new))
184+
x2_new = jnp.where(use_right,
185+
a_new + golden_ratio * (b_new - a_new), x1)
186+
f1_new = jnp.where(use_right, f2, f(x1_new, *args))
187+
f2_new = jnp.where(use_right, f(x2_new, *args), f1)
155188
156-
# Interpolant for value function on the state grid
157-
vf = lambda x: jnp.interp(x, y_grid, v)
189+
return a_new, b_new, x1_new, x2_new, f1_new, f2_new, i + 1
158190
159-
def solve_state(y):
160-
# Candidate consumptions scaled by income
161-
c = c_grid_frac * y
191+
def cond(state):
192+
a, b, x1, x2, f1, f2, i = state
193+
return (jnp.abs(b - a) > tol) & (i < max_iter)
162194
163-
# Next income for each c and each shock
164-
k = jnp.maximum(y - c, 1e-12)
165-
y_next = (k**α)[:, None] * shocks[None, :]
195+
a_f, b_f, x1_f, x2_f, f1_f, f2_f, _ = jax.lax.while_loop(
196+
cond, body, (a, b, x1, x2, f1, f2, 0)
197+
)
166198
167-
# Expected continuation value via Monte Carlo
168-
v_next = vf(y_next.reshape(-1)).reshape(
169-
c.shape[0], shocks.shape[0]).mean(axis=1)
199+
# Return the best point
200+
x_max = jnp.where(f1_f > f2_f, x1_f, x2_f)
201+
f_max = jnp.maximum(f1_f, f2_f)
170202
171-
# Objective on the consumption grid
172-
obj = u(c, γ) + β * v_next
203+
return x_max, f_max
173204
174-
# Maximize over c-grid
175-
idx = jnp.argmax(obj)
205+
@jax.jit
206+
def T(v, model):
207+
"""
208+
Bellman operator returning greedy policy and updated value
209+
"""
210+
y_grid = model.y_grid
176211
177-
c_star = c[idx]
178-
v_val = obj[idx]
179-
return c_star, v_val
212+
def maximize_at_state(y):
213+
# Maximize RHS of Bellman equation at state y
214+
c_star, v_max = golden_max(state_action_value,
215+
1e-10, y - 1e-10,
216+
args=(y, v, model))
217+
return c_star, v_max
180218
181-
# Vectorize across states
182-
c_star, v_new = jax.vmap(solve_state)(y_grid)
183-
return c_star, v_new
219+
v_greedy, v_new = jax.vmap(maximize_at_state)(y_grid)
220+
return v_greedy, v_new
184221
185222
186223
@jax.jit
@@ -210,19 +247,20 @@ Let us compute the approximate solution at the default parameters
210247
og = create_optgrowth_model()
211248
212249
with qe.Timer(unit="milliseconds"):
213-
v_greedy = vfi(og)[0].block_until_ready()
250+
c_greedy, _ = vfi(og)
251+
c_greedy.block_until_ready()
214252
```
215253

216254
Here is a plot of the resulting policy, compared with the true policy:
217255

218256
```{code-cell} ipython3
219257
fig, ax = plt.subplots()
220258
221-
ax.plot(og.y_grid, v_greedy, lw=2, alpha=0.8,
222-
label='approximate policy function')
259+
ax.plot(og.y_grid, c_greedy, lw=2, alpha=0.8,
260+
label='approximate policy function')
223261
224-
ax.plot(og.y_grid, (1 - og.α * og.β) * og.y_grid,
225-
'k--', lw=2, alpha=0.8, label='true policy function')
262+
ax.plot(og.y_grid, (1 - og.α * og.β) * og.y_grid,
263+
'k--', lw=2, alpha=0.8, label='true policy function')
226264
227265
ax.legend()
228266
plt.show()
@@ -234,26 +272,23 @@ the algorithm.
234272
The maximal absolute deviation between the two policies is
235273

236274
```{code-cell} ipython3
237-
np.max(np.abs(np.asarray(v_greedy)
238-
- np.asarray((1 - og.α * og.β) * og.y_grid)))
275+
jnp.max(jnp.abs(c_greedy - (1 - og.α * og.β) * og.y_grid))
239276
```
240277

241278
## Exercises
242279

243280
```{exercise-start}
244281
:label: ogfast_ex1
245282
```
246-
247283
Time how long it takes to iterate with the Bellman operator 20 times, starting from initial condition $v(y) = u(y)$.
248284

249-
Use the default parameterization.
285+
Use the default parameterization and [`jax.lax.fori_loop`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.fori_loop.html#jax.lax.fori_loop) for the iteration.
250286
```{exercise-end}
251287
```
252288

253289
```{solution-start} ogfast_ex1
254290
:class: dropdown
255291
```
256-
257292
Let's set up the initial condition.
258293

259294
```{code-cell} ipython3
@@ -264,14 +299,14 @@ Here is the timing.
264299

265300
```{code-cell} ipython3
266301
with qe.Timer(unit="milliseconds"):
267-
for _ in range(20):
268-
_, v = T(v, og)
302+
def bellman_step(_, v_curr):
303+
return T(v_curr, og)[1]
304+
v = jax.lax.fori_loop(0, 20, bellman_step, v)
269305
v.block_until_ready()
270306
```
271307

272308
Compared with our {ref}`timing <og_ex2>` for the non-compiled version of
273309
value function iteration, the JIT-compiled code is usually an order of magnitude faster.
274-
275310
```{solution-end}
276311
```
277312

@@ -296,11 +331,9 @@ Compare execution time as well.
296331
```{exercise-end}
297332
```
298333

299-
300334
```{solution-start} ogfast_ex2
301335
:class: dropdown
302336
```
303-
304337
Here is the CRRA variant using the same code path
305338

306339
```{code-cell} ipython3
@@ -311,16 +344,17 @@ Let's solve and time the model
311344

312345
```{code-cell} ipython3
313346
with qe.Timer(unit="milliseconds"):
314-
v_greedy = vfi(og_crra)[0].block_until_ready()
347+
c_greedy, _ = vfi(og_crra)
348+
c_greedy.block_until_ready()
315349
```
316350

317351
Here is a plot of the resulting policy
318352

319353
```{code-cell} ipython3
320354
fig, ax = plt.subplots()
321355
322-
ax.plot(og_crra.y_grid, v_greedy, lw=2, alpha=0.6,
323-
label='approximate policy function')
356+
ax.plot(og_crra.y_grid, c_greedy, lw=2, alpha=0.6,
357+
label='approximate policy function')
324358
325359
ax.legend(loc='lower right')
326360
plt.show()
@@ -329,15 +363,12 @@ plt.show()
329363
This matches the solution obtained in the non-jitted code in {ref}`the earlier exercise <og_ex1>`.
330364

331365
Execution time is an order of magnitude faster.
332-
333366
```{solution-end}
334367
```
335368

336-
337369
```{exercise-start}
338370
:label: ogfast_ex3
339371
```
340-
341372
In this exercise we return to the original log utility specification.
342373

343374
Once an optimal consumption policy $\sigma$ is given, income follows
@@ -363,28 +394,26 @@ Other parameters match the log-linear model discussed earlier.
363394
Notice that more patient agents typically have higher wealth.
364395

365396
Replicate the figure modulo randomness.
366-
367397
```{exercise-end}
368398
```
369399

370400
```{solution-start} ogfast_ex3
371401
:class: dropdown
372402
```
373-
374403
Here is one solution.
375404

376405
```{code-cell} ipython3
377-
import jax.random as jr
378-
379406
def simulate_og(σ_func, og_model, y0=0.1, ts_length=100, seed=0):
380-
"""Compute a time series given consumption policy σ."""
407+
"""
408+
Compute a time series given consumption policy σ.
409+
"""
381410
key = jr.PRNGKey(seed)
382411
ξ = jr.normal(key, (ts_length - 1,))
383412
y = np.empty(ts_length)
384413
y[0] = y0
385414
for t in range(ts_length - 1):
386415
y[t+1] = (y[t] - σ_func(y[t]))**og_model.α \
387-
* np.exp(og_model.μ + og_model.s * ξ[t])
416+
* np.exp(og_model.μ + og_model.s * ξ[t])
388417
return y
389418
```
390419

@@ -394,10 +423,9 @@ fig, ax = plt.subplots()
394423
for β in (0.8, 0.9, 0.98):
395424
396425
og_temp = create_optgrowth_model(β=β, s=0.05)
397-
v_greedy, v_solution = vfi(og_temp)
426+
c_greedy_temp, _ = vfi(og_temp)
398427
399-
# Define an optimal policy function
400-
σ_func = lambda x: np.interp(x, og_temp.y_grid, np.asarray(v_greedy))
428+
σ_func = lambda x: np.interp(x, og_temp.y_grid, np.asarray(c_greedy_temp))
401429
y = simulate_og(σ_func, og_temp)
402430
ax.plot(y, lw=2, alpha=0.6, label=rf'$\beta = {β}$')
403431

0 commit comments

Comments
 (0)