Skip to content

Commit 844165b

Browse files
committed
[ak_aiyagari] polish lecture
1 parent 370ecb1 commit 844165b

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

lectures/ak_aiyagari.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,8 @@ Given guesses of prices and taxes, we can use backwards induction to solve for
443443

444444
The function `backwards_opt` solve for optimal values by applying the discretized bellman operator backwards.
445445

446+
We use `jax.lax.scan` to facilitate sequential and recurrant computations efficiently.
447+
446448
```{code-cell} ipython3
447449
@jax.jit
448450
def backwards_opt(prices, taxes, household, Q):
@@ -457,6 +459,7 @@ def backwards_opt(prices, taxes, household, Q):
457459
num_action = a_grid.size
458460
459461
def bellman_operator_j(V_next, j):
462+
"Solve household optimization problem at age j given Vj+1"
460463
461464
Rj = populate_R(j, r, w, τ, δ, household)
462465
vals = Rj + β * Q.dot(V_next)
@@ -468,6 +471,7 @@ def backwards_opt(prices, taxes, household, Q):
468471
js = jnp.arange(J-1, -1, -1)
469472
init_V = VJ
470473
474+
# iterate from age J to 1
471475
_, outputs = jax.lax.scan(bellman_operator_j, init_V, js)
472476
V, σ = outputs
473477
V = V[::-1]
@@ -503,6 +507,7 @@ def popu_dist(σ, household, Q):
503507
num_state = hh.a_grid.size * hh.γ_grid.size
504508
505509
def update_popu_j(μ_j, j):
510+
"Update population distribution from age j to j+1"
506511
507512
Qσ = Q[jnp.arange(num_state), σ[j]]
508513
μ_next = μ_j @ Qσ
@@ -511,6 +516,7 @@ def popu_dist(σ, household, Q):
511516
512517
js = jnp.arange(J-1)
513518
519+
# iterate from age 1 to J
514520
_, μ = jax.lax.scan(update_popu_j, init_μ, js)
515521
μ = jnp.concatenate([init_μ[jnp.newaxis], μ], axis=0)
516522

0 commit comments

Comments
 (0)