@@ -93,7 +93,7 @@ subject to
9393``` {math}
9494:label: eqst
9595
96- a_{t+1} + c_t \leq R a_t + Y_t
96+ a_{t+1} = R ( a_t - c_t) + Y_{t+1}
9797\quad c_t \geq 0,
9898\quad a_t \geq 0
9999\quad t = 0, 1, \ldots
@@ -109,9 +109,10 @@ Here
109109
110110The timing here is as follows:
111111
112- 1 . At the start of period $t$, the household observes labor income $Y_t$ and financial assets $R a_t$ .
112+ 1 . At the start of period $t$, the household observes current asset holdings $ a_t$.
1131131 . The household chooses current consumption $c_t$.
114- 1 . Time shifts to $t+1$ and the process repeats.
114+ 1 . Savings $(a_t - c_t)$ earn interest at rate $r$.
115+ 1 . Labor income $Y_ {t+1}$ is realized and time shifts to $t+1$.
115116
116117Non-capital income $Y_t$ is given by $Y_t = y(Z_t)$, where
117118
@@ -246,7 +247,7 @@ random variables:
246247
247248 (u' \circ \sigma) (a, z)
248249 = \beta R \, \sum_{z'} (u' \circ \sigma)
249- [R a + y(z) - \sigma(a, z)), \, z'] \Pi(z, z')
250+ [R (a - \sigma(a, z)) + y(z' ), \, z'] \Pi(z, z')
250251```
251252
252253Here
@@ -259,35 +260,44 @@ We aim to find a fixed point $\sigma$ of {eq}`eqeul1`.
259260
260261To do so we use the EGM.
261262
262- We begin with an exogenous grid $G = \{ a' _ 0 , \ldots, a' _ {m-1}\} $ with $a' _ 0 = 0$.
263+ We begin with an exogenous grid $G = \{ s_0 , \ldots, s _ {m-1}\} $ with $s_0 > 0$, where each $s_i$ represents savings .
263264
264- Fix a current guess of the policy function $\sigma$.
265+ The relationship between current assets $a$, consumption $c$, and savings $s$ is
265266
266- For each $a'_ i$ and $z_j$ we set
267+ $$
268+ a = c + s
269+ $$
270+
271+ and next period assets are given by
272+
273+ $$
274+ a' = R s + y(z').
275+ $$
276+
277+ Fix a current guess of the policy function $\sigma$.
278+
279+ For each savings level $s_i$ and current state $z_j$, we set
267280
268281$$
269282 c_{ij} = (u')^{-1}
270283 \left[
271- \beta R \, \sum_{z'}
272- u' [ \sigma(a'_i , z') ] \Pi(z_j, z')
284+ \beta R \, \sum_{z'}
285+ u' [ \sigma(R s_i + y(z') , z') ] \Pi(z_j, z')
273286 \right]
274287$$
275288
276- and then $a^e_ {ij}$ as the current asset level $a_t$ that solves the budget constraint
277- $a'_ {ij} + c_ {ij} = R a_t + y(z_j)$.
278-
279- That is,
289+ and then obtain the endogenous grid of current assets via
280290
281291$$
282- a^e_{ij} = \frac{1}{R} [a'_{ ij} + c_{ij} - y(z_j)].
292+ a^e_{ij} = c_{ ij} + s_i.
283293$$
284294
285295Our next guess policy function, which we write as $K\sigma$, is the linear interpolation of
286296$(a^e_ {ij}, c_ {ij})$ over $i$, for each $j$.
287297
288298(The number of one dimensional linear interpolations is equal to ` len(z_grid) ` .)
289299
290- For $a < a^e_ {ij }$ we use the budget constraint to set $(K \sigma)(a, z_j) = Ra + y(z_j) $.
300+ For $a < a^e_ {i0 }$ (i.e., below the minimum endogenous grid point), the household consumes everything, so we set $(K \sigma)(a, z_j) = a $.
291301
292302
293303
@@ -355,7 +365,7 @@ guess $K\sigma$.
355365We understand $\sigma$ is an array of shape $(n_a, n_z)$, where $n_a$ and $n_z$
356366are the respective grid sizes.
357367
358- The value ` σ[i,j] ` corresponds to $\sigma(a' _ i , z_j)$.
368+ The value ` σ[i,j] ` corresponds to $\sigma(a_i , z_j)$, where $a_i$ is a point on the asset grid .
359369
360370``` {code-cell} ipython3
361371def K(σ: jnp.ndarray, ifp: IFP) -> jnp.ndarray:
@@ -368,46 +378,66 @@ def K(σ: jnp.ndarray, ifp: IFP) -> jnp.ndarray:
368378
369379 Algorithm
370380 ---------
371- The EGM works backwards from next period:
372- 1. Given σ(a', z'), compute current consumption c that
381+ The EGM works with a savings grid:
382+ 1. Use exogenous savings grid s_i
383+ 2. For each (s_i, z_j), compute next period assets a' = R*s_i + y(z')
384+ 3. Given σ(a', z'), compute current consumption c that
373385 satisfies Euler equation
374- 2. Compute the endogenous current asset level a^e that leads
375- to (c, a')
376- 3. Interpolate back to exogenous grid to get σ_new(a', z')
386+ 4. Compute the endogenous current asset level a^e = c + s
387+ 5. Interpolate back to asset grid to get σ_new(a, z)
377388
378389 """
379390 R, β, γ, Π, z_grid, asset_grid = ifp
380391 n_a = len(asset_grid)
381392 n_z = len(z_grid)
382393
394+ # Create savings grid (exogenous grid for EGM)
395+ # We use the asset grid as the savings grid
396+ savings_grid = asset_grid
397+
383398 def compute_c_for_fixed_income_state(j):
384399 """
385400 Compute updated consumption policy for income state z_j.
386401
387- The asset_grid here represents a' (next period assets ).
402+ The savings_grid represents s (savings), where a' = R*s + y(z' ).
388403
389404 """
390405
391- # Compute u'(σ(a', z')) for all (a', z')
392- u_prime_vals = u_prime(σ, γ)
406+ # For each savings level s_i, compute expected marginal utility
407+ # We need to evaluate σ at next period assets a' = R*s + y(z')
408+
409+ # Compute next period assets for all (s_i, z') combinations
410+ # Shape: (n_a, n_z) where savings_grid has n_a points
411+ a_next_grid = R * savings_grid[:, None] + y(z_grid)
412+
413+ # Interpolate to get consumption at each (a', z')
414+ # For each z', interpolate over the a' values
415+ def interp_for_z(z_idx):
416+ return jnp.interp(a_next_grid[:, z_idx], asset_grid, σ[:, z_idx])
393417
394- # Calculate the sum Σ_{z'} u'(σ(a', z')) * Π(z_j, z') at each a'
395- expected_marginal = u_prime_vals @ Π[j, :]
418+ c_next_grid = jax.vmap(interp_for_z)(jnp.arange(n_z)) # Shape: (n_z, n_a)
419+ c_next_grid = c_next_grid.T # Shape: (n_a, n_z)
420+
421+ # Compute u'(c') for all points
422+ u_prime_next = u_prime(c_next_grid, γ)
423+
424+ # Take expectation over z' for each s, given current state z_j
425+ expected_marginal = u_prime_next @ Π[j, :] # Shape: (n_a,)
396426
397427 # Use Euler equation to find today's consumption
398428 c_vals = u_prime_inv(β * R * expected_marginal, γ)
399429
400- # Compute endogenous grid of current assets using the
401- a_endogenous = (1/R) * (asset_grid + c_vals - y(z_grid[j]))
430+ # Compute endogenous grid of current assets: a = c + s
431+ a_endogenous = c_vals + savings_grid
402432
403- # Interpolate back to exogenous grid
433+ # Interpolate back to exogenous asset grid
404434 σ_new = jnp.interp(asset_grid, a_endogenous, c_vals)
405435
406436 # For asset levels below the minimum endogenous grid point,
407- # the household is constrained and c = R*a + y(z)
437+ # the household is constrained and consumes everything: c = a
408438
409439 σ_new = jnp.where(asset_grid < a_endogenous[0],
410- R * asset_grid + y(z_grid[j]) ,
440+ asset_grid,
411441 σ_new)
412442
413443 return σ_new # Consumption over the asset grid given z[j]
@@ -416,7 +446,7 @@ def K(σ: jnp.ndarray, ifp: IFP) -> jnp.ndarray:
416446 c_vmap = jax.vmap(compute_c_for_fixed_income_state)
417447 σ_new = c_vmap(jnp.arange(n_z)) # Shape (n_z, n_a), one row per income state
418448
419- return σ_new.T # Transpose to get (n_a, n_z)
449+ return σ_new.T # Transpose to get (n_a, n_z)
420450```
421451
422452``` {code-cell} ipython3
@@ -454,7 +484,7 @@ Let's road test the EGM code.
454484``` {code-cell} ipython3
455485ifp = create_ifp()
456486R, β, γ, Π, z_grid, asset_grid = ifp
457- σ_init = R * asset_grid[:, None] + y( z_grid)
487+ σ_init = asset_grid[:, None] * jnp.ones(len( z_grid) )
458488σ_star = solve_model(ifp, σ_init)
459489```
460490
@@ -475,13 +505,13 @@ To begin to understand the long run asset levels held by households under the de
475505``` {code-cell} ipython3
476506ifp = create_ifp()
477507R, β, γ, Π, z_grid, asset_grid = ifp
478- σ_init = R * asset_grid[:, None] + y( z_grid)
508+ σ_init = asset_grid[:, None] * jnp.ones(len( z_grid) )
479509σ_star = solve_model(ifp, σ_init)
480510a = asset_grid
481511
482512fig, ax = plt.subplots()
483513for z, lb in zip((0, 1), ('low income', 'high income')):
484- ax.plot(a, R * (a - σ_star[:, z]) + y(z ) , label=lb)
514+ ax.plot(a, R * (a - σ_star[:, z]) + y(z_grid[z] ) , label=lb)
485515
486516ax.plot(a, a, 'k--')
487517ax.set(xlabel='current assets', ylabel='next period assets')
@@ -493,9 +523,11 @@ plt.show()
493523The unbroken lines show the update function for assets at each $z$, which is
494524
495525$$
496- a \mapsto R (a - \sigma^*(a, z)) + y(z)
526+ a \mapsto R (a - \sigma^*(a, z)) + y(z' )
497527$$
498528
529+ where we plot this for a particular realization $z' = z$.
530+
499531The dashed line is the 45 degree line.
500532
501533The figure suggests that the dynamics will be stable --- assets do not diverge
@@ -533,7 +565,7 @@ Let's see if we match up:
533565``` {code-cell} ipython3
534566ifp_cake_eating = create_ifp(r=0.0, z_grid=(-jnp.inf, -jnp.inf))
535567R, β, γ, Π, z_grid, asset_grid = ifp_cake_eating
536- σ_init = R * asset_grid[:, None] + y( z_grid)
568+ σ_init = asset_grid[:, None] * jnp.ones(len( z_grid) )
537569σ_star = solve_model(ifp_cake_eating, σ_init)
538570
539571fig, ax = plt.subplots()
@@ -581,7 +613,7 @@ fig, ax = plt.subplots()
581613for r_val in r_vals:
582614 ifp = create_ifp(r=r_val)
583615 R, β, γ, Π, z_grid, asset_grid = ifp
584- σ_init = R * asset_grid[:, None] + y( z_grid)
616+ σ_init = asset_grid[:, None] * jnp.ones(len( z_grid) )
585617 σ_star = solve_model(ifp, σ_init)
586618 ax.plot(asset_grid, σ_star[:, 0], label=f'$r = {r_val:.3f}$')
587619
@@ -643,12 +675,13 @@ def compute_asset_stationary(ifp, σ_star, num_households=50_000, T=500, seed=12
643675 # Simulate forward T periods
644676 def step(state, key_t):
645677 a_current, z_current = state
678+ # Consume based on current state
679+ c = σ_interp(a_current, z_current)
646680 # Draw next shock
647681 z_next = jax.random.choice(key_t, n_z, p=Π[z_current])
648- # Update assets
682+ # Update assets: a' = R*(a - c) + Y'
649683 z_val = z_grid[z_next]
650- c = σ_interp(a_current, z_next)
651- a_next = R * a_current + y(z_val) - c
684+ a_next = R * (a_current - c) + y(z_val)
652685 return (a_next, z_next), None
653686
654687 keys = jax.random.split(key2, T)
@@ -668,7 +701,7 @@ Now we call the function, generate the asset distribution and histogram it:
668701``` {code-cell} ipython3
669702ifp = create_ifp()
670703R, β, γ, Π, z_grid, asset_grid = ifp
671- σ_init = R * asset_grid[:, None] + y( z_grid)
704+ σ_init = asset_grid[:, None] * jnp.ones(len( z_grid) )
672705σ_star = solve_model(ifp, σ_init)
673706assets = compute_asset_stationary(ifp, σ_star)
674707
@@ -733,7 +766,7 @@ for r in r_vals:
733766 print(f'Solving model at r = {r}')
734767 ifp = create_ifp(r=r)
735768 R, β, γ, Π, z_grid, asset_grid = ifp
736- σ_init = R * asset_grid[:, None] + y( z_grid)
769+ σ_init = asset_grid[:, None] * jnp.ones(len( z_grid) )
737770 σ_star = solve_model(ifp, σ_init)
738771 assets = compute_asset_stationary(ifp, σ_star, num_households=10_000, T=500)
739772 mean = np.mean(assets)
0 commit comments