@@ -443,6 +443,8 @@ Given guesses of prices and taxes, we can use backwards induction to solve for
443443
444444The function ` backwards_opt ` solve for optimal values by applying the discretized bellman operator backwards.
445445
446+ We use ` jax.lax.scan ` to facilitate sequential and recurrant computations efficiently.
447+
446448``` {code-cell} ipython3
447449@jax.jit
448450def backwards_opt(prices, taxes, household, Q):
@@ -457,6 +459,7 @@ def backwards_opt(prices, taxes, household, Q):
457459 num_action = a_grid.size
458460
459461 def bellman_operator_j(V_next, j):
462+ "Solve household optimization problem at age j given Vj+1"
460463
461464 Rj = populate_R(j, r, w, τ, δ, household)
462465 vals = Rj + β * Q.dot(V_next)
@@ -468,6 +471,7 @@ def backwards_opt(prices, taxes, household, Q):
468471 js = jnp.arange(J-1, -1, -1)
469472 init_V = VJ
470473
474+ # iterate from age J to 1
471475 _, outputs = jax.lax.scan(bellman_operator_j, init_V, js)
472476 V, σ = outputs
473477 V = V[::-1]
@@ -503,6 +507,7 @@ def popu_dist(σ, household, Q):
503507 num_state = hh.a_grid.size * hh.γ_grid.size
504508
505509 def update_popu_j(μ_j, j):
510+ "Update population distribution from age j to j+1"
506511
507512 Qσ = Q[jnp.arange(num_state), σ[j]]
508513 μ_next = μ_j @ Qσ
@@ -511,6 +516,7 @@ def popu_dist(σ, household, Q):
511516
512517 js = jnp.arange(J-1)
513518
519+ # iterate from age 1 to J
514520 _, μ = jax.lax.scan(update_popu_j, init_μ, js)
515521 μ = jnp.concatenate([init_μ[jnp.newaxis], μ], axis=0)
516522
0 commit comments