@@ -26,7 +26,7 @@ kernelspec:
2626:depth: 2
2727```
2828
29- In addition to what is in Anaconda, this lecture needs an extra package .
29+ In addition to what is in Anaconda, this lecture needs extra packages .
3030
3131``` {code-cell} ipython3
3232:tags: [hide-output]
@@ -64,6 +64,8 @@ import matplotlib.pyplot as plt
6464import numpy as np
6565import jax
6666import jax.numpy as jnp
67+ import jax.random as jr
68+ import jax.scipy.optimize as jsp
6769from typing import NamedTuple
6870import quantecon as qe
6971```
@@ -101,14 +103,13 @@ We store primitives in a `NamedTuple` built for JAX and create a factory functio
101103
102104``` {code-cell} ipython3
103105class OptimalGrowthModel(NamedTuple):
104- α: float # production parameter
105- β: float # discount factor
106- μ: float # shock location parameter
107- s: float # shock scale parameter
108- γ: float # CRRA parameter (γ = 1 gives log)
109- y_grid: jnp.ndarray # grid for output/income
110- shocks: jnp.ndarray # Monte Carlo draws of ξ
111- c_grid_frac: jnp.ndarray # fractional consumption grid in (0, 1)
106+ α: float # production parameter
107+ β: float # discount factor
108+ μ: float # shock location parameter
109+ s: float # shock scale parameter
110+ γ: float # CRRA parameter (γ = 1 gives log)
111+ y_grid: jnp.ndarray # grid for output/income
112+ shocks: jnp.ndarray # Monte Carlo draws of ξ
112113
113114
114115def create_optgrowth_model(α=0.4,
@@ -119,68 +120,104 @@ def create_optgrowth_model(α=0.4,
119120 grid_max=4.0,
120121 grid_size=120,
121122 shock_size=250,
122- c_grid_size=200,
123123 seed=0):
124124 """Factory function to create an OptimalGrowthModel instance."""
125125
126- key = jax.random .PRNGKey(seed)
126+ key = jr .PRNGKey(seed)
127127 y_grid = jnp.linspace(1e-5, grid_max, grid_size)
128- z = jax.random .normal(key, (shock_size,))
128+ z = jr .normal(key, (shock_size,))
129129 shocks = jnp.exp(μ + s * z)
130130
131- # Avoid endpoints 0 and 1 to keep feasibility and positivity
132- c_grid_frac = jnp.linspace(1e-6, 1.0 - 1e-6, c_grid_size)
133131 return OptimalGrowthModel(α=α, β=β, μ=μ, s=s, γ=γ,
134- y_grid=y_grid, shocks=shocks,
135- c_grid_frac=c_grid_frac)
132+ y_grid=y_grid, shocks=shocks)
136133```
137134
138- We now implement the CRRA utility function, the Bellman operator and the value function iteration loop using JAX
135+ We now implement the CRRA utility function, the Bellman operator and the value function iteration loop using JAX.
136+
137+ We also implement a golden section search for scalar maximization needed to solve the Bellman equation.
139138
140139``` {code-cell} ipython3
141- @jax.jit
142140def u(c, γ):
143- # CRRA utility with log at γ = 1
144- return jnp.where(jnp.isclose(γ, 1.0),
145- jnp.log(c), (c**(1.0 - γ) - 1.0) / (1.0 - γ))
146-
141+ return jnp.where(jnp.isclose(γ, 1.0),
142+ jnp.log(c), (c**(1.0 - γ) - 1.0) / (1.0 - γ))
147143
148- @jax.jit
149- def T(v, model):
144+ def state_action_value(c, y, v, model):
150145 """
151- Bellman operator returning greedy policy and updated value
146+ Right hand side of the Bellman equation.
152147 """
153148 α, β, γ, shocks = model.α, model.β, model.γ, model.shocks
154- y_grid, c_grid_frac = model.y_grid, model.c_grid_frac
149+ y_grid = model.y_grid
150+
151+ # Compute capital
152+ k = y - c
153+
154+ # Compute next period income for all shocks
155+ y_next = (k**α) * shocks
156+
157+ # Interpolate to get continuation values
158+ continuation = jnp.interp(y_next, y_grid, v).mean()
159+
160+ return u(c, γ) + β * continuation
161+
162+ def golden_max(f, a, b, args=(), tol=1e-5, max_iter=100):
163+ """
164+ Golden section search for maximum of f on [a, b].
165+ """
166+ golden_ratio = (jnp.sqrt(5.0) - 1.0) / 2.0
167+
168+ # Initialize
169+ x1 = b - golden_ratio * (b - a)
170+ x2 = a + golden_ratio * (b - a)
171+ f1 = f(x1, *args)
172+ f2 = f(x2, *args)
173+
174+ def body(state):
175+ a, b, x1, x2, f1, f2, i = state
176+
177+ # Update interval based on function values
178+ use_right = f2 > f1
179+
180+ a_new = jnp.where(use_right, x1, a)
181+ b_new = jnp.where(use_right, b, x2)
182+ x1_new = jnp.where(use_right, x2,
183+ b_new - golden_ratio * (b_new - a_new))
184+ x2_new = jnp.where(use_right,
185+ a_new + golden_ratio * (b_new - a_new), x1)
186+ f1_new = jnp.where(use_right, f2, f(x1_new, *args))
187+ f2_new = jnp.where(use_right, f(x2_new, *args), f1)
155188
156- # Interpolant for value function on the state grid
157- vf = lambda x: jnp.interp(x, y_grid, v)
189+ return a_new, b_new, x1_new, x2_new, f1_new, f2_new, i + 1
158190
159- def solve_state(y ):
160- # Candidate consumptions scaled by income
161- c = c_grid_frac * y
191+ def cond(state ):
192+ a, b, x1, x2, f1, f2, i = state
193+ return (jnp.abs(b - a) > tol) & (i < max_iter)
162194
163- # Next income for each c and each shock
164- k = jnp.maximum(y - c, 1e-12 )
165- y_next = (k**α)[:, None] * shocks[None, :]
195+ a_f, b_f, x1_f, x2_f, f1_f, f2_f, _ = jax.lax.while_loop(
196+ cond, body, (a, b, x1, x2, f1, f2, 0 )
197+ )
166198
167- # Expected continuation value via Monte Carlo
168- v_next = vf(y_next.reshape(-1)).reshape(
169- c.shape[0], shocks.shape[0]).mean(axis=1 )
199+ # Return the best point
200+ x_max = jnp.where(f1_f > f2_f, x1_f, x2_f)
201+ f_max = jnp.maximum(f1_f, f2_f )
170202
171- # Objective on the consumption grid
172- obj = u(c, γ) + β * v_next
203+ return x_max, f_max
173204
174- # Maximize over c-grid
175- idx = jnp.argmax(obj)
205+ @jax.jit
206+ def T(v, model):
207+ """
208+ Bellman operator returning greedy policy and updated value
209+ """
210+ y_grid = model.y_grid
176211
177- c_star = c[idx]
178- v_val = obj[idx]
179- return c_star, v_val
212+ def maximize_at_state(y):
213+ # Maximize RHS of Bellman equation at state y
214+ c_star, v_max = golden_max(state_action_value,
215+ 1e-10, y - 1e-10,
216+ args=(y, v, model))
217+ return c_star, v_max
180218
181- # Vectorize across states
182- c_star, v_new = jax.vmap(solve_state)(y_grid)
183- return c_star, v_new
219+ v_greedy, v_new = jax.vmap(maximize_at_state)(y_grid)
220+ return v_greedy, v_new
184221
185222
186223@jax.jit
@@ -210,19 +247,20 @@ Let us compute the approximate solution at the default parameters
210247og = create_optgrowth_model()
211248
212249with qe.Timer(unit="milliseconds"):
213- v_greedy = vfi(og)[0].block_until_ready()
250+ c_greedy, _ = vfi(og)
251+ c_greedy.block_until_ready()
214252```
215253
216254Here is a plot of the resulting policy, compared with the true policy:
217255
218256``` {code-cell} ipython3
219257fig, ax = plt.subplots()
220258
221- ax.plot(og.y_grid, v_greedy , lw=2, alpha=0.8,
222- label='approximate policy function')
259+ ax.plot(og.y_grid, c_greedy , lw=2, alpha=0.8,
260+ label='approximate policy function')
223261
224- ax.plot(og.y_grid, (1 - og.α * og.β) * og.y_grid,
225- 'k--', lw=2, alpha=0.8, label='true policy function')
262+ ax.plot(og.y_grid, (1 - og.α * og.β) * og.y_grid,
263+ 'k--', lw=2, alpha=0.8, label='true policy function')
226264
227265ax.legend()
228266plt.show()
@@ -234,26 +272,23 @@ the algorithm.
234272The maximal absolute deviation between the two policies is
235273
236274``` {code-cell} ipython3
237- np.max(np.abs(np.asarray(v_greedy)
238- - np.asarray((1 - og.α * og.β) * og.y_grid)))
275+ jnp.max(jnp.abs(c_greedy - (1 - og.α * og.β) * og.y_grid))
239276```
240277
241278## Exercises
242279
243280``` {exercise-start}
244281:label: ogfast_ex1
245282```
246-
247283Time how long it takes to iterate with the Bellman operator 20 times, starting from initial condition $v(y) = u(y)$.
248284
249- Use the default parameterization.
285+ Use the default parameterization and [ ` jax.lax.fori_loop ` ] ( https://docs.jax.dev/en/latest/_autosummary/jax.lax.fori_loop.html#jax.lax.fori_loop ) for the iteration .
250286``` {exercise-end}
251287```
252288
253289``` {solution-start} ogfast_ex1
254290:class: dropdown
255291```
256-
257292Let's set up the initial condition.
258293
259294``` {code-cell} ipython3
@@ -264,14 +299,14 @@ Here is the timing.
264299
265300``` {code-cell} ipython3
266301with qe.Timer(unit="milliseconds"):
267- for _ in range(20):
268- _, v = T(v, og)
302+ def bellman_step(_, v_curr):
303+ return T(v_curr, og)[1]
304+ v = jax.lax.fori_loop(0, 20, bellman_step, v)
269305 v.block_until_ready()
270306```
271307
272308Compared with our {ref}` timing <og_ex2> ` for the non-compiled version of
273309value function iteration, the JIT-compiled code is usually an order of magnitude faster.
274-
275310``` {solution-end}
276311```
277312
@@ -296,11 +331,9 @@ Compare execution time as well.
296331``` {exercise-end}
297332```
298333
299-
300334``` {solution-start} ogfast_ex2
301335:class: dropdown
302336```
303-
304337Here is the CRRA variant using the same code path
305338
306339``` {code-cell} ipython3
@@ -311,16 +344,17 @@ Let's solve and time the model
311344
312345``` {code-cell} ipython3
313346with qe.Timer(unit="milliseconds"):
314- v_greedy = vfi(og_crra)[0].block_until_ready()
347+ c_greedy, _ = vfi(og_crra)
348+ c_greedy.block_until_ready()
315349```
316350
317351Here is a plot of the resulting policy
318352
319353``` {code-cell} ipython3
320354fig, ax = plt.subplots()
321355
322- ax.plot(og_crra.y_grid, v_greedy , lw=2, alpha=0.6,
323- label='approximate policy function')
356+ ax.plot(og_crra.y_grid, c_greedy , lw=2, alpha=0.6,
357+ label='approximate policy function')
324358
325359ax.legend(loc='lower right')
326360plt.show()
@@ -329,15 +363,12 @@ plt.show()
329363This matches the solution obtained in the non-jitted code in {ref}` the earlier exercise <og_ex1> ` .
330364
331365Execution time is an order of magnitude faster.
332-
333366``` {solution-end}
334367```
335368
336-
337369``` {exercise-start}
338370:label: ogfast_ex3
339371```
340-
341372In this exercise we return to the original log utility specification.
342373
343374Once an optimal consumption policy $\sigma$ is given, income follows
@@ -363,28 +394,26 @@ Other parameters match the log-linear model discussed earlier.
363394Notice that more patient agents typically have higher wealth.
364395
365396Replicate the figure modulo randomness.
366-
367397``` {exercise-end}
368398```
369399
370400``` {solution-start} ogfast_ex3
371401:class: dropdown
372402```
373-
374403Here is one solution.
375404
376405``` {code-cell} ipython3
377- import jax.random as jr
378-
379406def simulate_og(σ_func, og_model, y0=0.1, ts_length=100, seed=0):
380- """Compute a time series given consumption policy σ."""
407+ """
408+ Compute a time series given consumption policy σ.
409+ """
381410 key = jr.PRNGKey(seed)
382411 ξ = jr.normal(key, (ts_length - 1,))
383412 y = np.empty(ts_length)
384413 y[0] = y0
385414 for t in range(ts_length - 1):
386415 y[t+1] = (y[t] - σ_func(y[t]))**og_model.α \
387- * np.exp(og_model.μ + og_model.s * ξ[t])
416+ * np.exp(og_model.μ + og_model.s * ξ[t])
388417 return y
389418```
390419
@@ -394,10 +423,9 @@ fig, ax = plt.subplots()
394423for β in (0.8, 0.9, 0.98):
395424
396425 og_temp = create_optgrowth_model(β=β, s=0.05)
397- v_greedy, v_solution = vfi(og_temp)
426+ c_greedy_temp, _ = vfi(og_temp)
398427
399- # Define an optimal policy function
400- σ_func = lambda x: np.interp(x, og_temp.y_grid, np.asarray(v_greedy))
428+ σ_func = lambda x: np.interp(x, og_temp.y_grid, np.asarray(c_greedy_temp))
401429 y = simulate_og(σ_func, og_temp)
402430 ax.plot(y, lw=2, alpha=0.6, label=rf'$\beta = {β}$')
403431
0 commit comments