@@ -58,7 +58,9 @@ We'll need the following imports:
5858import matplotlib.pyplot as plt
5959import numpy as np
6060from quantecon import MarkovChain
61+ import jax
6162import jax.numpy as jnp
63+ from typing import NamedTuple
6264```
6365
6466### References
@@ -306,13 +308,12 @@ Here we build a class called `IFP` that stores the model primitives.
306308
307309``` {code-cell} python3
308310class IFP(NamedTuple):
309- R: float, # Interest rate 1 + r
310- β: float, # Discount factor
311- γ: float, # Preference parameter
312- Π: jnp.array # Markov matrix
313- z_grid: jnp.array # Markov state values for Z_t
314- asset_grid: jnp.array # Exogenous asset grid
315- ]
311+ R: float # Interest rate 1 + r
312+ β: float # Discount factor
313+ γ: float # Preference parameter
314+ Π: jnp.ndarray # Markov matrix
315+ z_grid: jnp.ndarray # Markov state values for Z_t
316+ asset_grid: jnp.ndarray # Exogenous asset grid
316317
317318def create_ifp(r=0.01,
318319 β=0.96,
@@ -323,11 +324,12 @@ def create_ifp(r=0.01,
323324 asset_grid_max=16,
324325 asset_grid_size=50):
325326
326- assert self.R * self.β < 1, "Stability condition violated."
327-
328327 asset_grid = jnp.linspace(0, asset_grid_max, asset_grid_size)
329328 Π, z_grid = jnp.array(Π), jnp.array(z_grid)
330329 R = 1 + r
330+
331+ assert R * β < 1, "Stability condition violated."
332+
331333 return IFP(R=R, β=β, γ=γ, Π=Π, z_grid=z_grid, asset_grid=asset_grid)
332334
333335# Set y(z) = exp(z)
@@ -351,29 +353,111 @@ u_prime_inv = lambda c, γ: c**(-1/γ)
351353``` {code-cell} python3
352354def K(σ: jnp.ndarray, ifp: IFP) -> jnp.ndarray:
353355 """
354- The Coleman-Reffett operator for the IFP model using EGM
355-
356+ The Coleman-Reffett operator for the IFP model using the Endogenous Grid Method.
357+
358+ This operator implements one iteration of the EGM algorithm to update the
359+ consumption policy function.
360+
361+ Parameters
362+ ----------
363+ σ : jnp.ndarray, shape (n_a, n_z)
364+ Current guess of consumption policy where σ[i, j] is consumption
365+ when assets = asset_grid[i] and income state = z_grid[j]
366+ ifp : IFP
367+ Model parameters
368+
369+ Returns
370+ -------
371+ σ_new : jnp.ndarray, shape (n_a, n_z)
372+ Updated consumption policy
373+
374+ Algorithm
375+ ---------
376+ The EGM works backwards from next period:
377+ 1. Given σ(a', z'), compute current consumption c that satisfies Euler equation
378+ 2. Compute the endogenous current asset level a that leads to (c, a')
379+ 3. Interpolate back to exogenous grid to get σ_new(a, z)
356380 """
357381 R, β, γ, Π, z_grid, asset_grid = ifp
382+ n_a = len(asset_grid)
383+ n_z = len(z_grid)
384+
385+ def compute_c_for_state(j):
386+ """
387+ Compute updated consumption policy for income state z_j.
388+
389+ The asset_grid here represents a' (next period assets), not current assets.
390+ """
391+
392+ # Step 1: Compute expected marginal utility of consumption tomorrow
393+ # ----------------------------------------------------------------
394+ # For each level of a' (next period assets), compute:
395+ # E_j[u'(c_{t+1})] = Σ_{z'} u'(σ(a', z')) * Π(z_j, z')
396+ # where the expectation is over tomorrow's income state z'
397+ # conditional on today's income state z_j
398+
399+ u_prime_vals = u_prime(σ, γ) # u'(σ(a', z')) for all (a', z')
400+ # Shape: (n_a, n_z) where n_a is # of a' values
401+
402+ expected_marginal = u_prime_vals @ Π[j, :] # Matrix multiply to get expectation
403+ # Π[j, :] are transition probs from z_j
404+ # Result shape: (n_a,) - one value per a'
358405
359- # Determine endogenous grid associated with consumption choices in σ_array
360- ae = (1/R) (σ_array + a_grid - y(z_grid))
406+ # Step 2: Use Euler equation to find today's consumption
407+ # -------------------------------------------------------
408+ # The Euler equation is: u'(c_t) = β R E_t[u'(c_{t+1})]
409+ # Inverting: c_t = (u')^{-1}(β R E_t[u'(c_{t+1})])
410+ # This gives consumption today (c_ij) for each next period asset a'_i
361411
362- # Linear interpolation of policy using endogenous grid.
363- def σ_interp(ap):
364- return [jnp.interp(ap, ae[:, j], σ[:, j]) for j in range(len(z_grid))]
412+ c_vals = u_prime_inv(β * R * expected_marginal, γ)
413+ # c_vals[i] is consumption today that's optimal when planning to
414+ # have a'_i assets tomorrow, given income state z_j today
415+ # Shape: (n_a,)
365416
366- # Define function to compute consumption at a single grid pair (a'_i, z_j)
367- def compute_c(i, j):
368- ap = ae[i]
369- rhs = jnp.sum( u_prime(σ_interp(ap)) * Π[j, :] )
370- return u_prime_inv(β * R * rhs)
417+ # Step 3: Compute endogenous grid of current assets
418+ # --------------------------------------------------
419+ # The budget constraint is: a_{t+1} + c_t = R * a_t + Y_t
420+ # Rearranging: a_t = (a_{t+1} + c_t - Y_t) / R
421+ # For each (a'_i, c_i) pair, find the current asset level a^e_i that
422+ # makes this budget constraint hold
371423
372- # Vectorize over grid using vmap
373- compute_c_vectorized = jax.vmap(compute_c)
374- next_σ_array = compute_c_vectorized(asset_grid, z_grid)
424+ a_endogenous = (1/R) * (asset_grid + c_vals - y(z_grid[j]))
425+ # asset_grid[i] is a'_i, c_vals[i] is c_i, y(z_grid[j]) is income today
426+ # a_endogenous[i] is the current asset level that leads to this (c_i, a'_i) pair
427+ # Shape: (n_a,)
375428
376- return next_σ_array
429+ # Step 4: Interpolate back to exogenous grid
430+ # -------------------------------------------
431+ # We now have consumption as a function of the *endogenous* grid a^e
432+ # But we need it on the *exogenous* grid (asset_grid)
433+ # Use linear interpolation: σ_new(a) ≈ c(a) where a ∈ asset_grid
434+
435+ σ_new = jnp.interp(asset_grid, a_endogenous, c_vals)
436+ # For each point in asset_grid, interpolate to find consumption
437+ # Shape: (n_a,)
438+
439+ # Step 5: Handle borrowing constraint
440+ # ------------------------------------
441+ # For asset levels below the minimum endogenous grid point,
442+ # the household is constrained and consumes all available resources
443+ # c = R*a + y(z) (save nothing)
444+
445+ σ_new = jnp.where(asset_grid < a_endogenous[0],
446+ R * asset_grid + y(z_grid[j]),
447+ σ_new)
448+ # When a < a_endogenous[0], set c = R*a + y (consume everything)
449+
450+ return σ_new # Shape: (n_a,)
451+
452+ # Vectorize computation over all income states using vmap
453+ # --------------------------------------------------------
454+ # Instead of a Python loop over j, use JAX's vmap for efficiency
455+ # This computes compute_c_for_state(j) for all j in parallel
456+
457+ σ_new = jax.vmap(compute_c_for_state)(jnp.arange(n_z))
458+ # Result shape: (n_z, n_a) - one row per income state
459+
460+ return σ_new.T # Transpose to get (n_a, n_z) to match input format
377461```
378462
379463
@@ -395,7 +479,7 @@ def solve_model(ifp: IFP,
395479
396480 def body(loop_state):
397481 i, σ, error = loop_state
398- σ_new = K(σ, model )
482+ σ_new = K(σ, ifp )
399483 error = jnp.max(jnp.abs(σ_new - σ))
400484 return i + 1, σ_new, error
401485
@@ -413,9 +497,9 @@ def solve_model(ifp: IFP,
413497Let's road test the EGM code.
414498
415499``` {code-cell} python3
416- ifp = IFP ()
500+ ifp = create_ifp ()
417501R, β, γ, Π, z_grid, asset_grid = ifp
418- σ_init = R * a_grid + z_grid
502+ σ_init = R * asset_grid[:, None] + y( z_grid)
419503σ_star = solve_model(ifp, σ_init)
420504```
421505
@@ -424,8 +508,8 @@ Here's a plot of the optimal policy for each $z$ state
424508
425509``` {code-cell} python3
426510fig, ax = plt.subplots()
427- ax.plot(a_grid , σ_star[:, 0], label='bad state')
428- ax.plot(a_grid , σ_star[:, 1], label='good state')
511+ ax.plot(asset_grid , σ_star[:, 0], label='bad state')
512+ ax.plot(asset_grid , σ_star[:, 1], label='good state')
429513ax.set(xlabel='assets', ylabel='consumption')
430514ax.legend()
431515plt.show()
@@ -456,14 +540,14 @@ def v_star(x, β, γ):
456540Let's see if we match up:
457541
458542``` {code-cell} python3
459- ifp_cake_eating = IFP (r=0.0, z_grid=(-jnp.inf, -jnp.inf))
543+ ifp_cake_eating = create_ifp (r=0.0, z_grid=(-jnp.inf, -jnp.inf))
460544R, β, γ, Π, z_grid, asset_grid = ifp_cake_eating
461- σ_init = R * a_grid + z_grid
462- σ_star = solve_model_time_iter (ifp_cake_eating, σ_init)
545+ σ_init = R * asset_grid[:, None] + y( z_grid)
546+ σ_star = solve_model (ifp_cake_eating, σ_init)
463547
464548fig, ax = plt.subplots()
465- ax.plot(a_grid , σ_star[:, 0], label='numerical')
466- ax.plot(a_grid , c_star(a_grid, ifp .β, ifp .γ), '--', label='analytical')
549+ ax.plot(asset_grid , σ_star[:, 0], label='numerical')
550+ ax.plot(asset_grid , c_star(asset_grid, ifp_cake_eating .β, ifp_cake_eating .γ), '--', label='analytical')
467551ax.set(xlabel='assets', ylabel='consumption')
468552ax.legend()
469553plt.show()
0 commit comments