@@ -114,7 +114,7 @@ The timing here is as follows:
114114
115115Non-capital income $Y_t$ is given by $Y_t = y(Z_t)$, where
116116
117- * $\{ Z_t\} $ is an exogeneous state process and
117+ * $\{ Z_t\} $ is an exogenous state process and
118118* $y$ is a given function taking values in $\mathbb{R}_ +$.
119119
120120As is common in the literature, we take $\{ Z_t\} $ to be a finite state
@@ -258,7 +258,7 @@ We aim to find a fixed point $\sigma$ of {eq}`eqeul1`.
258258
259259To do so we use the EGM.
260260
261- We begin with an exogeneous grid $G = \{ a'_ 0, \ldots, a'_ {m-1}\} $ with $a'_ 0 = 0$.
261+ We begin with an exogenous grid $G = \{ a'_ 0, \ldots, a'_ {m-1}\} $ with $a'_ 0 = 0$.
262262
263263Fix a current guess of the policy function $\sigma$.
264264
306306
307307Here we build a class called ` IFP ` that stores the model primitives.
308308
309- ``` {code-cell} python3
309+ ``` {code-cell} ipython
310310class IFP(NamedTuple):
311- R: float # Interest rate 1 + r
312- β: float # Discount factor
313- γ: float # Preference parameter
314- Π: jnp.ndarray # Markov matrix
315- z_grid: jnp.ndarray # Markov state values for Z_t
311+ R: float # Gross interest rate R = 1 + r
312+ β: float # Discount factor
313+ γ: float # Preference parameter
314+ Π: jnp.ndarray # Markov matrix for exogenous shock
315+ z_grid: jnp.ndarray # Markov state values for Z_t
316316 asset_grid: jnp.ndarray # Exogenous asset grid
317317
318+
318319def create_ifp(r=0.01,
319- β=0.96 ,
320+ β=0.98 ,
320321 γ=1.5,
321322 Π=((0.6, 0.4),
322- (0.05, 0.95)),
323- z_grid=(0.0, 0.1 ),
324- asset_grid_max=16 ,
323+ (0.05, 0.95)),
324+ z_grid=(0.0, 0.2 ),
325+ asset_grid_max=40 ,
325326 asset_grid_size=50):
326327
327328 asset_grid = jnp.linspace(0, asset_grid_max, asset_grid_size)
328329 Π, z_grid = jnp.array(Π), jnp.array(z_grid)
329330 R = 1 + r
330-
331331 assert R * β < 1, "Stability condition violated."
332-
333332 return IFP(R=R, β=β, γ=γ, Π=Π, z_grid=z_grid, asset_grid=asset_grid)
334333
335334# Set y(z) = exp(z)
336335y = jnp.exp
337336```
338337
339- The exogeneous state process $\{ Z_t\} $ defaults to a two-state Markov chain
338+ The exogenous state process $\{ Z_t\} $ defaults to a two-state Markov chain
340339with transition matrix $\Pi$.
341340
342341We define utility globally:
343342
344- ``` {code-cell} python3
343+ ``` {code-cell} ipython
345344# Define utility function derivatives
346345u_prime = lambda c, γ: c**(-γ)
347346u_prime_inv = lambda c, γ: c**(-1/γ)
@@ -350,7 +349,7 @@ u_prime_inv = lambda c, γ: c**(-1/γ)
350349
351350### Solver
352351
353- ``` {code-cell} python3
352+ ``` {code-cell} ipython
354353def K(σ: jnp.ndarray, ifp: IFP) -> jnp.ndarray:
355354 """
356355 The Coleman-Reffett operator for the IFP model using the
@@ -362,7 +361,7 @@ def K(σ: jnp.ndarray, ifp: IFP) -> jnp.ndarray:
362361 Parameters
363362 ----------
364363 σ : jnp.ndarray, shape (n_a, n_z)
365- Current guess of consumption policy where σ[i, j] is consumption
364+ Current guess of consumption policy, σ[i, j] is consumption
366365 when assets = asset_grid[i] and income state = z_grid[j]
367366 ifp : IFP
368367 Model parameters
@@ -389,87 +388,44 @@ def K(σ: jnp.ndarray, ifp: IFP) -> jnp.ndarray:
389388 """
390389 Compute updated consumption policy for income state z_j.
391390
392- The asset_grid here represents a' (next period assets),
393- not current assets.
394- """
391+ The asset_grid here represents a' (next period assets).
395392
396- # Step 1: Compute expected marginal utility of consumption tomorrow
397- # ----------------------------------------------------------------
398- # For each level of a' (next period assets), compute:
399- # E_j[u'(c_{t+1})] = Σ_{z'} u'(σ(a', z')) * Π(z_j, z')
400- # where the expectation is over tomorrow's income state z'
401- # conditional on today's income state z_j
393+ """
402394
403- # u'(σ(a', z')) for all (a', z')
404- # Shape: (n_a, n_z) where n_a is # of a' values
395+ # Compute u'(σ(a', z')) for all (a', z')
405396 u_prime_vals = u_prime(σ, γ)
406397
407- # Matrix multiply to get expectation
408- # Π[j, :] are transition probs from z_j
409- # Result shape: (n_a,) - one value per a'
398+ # Calculate the sum Σ_{z'} u'(σ(a', z')) * Π(z_j, z') at each a'
410399 expected_marginal = u_prime_vals @ Π[j, :]
411400
412- # Step 2: Use Euler equation to find today's consumption
413- # -------------------------------------------------------
414- # The Euler equation is: u'(c_t) = β R E_t[u'(c_{t+1})]
415- # Inverting: c_t = (u')^{-1}(β R E_t[u'(c_{t+1})])
416- # This gives consumption today (c_ij) for each next period asset a'_i
417-
401+ # Use Euler equation to find today's consumption
418402 c_vals = u_prime_inv(β * R * expected_marginal, γ)
419- # c_vals[i] is consumption today that's optimal when planning to
420- # have a'_i assets tomorrow, given income state z_j today
421- # Shape: (n_a,)
422-
423- # Step 3: Compute endogenous grid of current assets
424- # --------------------------------------------------
425- # The budget constraint is: a_{t+1} + c_t = R * a_t + Y_t
426- # Rearranging: a_t = (a_{t+1} + c_t - Y_t) / R
427- # For each (a'_i, c_i) pair, find the current asset
428- # level a^e_i that makes this budget constraint hold
429-
430- # asset_grid[i] is a'_i, c_vals[i] is c_i
431- # y(z_grid[j]) is income today
432- # a_endogenous[i] is the current asset level that
433- # leads to this (c_i, a'_i) pair. Shape: (n_a,)
434- a_endogenous = (1/R) * (asset_grid + c_vals - y(z_grid[j]))
435403
436- # Step 4: Interpolate back to exogenous grid
437- # -------------------------------------------
438- # We now have consumption as a function of the *endogenous* grid a^e
439- # But we need it on the *exogenous* grid (asset_grid)
440- # Use linear interpolation: σ_new(a) ≈ c(a) where a ∈ asset_grid
404+ # Compute endogenous grid of current assets using the
405+ a_endogenous = (1/R) * (asset_grid + c_vals - y(z_grid[j]))
441406
407+ # Interpolate back to exogenous grid
442408 σ_new = jnp.interp(asset_grid, a_endogenous, c_vals)
443- # For each point in asset_grid, interpolate to find consumption
444- # Shape: (n_a,)
445409
446- # Step 5: Handle borrowing constraint
447- # ------------------------------------
448410 # For asset levels below the minimum endogenous grid point,
449- # the household is constrained and consumes all available resources
450- # c = R*a + y(z) (save nothing)
411+ # the household is constrained and c = R*a + y(z)
451412
452413 σ_new = jnp.where(asset_grid < a_endogenous[0],
453414 R * asset_grid + y(z_grid[j]),
454415 σ_new)
455- # When a < a_endogenous[0], set c = R*a + y (consume everything)
456-
457- return σ_new # Shape: (n_a,)
458416
459- # Vectorize computation over all income states using vmap
460- # --------------------------------------------------------
461- # Instead of a Python loop over j, use JAX's vmap for efficiency
462- # This computes compute_c_for_state(j) for all j in parallel
417+ return σ_new # Consumption over the asset grid given z[j]
463418
419+ # Vectorize computation over all exogenous states using vmap
420+ # Resulting shape is (n_z, n_a), so one row per income state
464421 σ_new = jax.vmap(compute_c_for_state)(jnp.arange(n_z))
465- # Result shape: (n_z, n_a) - one row per income state
466422
467- return σ_new.T # Transpose to get (n_a, n_z) to match input format
423+ return σ_new.T # Transpose to get (n_a, n_z)
468424```
469425
470426
471427
472- ``` {code-cell} python3
428+ ``` {code-cell} ipython
473429@jax.jit
474430def solve_model(ifp: IFP,
475431 σ_init: jnp.ndarray,
@@ -503,7 +459,7 @@ def solve_model(ifp: IFP,
503459
504460Let's road test the EGM code.
505461
506- ``` {code-cell} python3
462+ ``` {code-cell} ipython
507463ifp = create_ifp()
508464R, β, γ, Π, z_grid, asset_grid = ifp
509465σ_init = R * asset_grid[:, None] + y(z_grid)
@@ -513,7 +469,7 @@ R, β, γ, Π, z_grid, asset_grid = ifp
513469Here's a plot of the optimal policy for each $z$ state
514470
515471
516- ``` {code-cell} python3
472+ ``` {code-cell} ipython
517473fig, ax = plt.subplots()
518474ax.plot(asset_grid, σ_star[:, 0], label='bad state')
519475ax.plot(asset_grid, σ_star[:, 1], label='good state')
@@ -535,7 +491,7 @@ In this case, our income fluctuation problem is just a CRRA cake eating problem.
535491
536492Then the value function and optimal consumption policy are given by
537493
538- ``` {code-cell} python3
494+ ``` {code-cell} ipython
539495def c_star(x, β, γ):
540496 return (1 - β ** (1/γ)) * x
541497
@@ -546,7 +502,7 @@ def v_star(x, β, γ):
546502
547503Let's see if we match up:
548504
549- ``` {code-cell} python3
505+ ``` {code-cell} ipython
550506ifp_cake_eating = create_ifp(r=0.0, z_grid=(-jnp.inf, -jnp.inf))
551507R, β, γ, Π, z_grid, asset_grid = ifp_cake_eating
552508σ_init = R * asset_grid[:, None] + y(z_grid)
@@ -589,8 +545,9 @@ Your figure should show that higher interest rates boost savings and suppress co
589545
590546Here's one solution:
591547
592- ``` {code-cell} python3
593- r_vals = np.linspace(0, 0.04, 4)
548+ ``` {code-cell} ipython
549+ # With β=0.98, we need R*β < 1, so r < 0.0204
550+ r_vals = np.linspace(0, 0.015, 4)
594551
595552fig, ax = plt.subplots()
596553for r_val in r_vals:
@@ -618,7 +575,7 @@ default parameters.
618575
619576The following figure is a 45 degree diagram showing the law of motion for assets when consumption is optimal
620577
621- ``` {code-cell} python3
578+ ``` {code-cell} ipython
622579ifp = create_ifp()
623580R, β, γ, Π, z_grid, asset_grid = ifp
624581σ_init = R * asset_grid[:, None] + y(z_grid)
@@ -639,7 +596,7 @@ plt.show()
639596The unbroken lines show the update function for assets at each $z$, which is
640597
641598$$
642- a \mapsto R (a - \sigma^*(a, z)) + y(z)
599+ a \mapsto R (a - \sigma^*(a, z)) + y(z)
643600$$
644601
645602The dashed line is the 45 degree line.
@@ -671,45 +628,68 @@ Your task is to generate such a histogram.
671628:class: dropdown
672629```
673630
674- First we write a function to compute a long asset series .
631+ First we write a function to simulate many households in parallel using JAX .
675632
676- ``` {code-cell} python3
677- def compute_asset_series (ifp, σ_init, T=500_000 , seed=1234):
633+ ``` {code-cell} ipython
634+ def compute_asset_stationary (ifp, σ_star, num_households=50_000, T=500 , seed=1234):
678635 """
679- Simulates a time series of length T for assets, given optimal
680- savings behavior.
636+ Simulates num_households households for T periods to approximate
637+ the stationary distribution of assets.
638+
639+ By ergodicity, simulating many households for moderate time is equivalent
640+ to simulating one household for very long time, but parallelizes better.
681641
682642 ifp is an instance of IFP
643+ σ_star is the optimal consumption policy
683644 """
684645 R, β, γ, Π, z_grid, asset_grid = ifp
646+ n_z = len(z_grid)
685647
686- # Solve for the optimal policy
687- σ_star = solve_model(ifp, σ_init)
688- σ = lambda a, z: np.interp(a, asset_grid, σ_star[:, z])
648+ # Create interpolation function for consumption policy
649+ σ_interp = lambda a, z_idx: jnp.interp(a, asset_grid, σ_star[:, z_idx])
650+
651+ # Simulate one household forward
652+ def simulate_one_household(key):
653+ # Random initial state (both z and a)
654+ key1, key2, key3 = jax.random.split(key, 3)
655+ z_idx = jax.random.choice(key1, n_z)
656+ # Start with random assets drawn uniformly from [0, asset_grid_max/2]
657+ a = jax.random.uniform(key3, minval=0.0, maxval=asset_grid[-1]/2)
689658
690- # Simulate the exogeneous state process
691- mc = MarkovChain(Π)
692- z_seq = mc.simulate(T, random_state=seed)
659+ # Simulate forward T periods
660+ def step(state, key_t):
661+ a_current, z_current = state
662+ # Draw next shock
663+ z_next = jax.random.choice(key_t, n_z, p=Π[z_current])
664+ # Update assets
665+ z_val = z_grid[z_next]
666+ c = σ_interp(a_current, z_next)
667+ a_next = R * a_current + y(z_val) - c
668+ return (a_next, z_next), None
693669
694- # Simulate the asset path
695- a = np.zeros(T+1)
696- for t in range(T):
697- z_idx = z_seq[t]
698- z_val = z_grid[z_idx]
699- a[t+1] = R * a[t] + y(z_val) - σ(a[t], z_idx)
700- return a
670+ keys = jax.random.split(key2, T)
671+ (a_final, _), _ = jax.lax.scan(step, (a, z_idx), keys)
672+ return a_final
673+
674+ # Vectorize over many households
675+ key = jax.random.PRNGKey(seed)
676+ keys = jax.random.split(key, num_households)
677+ assets = jax.vmap(simulate_one_household)(keys)
678+
679+ return np.array(assets)
701680```
702681
703- Now we call the function, generate the series and then histogram it:
682+ Now we call the function, generate the asset distribution and histogram it:
704683
705- ``` {code-cell} python3
684+ ``` {code-cell} ipython
706685ifp = create_ifp()
707686R, β, γ, Π, z_grid, asset_grid = ifp
708687σ_init = R * asset_grid[:, None] + y(z_grid)
709- a = compute_asset_series(ifp, σ_init)
688+ σ_star = solve_model(ifp, σ_init)
689+ assets = compute_asset_stationary(ifp, σ_star)
710690
711691fig, ax = plt.subplots()
712- ax.hist(a , bins=20, alpha=0.5, density=True)
692+ ax.hist(assets , bins=20, alpha=0.5, density=True)
713693ax.set(xlabel='assets')
714694plt.show()
715695```
@@ -724,6 +704,8 @@ more realistic features to the model.
724704``` {solution-end}
725705```
726706
707+
708+
727709``` {exercise-start}
728710:label: ifp_ex3
729711```
@@ -756,9 +738,10 @@ stationary distribution given the interest rate.
756738
757739Here's one solution
758740
759- ``` {code-cell} python3
741+ ``` {code-cell} ipython
760742M = 25
761- r_vals = np.linspace(0, 0.02, M)
743+ # With β=0.98, we need R*β < 1, so R < 1/0.98 ≈ 1.0204, thus r < 0.0204
744+ r_vals = np.linspace(0, 0.015, M)
762745fig, ax = plt.subplots()
763746
764747asset_mean = []
@@ -767,11 +750,14 @@ for r in r_vals:
767750 ifp = create_ifp(r=r)
768751 R, β, γ, Π, z_grid, asset_grid = ifp
769752 σ_init = R * asset_grid[:, None] + y(z_grid)
770- mean = np.mean(compute_asset_series(ifp, σ_init, T=250_000))
753+ σ_star = solve_model(ifp, σ_init)
754+ assets = compute_asset_stationary(ifp, σ_star, num_households=10_000, T=500)
755+ mean = np.mean(assets)
771756 asset_mean.append(mean)
772- ax.plot(asset_mean, r_vals)
757+ print(f' Mean assets: {mean:.4f}')
758+ ax.plot(r_vals, asset_mean)
773759
774- ax.set(xlabel='capital ', ylabel='interest rate ')
760+ ax.set(xlabel='interest rate ', ylabel='capital ')
775761
776762plt.show()
777763```
0 commit comments