Skip to content

Commit d5d2565

Browse files
committed
add minor comments into the code
1 parent 06e07c0 commit d5d2565

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

lectures/mccall_q.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,12 @@ def update(model, v):
174174
n = model.w.shape[0]
175175
176176
def v_at_state(i):
177+
# compute state-action values
177178
sa = state_action_values(model, i, v)
179+
# apply max operator
178180
return jnp.max(sa)
179181
182+
# vectorize over all states
180183
indices = jnp.arange(n)
181184
v_new = jax.vmap(v_at_state)(indices)
182185
return v_new
@@ -196,6 +199,7 @@ def vfi(model, tol=1e-5, max_iter=500):
196199
_, i, err = state
197200
return (err > tol) & (i < max_iter)
198201
202+
# iterate until convergence
199203
init_state = (v0, 0, tol + 1.0)
200204
v_final, iters, err = jax.lax.while_loop(cond_fun, body_fun, init_state)
201205
converged = jnp.where(err <= tol, 1, 0)

0 commit comments

Comments
 (0)