Skip to content

Commit 1d4aa57

Browse files
jstacclaude
andcommitted
Improve ifp_egm lecture: add JAX dependency, fix grammar, enhance documentation
- Add jax to pip install requirements - Improve y_bar function docstring with clearer mathematical notation - Fix grammar and consistency in introduction - Add spacing in K_numpy function for readability 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent e5e369e commit 1d4aa57

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

lectures/ifp_egm.md

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,18 @@ In this lecture we continue examining a version of the IFP from
3535

3636
We will make three changes.
3737

38-
1. We will add a transient shock component to labor income (as well as a persistent one).
39-
2. We will change the timing to one that is more efficient for our set up.
40-
3. To solve the model, we will use the endogenous grid method (EGM).
38+
1. Add a transient shock component to labor income (as well as a persistent one).
39+
2. Change the timing to one that is more efficient for our set up.
40+
3. Use the endogenous grid method (EGM) to solve the model.
4141

42-
We use the EGM because we know it to be fast and accurate from {doc}`os_egm_jax`.
42+
We use EGM because we know it to be fast and accurate from {doc}`os_egm_jax`.
4343

4444
In addition to what's in Anaconda, this lecture will need the following libraries:
4545

4646
```{code-cell} ipython3
4747
:tags: [hide-output]
4848
49-
!pip install quantecon
49+
!pip install quantecon jax
5050
```
5151

5252
We'll also need the following imports:
@@ -62,7 +62,7 @@ from typing import NamedTuple
6262
```
6363

6464
We will use 64-bit precision in JAX because we want to compare NumPy outputs
65-
with JAX outputs --- and NumPy arrays default to 64 bits.
65+
with JAX outputs and NumPy arrays default to 64 bits.
6666

6767
```{code-cell} ipython3
6868
jax.config.update("jax_enable_x64", True)
@@ -217,7 +217,7 @@ When $c_t$ hits the upper bound $a_t$, the
217217
strict inequality $u' (c_t) > \beta R \, \mathbb{E}_t u'(c_{t+1})$
218218
can occur because $c_t$ cannot increase sufficiently to attain equality.
219219

220-
The lower boundary case $c_t = 0$ never arises along the optimal path because $u'(0) = \infty$.
220+
The case $c_t = 0$ never arises along the optimal path because $u'(0) = \infty$.
221221

222222

223223
### Optimality Results
@@ -470,6 +470,7 @@ def K_numpy(
470470
471471
for i in range(1, n_a): # Start from 1 for positive savings levels
472472
for j in range(n_z):
473+
473474
# Compute Σ_z' ∫ u'(σ(R s_i + y(z', η'), z')) φ(η') dη' Π[z_j, z']
474475
expectation = 0.0
475476
for k in range(n_z):
@@ -488,6 +489,7 @@ def K_numpy(
488489
inner_mean_k = (inner_sum / len(η_draws))
489490
# Weight by transition probability and add to the expectation
490491
expectation += inner_mean_k * Π[j, k]
492+
491493
# Calculate updated c_{ij} values
492494
new_c_vals[i, j] = u_prime_inv(β * R * expectation)
493495
@@ -776,12 +778,15 @@ def y(z, η):
776778
return jnp.exp(a_y * η + z * b_y)
777779
778780
def y_bar(k):
779-
"""Expected labor income conditional on current state z_grid[k]"""
780-
# Compute mean of y(z', η) for each future state z'
781+
"""
782+
Taking z = z_grid[k], compute an approximation to
783+
784+
E_z Y' = Σ_{z'} ∫ y(z', η') φ(η') dη' Π[z, z']
785+
"""
786+
# Approximate ∫ y(z', η') φ(η') dη' at given z'
781787
def mean_y_at_z(z_prime):
782788
return jnp.mean(y(z_prime, η_draws))
783-
784-
# Vectorize over all future states z'
789+
# Evaluate this integral across all z'
785790
y_means = jax.vmap(mean_y_at_z)(z_grid)
786791
# Weight by transition probabilities and sum
787792
return jnp.sum(y_means * Π[k, :])

0 commit comments

Comments
 (0)