@@ -225,7 +225,7 @@ def plot_value_function_seq(mcm, ax, num_plots=8):
225225
226226```{code-cell} ipython3
227227mcm = create_mccall_model()
228- valfunc_VFI , converged = vfi(mcm)
228+ valfunc_vfi , converged = vfi(mcm)
229229
230230fig, ax = plt.subplots(figsize=(10,6))
231231ax.set_xlabel('wage')
@@ -241,7 +241,7 @@ This is the approximation to the McCall worker's value function that is produce
241241We'll use this value function as a benchmark later after we have done some Q-learning.
242242
243243```{code-cell} ipython3
244- print(valfunc_VFI )
244+ print(valfunc_vfi )
245245```
246246
247247## Implied quality function $Q$
@@ -616,8 +616,8 @@ def valfunc_from_qtable(qtable):
616616 return jnp.max(qtable, axis=1)
617617
618618
619- def compute_error(valfunc, valfunc_VFI ):
620- return jnp.mean(jnp.abs(valfunc - valfunc_VFI ))
619+ def compute_error(valfunc, valfunc_vfi ):
620+ return jnp.mean(jnp.abs(valfunc - valfunc_vfi ))
621621```
622622
623623```{code-cell} ipython3
@@ -644,7 +644,7 @@ print(valfunc_qlr)
644644```{code-cell} ipython3
645645# plot
646646fig, ax = plt.subplots(figsize=(10,6))
647- ax.plot(w_default, valfunc_VFI , '-o', label='VFI')
647+ ax.plot(w_default, valfunc_vfi , '-o', label='VFI')
648648ax.plot(w_default, valfunc_qlr, '-o', label='QL')
649649ax.set_xlabel('wages')
650650ax.set_ylabel('optimal value')
@@ -675,8 +675,8 @@ plt.show()
675675```{code-cell} ipython3
676676# vfi
677677mcm = create_mccall_model(w=w_new, q=q_new)
678- valfunc_VFI , converged = vfi(mcm)
679- valfunc_VFI
678+ valfunc_vfi , converged = vfi(mcm)
679+ valfunc_vfi
680680```
681681
682682```{code-cell} ipython3
@@ -688,7 +688,7 @@ def plot_epochs(epochs_to_plot, quit_allowed=1, key=key):
688688 epochs_to_plot = jnp.asarray(epochs_to_plot)
689689 # plot
690690 fig, ax = plt.subplots(figsize=(10,6))
691- ax.plot(w_new, valfunc_VFI , '-o', label='VFI')
691+ ax.plot(w_new, valfunc_vfi , '-o', label='VFI')
692692
693693 max_epochs = int(jnp.max(epochs_to_plot)) # Convert to Python int
694694 # iterate on epoch numbers
@@ -697,7 +697,7 @@ def plot_epochs(epochs_to_plot, quit_allowed=1, key=key):
697697 print(f"Progress: EPOCHs = {n}")
698698 if n in epochs_to_plot:
699699 valfunc_qlr = valfunc_from_qtable(qtable)
700- error = compute_error(valfunc_qlr, valfunc_VFI )
700+ error = compute_error(valfunc_qlr, valfunc_vfi )
701701
702702 ax.plot(w_new, valfunc_qlr, '-o', label=f'QL:epochs={n}, mean error={error}')
703703
0 commit comments