@@ -399,8 +399,8 @@ linear interpolation of $(a^e_{ij}, c_{ij})$ over $i$ for each $j$.
399399
400400``` {code-cell} ipython3
401401def K_numpy(
402- c_vals: np.ndarray,
403- ae_vals: np.ndarray,
402+ c_vals: np.ndarray, # Initial guess of σ on grid endogenous grid
403+ ae_vals: np.ndarray, # Initial endogenous grid
404404 ifp_numpy: IFPNumPy
405405 ) -> np.ndarray:
406406 """
@@ -441,7 +441,8 @@ To solve the model we use a simple while loop.
441441``` {code-cell} ipython3
442442def solve_model_numpy(
443443 ifp_numpy: IFPNumPy,
444- c_vals: np.ndarray,
444+ ae_vals_init: np.ndarray,
445+ c_vals_init: np.ndarray,
445446 tol: float = 1e-5,
446447 max_iter: int = 1_000
447448 ) -> np.ndarray:
@@ -450,7 +451,6 @@ def solve_model_numpy(
450451
451452 """
452453 i = 0
453- ae_vals = c_vals # Initial condition
454454 error = tol + 1
455455
456456 while error > tol and i < max_iter:
@@ -467,8 +467,13 @@ Let's road test the EGM code.
467467``` {code-cell} ipython3
468468ifp_numpy = create_ifp()
469469R, β, γ, Π, z_grid, s = ifp_numpy
470- initial_c_vals = s[:, None] * np.ones(len(z_grid))
471- c_vals, ae_vals = solve_model_numpy(ifp_numpy, initial_c_vals)
470+ # Initial conditions -- agent consumes everything
471+ ae_vals_init = s[:, None] * np.ones(len(z_grid))
472+ c_vals_init = ae_vals_init
473+ # Solve from these initial conditions
474+ c_vals, ae_vals = solve_model_numpy(
475+ ifp_numpy, c_vals_init, ae_vals_init
476+ )
472477```
473478
474479Here's a plot of the optimal consumption policy for each $z$ state
@@ -601,10 +606,13 @@ Here's a jit-accelerated iterative routine to solve the model using this operato
601606
602607``` {code-cell} ipython3
603608@jax.jit
604- def solve_model(ifp: IFP,
605- c_vals: jnp.ndarray,
606- tol: float = 1e-5,
607- max_iter: int = 1000) -> jnp.ndarray:
609+ def solve_model(
610+ ifp: IFP,
611+ c_vals_init: jnp.ndarray, # Initial guess of σ on grid endogenous grid
612+ ae_vals_init: jnp.ndarray, # Initial endogenous grid
613+ tol: float = 1e-5,
614+ max_iter: int = 1000
615+ ) -> jnp.ndarray:
608616 """
609617 Solve the model using time iteration with EGM.
610618
@@ -621,8 +629,8 @@ def solve_model(ifp: IFP,
621629 i += 1
622630 return new_c_vals, new_ae_vals, i, error
623631
624- ae_vals = c_vals
625- initial_state = (c_vals, ae_vals, 0, tol + 1 )
632+ i, error = 0, tol + 1
633+ initial_state = (c_vals_init, ae_vals_init, i, error )
626634 final_loop_state = jax.lax.while_loop(condition, body, initial_state)
627635 c_vals, ae_vals, i, error = final_loop_state
628636
@@ -637,8 +645,11 @@ Let's road test the EGM code.
637645``` {code-cell} ipython3
638646ifp = create_ifp()
639647R, β, γ, Π, z_grid, s = ifp
640- c_vals_init = s[:, None] * jnp.ones(len(z_grid))
641- c_vals_jax, ae_vals_jax = solve_model(ifp, c_vals_init)
648+ # Set initial conditions where the agent consumes everything
649+ ae_vals_init = s[:, None] * jnp.ones(len(z_grid))
650+ c_vals_init = ae_vals_init
651+ # Solve starting from these initial conditions
652+ c_vals_jax, ae_vals_jax = solve_model(ifp, c_vals_init, ae_vals_init)
642653```
643654
644655To verify the correctness of our JAX implementation, let's compare it with the NumPy version we developed earlier.
@@ -735,8 +746,9 @@ Let's see if we match up:
735746``` {code-cell} ipython3
736747ifp_cake_eating = create_ifp(r=0.0, z_grid=(-jnp.inf, -jnp.inf))
737748R, β, γ, Π, z_grid, s = ifp_cake_eating
738- c_vals_init = s[:, None] * jnp.ones(len(z_grid))
739- c_vals, ae_vals = solve_model(ifp_cake_eating, c_vals_init)
749+ ae_vals_init = s[:, None] * jnp.ones(len(z_grid))
750+ c_vals_init = ae_vals_init
751+ c_vals, ae_vals = solve_model(ifp_cake_eating, c_vals_init, ae_vals_init)
740752
741753fig, ax = plt.subplots()
742754ax.plot(ae_vals[:, 0], c_vals[:, 0], label='numerical')
@@ -758,11 +770,6 @@ Our plan is to run a large number of households forward for $T$ periods and then
758770histogram the cross-sectional distribution of assets.
759771
760772Set ` num_households=50_000, T=500 ` .
761- ```
762-
763- ```{solution-start} ifp_egm_ex2
764- :class: dropdown
765- ```
766773
767774First we write a function to run a single household forward in time and record
768775the final value of assets.
@@ -773,11 +780,11 @@ as representing an optimal policy associated with a given model `ifp`
773780``` {code-cell} ipython3
774781@jax.jit
775782def simulate_household(
776- key, a_0, z_idx_0, c_vals, ae_vals, ifp, num_households, T
783+ key, a_0, z_idx_0, c_vals, ae_vals, ifp, T
777784 ):
778785 """
779- Simulates num_households households for T periods to approximate
780- the stationary distribution of assets.
786+ Simulates a single household for T periods to approximate the stationary
787+ distribution of assets.
781788
782789 - key is the state of the random number generator
783790 - ifp is an instance of IFP
@@ -793,13 +800,12 @@ def simulate_household(
793800 # Simulate forward T periods
794801 def update(state, t):
795802 a, z_idx = state
796- c = σ(a, z_idx)
797803 # Draw next shock z' from Π[z, z']
798804 current_key = jax.random.fold_in(t, key)
799805 z_next_idx = jax.random.choice(current_key, n_z, p=Π[z_idx])
800806 z_next = z_grid[z_next_idx]
801807 # Update assets: a' = R * (a - c) + Y'
802- a_next = R * (a - c ) + y(z_next)
808+ a_next = R * (a - σ(a, z_idx) ) + y(z_next)
803809 # Return updated state
804810 return a_next, z_next_idx
805811
@@ -819,12 +825,10 @@ def compute_asset_stationary(
819825 Simulates num_households households for T periods to approximate
820826 the stationary distribution of assets.
821827
822- By ergodicity, simulating many households for moderate time is equivalent to
823- simulating one household for very long time, but parallelizes better.
828+ Returns the final cross-section of asset holdings.
824829
825- ifp is an instance of IFP
826- c_vals, ae_vals are the consumption policy and endogenous grid from
827- solve_model
830+ - ifp is an instance of IFP
831+ - c_vals, ae_vals are the optimal consumption policy and endogenous grid.
828832
829833 """
830834 R, β, γ, Π, z_grid, s = ifp
@@ -856,8 +860,9 @@ Now we call the function, generate the asset distribution and histogram it:
856860``` {code-cell} ipython3
857861ifp = create_ifp()
858862R, β, γ, Π, z_grid, s = ifp
859- c_vals_init = s[:, None] * jnp.ones(len(z_grid))
860- c_vals, ae_vals = solve_model(ifp, c_vals_init)
863+ ae_vals_init = s[:, None] * jnp.ones(len(z_grid))
864+ c_vals_init = ae_vals_init
865+ c_vals, ae_vals = solve_model(ifp, c_vals_init, ae_vals_init)
861866assets = compute_asset_stationary(ifp, c_vals, ae_vals)
862867
863868fig, ax = plt.subplots()
@@ -906,9 +911,14 @@ fig, ax = plt.subplots()
906911for r_val in r_vals:
907912 ifp = create_ifp(r=r_val)
908913 R, β, γ, Π, z_grid, s = ifp
909- c_vals_init = s[:, None] * jnp.ones(len(z_grid))
914+ ae_vals_init = s[:, None] * jnp.ones(len(z_grid))
915+ c_vals_init = ae_vals_init
910916 c_vals, ae_vals = solve_model(ifp, c_vals_init)
917+ # Plot policy
911918 ax.plot(ae_vals[:, 0], c_vals[:, 0], label=f'$r = {r_val:.3f}$')
919+ # Start next round with last solution
920+ c_vals_init = c_vals
921+ ae_vals_init = ae_vals
912922
913923ax.set(xlabel='asset level', ylabel='consumption (low income)')
914924ax.legend()
@@ -921,7 +931,7 @@ plt.show()
921931
922932
923933``` {exercise-start}
924- :label: ifp_egm_ex3
934+ :label: ifp_egm_ex2
925935```
926936
927937Following on from Exercises 1, let's look at how savings and aggregate
@@ -953,7 +963,7 @@ r_vals = np.linspace(0, 0.015, M)
953963```
954964
955965
956- ``` {solution-start} ifp_egm_ex3
966+ ``` {solution-start} ifp_egm_ex2
957967:class: dropdown
958968```
959969
@@ -967,12 +977,16 @@ for r in r_vals:
967977 print(f'Solving model at r = {r}')
968978 ifp = create_ifp(r=r)
969979 R, β, γ, Π, z_grid, s = ifp
970- c_vals_init = s[:, None] * jnp.ones(len(z_grid))
971- c_vals, ae_vals = solve_model(ifp, c_vals_init)
980+ ae_vals_init = s[:, None] * jnp.ones(len(z_grid))
981+ c_vals_init = ae_vals_init
982+ c_vals, ae_vals = solve_model(ifp, c_vals_init, ae_vals_init)
972983 assets = compute_asset_stationary(ifp, c_vals, ae_vals, num_households=10_000, T=500)
973984 mean = np.mean(assets)
974985 asset_mean.append(mean)
975986 print(f' Mean assets: {mean:.4f}')
987+ # Start next round with last solution
988+ c_vals_init = c_vals
989+ ae_vals_init = ae_vals
976990ax.plot(r_vals, asset_mean)
977991
978992ax.set(xlabel='interest rate', ylabel='capital')
0 commit comments