@@ -245,8 +245,8 @@ class Household(NamedTuple):
245245def create_household(β=0.96, # Discount factor
246246 Π=[[0.9, 0.1], [0.1, 0.9]], # Markov chain
247247 z_grid=[0.1, 1.0], # Exogenous states
248- a_min=1e-10, a_max=20, # Asset grid
249- a_size=200 ):
248+ a_min=1e-10, a_max=12.5, # Asset grid
249+ a_size=100 ):
250250 """
251251 Create a Household namedtuple with custom grids.
252252 """
278278for all $(a, z, a')$.
279279
280280``` {code-cell} ipython3
281- @jax.jit
282281def B(v, household, prices):
283282 # Unpack
284283 β, a_grid, z_grid, Π = household
@@ -303,125 +302,54 @@ def B(v, household, prices):
303302The next function computes greedy policies
304303
305304``` {code-cell} ipython3
306- @jax.jit
307305def get_greedy(v, household, prices):
308306 """
309- Computes a v-greedy policy σ, returned as a set of indices. If
307+ Computes a v-greedy policy σ, returned as a set of indices. If
310308 σ[i, j] equals ip, then a_grid[ip] is the maximizer at i, j.
311309 """
312310 # argmax over ap
313311 return jnp.argmax(B(v, household, prices), axis=-1)
314312```
315313
316- The following function computes the array $r _ {\sigma}$ which gives current rewards given policy $\sigma$
314+ We define the Bellman operator $T$, which takes a value function $v$ and returns $Tv$ as given in the Bellman equation
317315
318316``` {code-cell} ipython3
319- @jax.jit
320- def compute_r_σ(σ, household, prices):
317+ def T(v, household, prices):
321318 """
322- Compute current rewards at each i, j under policy σ. In particular,
323-
324- r_σ[i, j] = u((1 + r)a[i] + wz[j] - a'[ip])
325-
326- when ip = σ[i, j].
319+ The Bellman operator. Takes a value function v and returns Tv.
327320 """
328- # Unpack
329- β, a_grid, z_grid, Π = household
330- a_size, z_size = len(a_grid), len(z_grid)
331- r, w = prices
332-
333- # Compute r_σ[i, j]
334- a = jnp.reshape(a_grid, (a_size, 1))
335- z = jnp.reshape(z_grid, (1, z_size))
336- ap = a_grid[σ]
337- c = (1 + r) * a + w * z - ap
338- r_σ = u(c)
339-
340- return r_σ
321+ return jnp.max(B(v, household, prices), axis=-1)
341322```
342323
343- The value $v_ {\sigma}$ of a policy $\sigma$ is defined as
344-
345- $$
346- v_{\sigma} = (I - \beta P_{\sigma})^{-1} r_{\sigma}
347- $$
348-
349- (See Ch 5 of [ Dynamic Programming] ( https://dp.quantecon.org/ ) for notation and background on Howard policy iteration.)
350-
351- To compute this vector, we set up the linear map $v \rightarrow R_ {\sigma} v$, where $R_ {\sigma} := I - \beta P_ {\sigma}$.
352-
353- This map can be expressed as
354-
355- $$
356- (R_{\sigma} v)(a, z) = v(a, z) - \beta \sum_{z'} v(\sigma(a, z), z') \Pi(z, z')
357- $$
358-
359- (Notice that $R_ \sigma$ is expressed as a linear operator rather than a matrix—this is much easier and cleaner to code, and also exploits sparsity.)
324+ Here's value function iteration, which repeatedly applies the Bellman operator until convergence
360325
361326``` {code-cell} ipython3
362327@jax.jit
363- def R_σ(v, σ, household):
364- # Unpack
328+ def value_function_iteration(household, prices, tol=1e-4, max_iter=10_000):
329+ """
330+ Implements value function iteration using a compiled JAX loop.
331+ """
365332 β, a_grid, z_grid, Π = household
366333 a_size, z_size = len(a_grid), len(z_grid)
367334
368- # Set up the array v[σ[i, j], jp]
369- zp_idx = jnp.arange(z_size)
370- zp_idx = jnp.reshape(zp_idx, (1, 1, z_size))
371- σ = jnp.reshape(σ, (a_size, z_size, 1))
372- V = v[σ, zp_idx]
373-
374- # Expand Π[j, jp] to Π[i, j, jp]
375- Π = jnp.reshape(Π, (1, z_size, z_size))
376-
377- # Compute and return v[i, j] - β Σ_jp v[σ[i, j], jp] * Π[j, jp]
378- return v - β * jnp.sum(V * Π, axis=-1)
379- ```
335+ def condition_function(loop_state):
336+ i, v, error = loop_state
337+ return jnp.logical_and(error > tol, i < max_iter)
380338
381- The next function computes the lifetime value of a given policy
339+ def update(loop_state):
340+ i, v, error = loop_state
341+ v_new = T(v, household, prices)
342+ error = jnp.max(jnp.abs(v_new - v))
343+ return i + 1, v_new, error
382344
383- ``` {code-cell} ipython3
384- @jax.jit
385- def get_value(σ, household, prices):
386- """
387- Get the lifetime value of policy σ by computing
345+ # Initial loop state
346+ v_init = jnp.zeros((a_size, z_size))
347+ loop_state_init = (0, v_init, tol + 1)
388348
389- v_σ = R_σ^{-1} r_σ
390- """
391- r_σ = compute_r_σ(σ, household, prices)
392-
393- # Reduce R_σ to a function in v
394- _R_σ = lambda v: R_σ(v, σ, household)
349+ # Run the fixed point iteration
350+ i, v, error = jax.lax.while_loop(condition_function, update, loop_state_init)
395351
396- # Compute v_σ = R_σ^{-1} r_σ using an iterative routine.
397- return jax.scipy.sparse.linalg.bicgstab(_R_σ, r_σ)[0]
398- ```
399-
400- Here's the Howard policy iteration
401-
402- ``` {code-cell} ipython3
403- def howard_policy_iteration(household, prices,
404- tol=1e-4, max_iter=10_000, verbose=False):
405- """
406- Howard policy iteration routine.
407- """
408- β, a_grid, z_grid, Π = household
409- a_size, z_size = len(a_grid), len(z_grid)
410- σ = jnp.zeros((a_size, z_size), dtype=int)
411-
412- v_σ = get_value(σ, household, prices)
413- i = 0
414- error = tol + 1
415- while error > tol and i < max_iter:
416- σ_new = get_greedy(v_σ, household, prices)
417- v_σ_new = get_value(σ_new, household, prices)
418- error = jnp.max(jnp.abs(v_σ_new - v_σ))
419- σ = σ_new
420- v_σ = v_σ_new
421- i = i + 1
422- if verbose:
423- print(f"iteration {i} with error {error}.")
424- return σ
352+ return get_greedy(v, household, prices)
425353```
426354
427355As a first example of what we can do, let's compute and plot an optimal accumulation policy at fixed prices
@@ -437,8 +365,7 @@ print(f"Interest rate: {r}, Wage: {w}")
437365
438366``` {code-cell} ipython3
439367with qe.Timer():
440- σ_star = howard_policy_iteration(
441- household, prices, verbose=True).block_until_ready()
368+ σ_star = value_function_iteration(household, prices).block_until_ready()
442369```
443370
444371The next plot shows asset accumulation policies at different values of the exogenous state
@@ -560,7 +487,7 @@ def G(K, firm, household):
560487 # Generate a household object with these prices, compute
561488 # aggregate capital.
562489 prices = Prices(r=r, w=w)
563- σ_star = howard_policy_iteration (household, prices)
490+ σ_star = value_function_iteration (household, prices)
564491 return capital_supply(σ_star, household)
565492```
566493
@@ -640,8 +567,8 @@ def prices_to_capital_stock(household, r, firm):
640567 prices = Prices(r=r, w=w)
641568
642569 # Compute the optimal policy
643- σ_star = howard_policy_iteration (household, prices)
644-
570+ σ_star = value_function_iteration (household, prices)
571+
645572 # Compute capital supply
646573 return capital_supply(σ_star, household)
647574
@@ -752,3 +679,189 @@ plt.show()
752679
753680``` {solution-end}
754681```
682+
683+ ``` {exercise-start}
684+ :label: aiyagari_ex3
685+ ```
686+
687+ In this lecture, we used value function iteration to solve the household problem.
688+
689+ An alternative is Howard policy iteration (HPI), which is discussed in detail in {doc}` opt_savings_2 ` .
690+
691+ HPI can be faster than VFI for some problems because it uses fewer but more computationally intensive iterations.
692+
693+ Your task is to implement Howard policy iteration and compare the results with value function iteration.
694+
695+ ** Key concepts you'll need:**
696+
697+ Howard policy iteration requires computing the value $v_ {\sigma}$ of a policy $\sigma$, defined as:
698+
699+ $$
700+ v_{\sigma} = (I - \beta P_{\sigma})^{-1} r_{\sigma}
701+ $$
702+
703+ where $r_ {\sigma}$ is the reward vector under policy $\sigma$, and $P_ {\sigma}$ is the transition matrix induced by $\sigma$.
704+
705+ To solve this, you'll need to:
706+ 1 . Compute current rewards $r_ {\sigma}(a, z) = u((1 + r)a + wz - \sigma(a, z))$
707+ 2 . Set up the linear operator $R_ {\sigma}$ where $(R_ {\sigma} v)(a, z) = v(a, z) - \beta \sum_ {z'} v(\sigma(a, z), z') \Pi(z, z')$
708+ 3 . Solve $v_ {\sigma} = R_ {\sigma}^{-1} r_ {\sigma}$ using ` jax.scipy.sparse.linalg.bicgstab `
709+
710+ You can use the ` get_greedy ` function that's already defined in this lecture.
711+
712+ Implement the following Howard policy iteration routine:
713+
714+ ``` python
715+ def howard_policy_iteration (household , prices ,
716+ tol = 1e-4 , max_iter = 10_000 , verbose = False ):
717+ """
718+ Howard policy iteration routine.
719+ """
720+ # Your code here
721+ pass
722+ ```
723+
724+ Once implemented, compute the equilibrium capital stock using HPI and verify that it produces approximately the same result as VFI at the default parameter values.
725+
726+ ``` {exercise-end}
727+ ```
728+
729+ ``` {solution-start} aiyagari_ex3
730+ :class: dropdown
731+ ```
732+
733+ First, we need to implement the helper functions for Howard policy iteration.
734+
735+ The following function computes the array $r_ {\sigma}$ which gives current rewards given policy $\sigma$:
736+
737+ ``` {code-cell} ipython3
738+ def compute_r_σ(σ, household, prices):
739+ """
740+ Compute current rewards at each i, j under policy σ. In particular,
741+
742+ r_σ[i, j] = u((1 + r)a[i] + wz[j] - a'[ip])
743+
744+ when ip = σ[i, j].
745+ """
746+ # Unpack
747+ β, a_grid, z_grid, Π = household
748+ a_size, z_size = len(a_grid), len(z_grid)
749+ r, w = prices
750+
751+ # Compute r_σ[i, j]
752+ a = jnp.reshape(a_grid, (a_size, 1))
753+ z = jnp.reshape(z_grid, (1, z_size))
754+ ap = a_grid[σ]
755+ c = (1 + r) * a + w * z - ap
756+ r_σ = u(c)
757+
758+ return r_σ
759+ ```
760+
761+ The linear operator $R_ {\sigma}$ is defined as:
762+
763+ ``` {code-cell} ipython3
764+ def R_σ(v, σ, household):
765+ # Unpack
766+ β, a_grid, z_grid, Π = household
767+ a_size, z_size = len(a_grid), len(z_grid)
768+
769+ # Set up the array v[σ[i, j], jp]
770+ zp_idx = jnp.arange(z_size)
771+ zp_idx = jnp.reshape(zp_idx, (1, 1, z_size))
772+ σ = jnp.reshape(σ, (a_size, z_size, 1))
773+ V = v[σ, zp_idx]
774+
775+ # Expand Π[j, jp] to Π[i, j, jp]
776+ Π = jnp.reshape(Π, (1, z_size, z_size))
777+
778+ # Compute and return v[i, j] - β Σ_jp v[σ[i, j], jp] * Π[j, jp]
779+ return v - β * jnp.sum(V * Π, axis=-1)
780+ ```
781+
782+ The next function computes the lifetime value of a given policy:
783+
784+ ``` {code-cell} ipython3
785+ def get_value(σ, household, prices):
786+ """
787+ Get the lifetime value of policy σ by computing
788+
789+ v_σ = R_σ^{-1} r_σ
790+ """
791+ r_σ = compute_r_σ(σ, household, prices)
792+
793+ # Reduce R_σ to a function in v
794+ _R_σ = lambda v: R_σ(v, σ, household)
795+
796+ # Compute v_σ = R_σ^{-1} r_σ using an iterative routine.
797+ return jax.scipy.sparse.linalg.bicgstab(_R_σ, r_σ)[0]
798+ ```
799+
800+ Now we can implement Howard policy iteration:
801+
802+ ``` {code-cell} ipython3
803+ @jax.jit
804+ def howard_policy_iteration(household, prices, tol=1e-4, max_iter=10_000):
805+ """
806+ Howard policy iteration routine using a compiled JAX loop.
807+ """
808+ β, a_grid, z_grid, Π = household
809+ a_size, z_size = len(a_grid), len(z_grid)
810+
811+ def condition_function(loop_state):
812+ i, σ, v_σ, error = loop_state
813+ return jnp.logical_and(error > tol, i < max_iter)
814+
815+ def update(loop_state):
816+ i, σ, v_σ, error = loop_state
817+ σ_new = get_greedy(v_σ, household, prices)
818+ v_σ_new = get_value(σ_new, household, prices)
819+ error = jnp.max(jnp.abs(v_σ_new - v_σ))
820+ return i + 1, σ_new, v_σ_new, error
821+
822+ # Initial loop state
823+ σ_init = jnp.zeros((a_size, z_size), dtype=int)
824+ v_σ_init = get_value(σ_init, household, prices)
825+ loop_state_init = (0, σ_init, v_σ_init, tol + 1)
826+
827+ # Run the fixed point iteration
828+ i, σ, v_σ, error = jax.lax.while_loop(condition_function, update, loop_state_init)
829+
830+ return σ
831+ ```
832+
833+ Now let's create a modified version of the G function that uses HPI:
834+
835+ ``` {code-cell} ipython3
836+ def G_hpi(K, firm, household):
837+ # Get prices r, w associated with K
838+ r = r_given_k(K, firm)
839+ w = r_to_w(r, firm)
840+
841+ # Generate prices and compute aggregate capital using HPI.
842+ prices = Prices(r=r, w=w)
843+ σ_star = howard_policy_iteration(household, prices)
844+ return capital_supply(σ_star, household)
845+ ```
846+
847+ And compute the equilibrium using HPI:
848+
849+ ``` {code-cell} ipython3
850+ def compute_equilibrium_bisect_hpi(firm, household, a=1.0, b=20.0):
851+ K = bisect(lambda k: k - G_hpi(k, firm, household), a, b, xtol=1e-4)
852+ return K
853+
854+ firm = Firm()
855+ household = create_household()
856+ print("\nComputing equilibrium capital stock using HPI")
857+ with qe.Timer():
858+ K_star_hpi = compute_equilibrium_bisect_hpi(firm, household)
859+ print(f"Computed equilibrium capital stock with HPI: {K_star_hpi:.5}")
860+ print(f"Previous equilibrium capital stock with VFI: {K_star:.5}")
861+ print(f"Difference: {abs(K_star_hpi - K_star):.6}")
862+ ```
863+
864+ The results show that both methods produce approximately the same equilibrium, confirming that HPI is a valid alternative to VFI.
865+
866+ ``` {solution-end}
867+ ```
0 commit comments