@@ -35,18 +35,18 @@ In this lecture we continue examining a version of the IFP from
3535
3636We will make three changes.
3737
38- 1 . We will add a transient shock component to labor income (as well as a persistent one).
39- 2 . We will change the timing to one that is more efficient for our set up.
40- 3 . To solve the model, we will use the endogenous grid method (EGM).
38+ 1 . Add a transient shock component to labor income (as well as a persistent one).
39+ 2 . Change the timing to one that is more efficient for our set up.
40+ 3 . Use the endogenous grid method (EGM) to solve the model.
4141
42- We use the EGM because we know it to be fast and accurate from {doc}` os_egm_jax ` .
42+ We use EGM because we know it to be fast and accurate from {doc}` os_egm_jax ` .
4343
4444In addition to what's in Anaconda, this lecture will need the following libraries:
4545
4646``` {code-cell} ipython3
4747:tags: [hide-output]
4848
49- !pip install quantecon
49+ !pip install quantecon jax
5050```
5151
5252We'll also need the following imports:
@@ -62,7 +62,7 @@ from typing import NamedTuple
6262```
6363
6464We will use 64-bit precision in JAX because we want to compare NumPy outputs
65- with JAX outputs --- and NumPy arrays default to 64 bits.
65+ with JAX outputs and NumPy arrays default to 64 bits.
6666
6767``` {code-cell} ipython3
6868jax.config.update("jax_enable_x64", True)
@@ -217,7 +217,7 @@ When $c_t$ hits the upper bound $a_t$, the
217217strict inequality $u' (c_t) > \beta R \, \mathbb{E}_ t u'(c_ {t+1})$
218218can occur because $c_t$ cannot increase sufficiently to attain equality.
219219
220- The lower boundary case $c_t = 0$ never arises along the optimal path because $u'(0) = \infty$.
220+ The case $c_t = 0$ never arises along the optimal path because $u'(0) = \infty$.
221221
222222
223223### Optimality Results
@@ -470,6 +470,7 @@ def K_numpy(
470470
471471 for i in range(1, n_a): # Start from 1 for positive savings levels
472472 for j in range(n_z):
473+
473474 # Compute Σ_z' ∫ u'(σ(R s_i + y(z', η'), z')) φ(η') dη' Π[z_j, z']
474475 expectation = 0.0
475476 for k in range(n_z):
@@ -488,6 +489,7 @@ def K_numpy(
488489 inner_mean_k = (inner_sum / len(η_draws))
489490 # Weight by transition probability and add to the expectation
490491 expectation += inner_mean_k * Π[j, k]
492+
491493 # Calculate updated c_{ij} values
492494 new_c_vals[i, j] = u_prime_inv(β * R * expectation)
493495
@@ -776,12 +778,15 @@ def y(z, η):
776778 return jnp.exp(a_y * η + z * b_y)
777779
778780def y_bar(k):
779- """Expected labor income conditional on current state z_grid[k]"""
780- # Compute mean of y(z', η) for each future state z'
781+ """
782+ Taking z = z_grid[k], compute an approximation to
783+
784+ E_z Y' = Σ_{z'} ∫ y(z', η') φ(η') dη' Π[z, z']
785+ """
786+ # Approximate ∫ y(z', η') φ(η') dη' at given z'
781787 def mean_y_at_z(z_prime):
782788 return jnp.mean(y(z_prime, η_draws))
783-
784- # Vectorize over all future states z'
789+ # Evaluate this integral across all z'
785790 y_means = jax.vmap(mean_y_at_z)(z_grid)
786791 # Weight by transition probabilities and sum
787792 return jnp.sum(y_means * Π[k, :])
0 commit comments