Skip to content

Commit dbcfd78

Browse files
committed
replace valfunc_VFI with valfunc_vfi throughout the lecture
1 parent 659554f commit dbcfd78

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

lectures/mccall_q.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def plot_value_function_seq(mcm, ax, num_plots=8):
225225
226226
```{code-cell} ipython3
227227
mcm = create_mccall_model()
228-
valfunc_VFI, converged = vfi(mcm)
228+
valfunc_vfi, converged = vfi(mcm)
229229
230230
fig, ax = plt.subplots(figsize=(10,6))
231231
ax.set_xlabel('wage')
@@ -241,7 +241,7 @@ This is the approximation to the McCall worker's value function that is produce
241241
We'll use this value function as a benchmark later after we have done some Q-learning.
242242
243243
```{code-cell} ipython3
244-
print(valfunc_VFI)
244+
print(valfunc_vfi)
245245
```
246246
247247
## Implied quality function $Q$
@@ -616,8 +616,8 @@ def valfunc_from_qtable(qtable):
616616
return jnp.max(qtable, axis=1)
617617
618618
619-
def compute_error(valfunc, valfunc_VFI):
620-
return jnp.mean(jnp.abs(valfunc - valfunc_VFI))
619+
def compute_error(valfunc, valfunc_vfi):
620+
return jnp.mean(jnp.abs(valfunc - valfunc_vfi))
621621
```
622622
623623
```{code-cell} ipython3
@@ -644,7 +644,7 @@ print(valfunc_qlr)
644644
```{code-cell} ipython3
645645
# plot
646646
fig, ax = plt.subplots(figsize=(10,6))
647-
ax.plot(w_default, valfunc_VFI, '-o', label='VFI')
647+
ax.plot(w_default, valfunc_vfi, '-o', label='VFI')
648648
ax.plot(w_default, valfunc_qlr, '-o', label='QL')
649649
ax.set_xlabel('wages')
650650
ax.set_ylabel('optimal value')
@@ -675,8 +675,8 @@ plt.show()
675675
```{code-cell} ipython3
676676
# vfi
677677
mcm = create_mccall_model(w=w_new, q=q_new)
678-
valfunc_VFI, converged = vfi(mcm)
679-
valfunc_VFI
678+
valfunc_vfi, converged = vfi(mcm)
679+
valfunc_vfi
680680
```
681681
682682
```{code-cell} ipython3
@@ -688,7 +688,7 @@ def plot_epochs(epochs_to_plot, quit_allowed=1, key=key):
688688
epochs_to_plot = jnp.asarray(epochs_to_plot)
689689
# plot
690690
fig, ax = plt.subplots(figsize=(10,6))
691-
ax.plot(w_new, valfunc_VFI, '-o', label='VFI')
691+
ax.plot(w_new, valfunc_vfi, '-o', label='VFI')
692692
693693
max_epochs = int(jnp.max(epochs_to_plot)) # Convert to Python int
694694
# iterate on epoch numbers
@@ -697,7 +697,7 @@ def plot_epochs(epochs_to_plot, quit_allowed=1, key=key):
697697
print(f"Progress: EPOCHs = {n}")
698698
if n in epochs_to_plot:
699699
valfunc_qlr = valfunc_from_qtable(qtable)
700-
error = compute_error(valfunc_qlr, valfunc_VFI)
700+
error = compute_error(valfunc_qlr, valfunc_vfi)
701701
702702
ax.plot(w_new, valfunc_qlr, '-o', label=f'QL:epochs={n}, mean error={error}')
703703

0 commit comments

Comments
 (0)