@@ -4,11 +4,11 @@ jupytext:
44 extension : .md
55 format_name : myst
66 format_version : 0.13
7- jupytext_version : 1.17.2
7+ jupytext_version : 1.17.1
88kernelspec :
9- display_name : Python 3
10- language : python
119 name : python3
10+ display_name : Python 3 (ipykernel)
11+ language : python
1212---
1313
1414(mccall)=
@@ -651,7 +651,10 @@ Plot mean unemployment duration as a function of $c$ in `c_vals`.
651651Here's a solution using Numba.
652652
653653``` {code-cell} ipython3
654- cdf = np.cumsum(q_default)
654+ # Convert JAX arrays to NumPy arrays for use with Numba
655+ q_default_np = np.array(q_default)
656+ w_default_np = np.array(w_default)
657+ cdf = np.cumsum(q_default_np)
655658
656659@numba.jit
657660def compute_stopping_time(w_bar, seed=1234):
@@ -660,7 +663,7 @@ def compute_stopping_time(w_bar, seed=1234):
660663 t = 1
661664 while True:
662665 # Generate a wage draw
663- w = w_default [qe.random.draw(cdf)]
666+ w = w_default_np [qe.random.draw(cdf)]
664667 # Stop when the draw is above the reservation wage
665668 if w >= w_bar:
666669 stopping_time = t
@@ -681,7 +684,8 @@ stop_times = np.empty_like(c_vals)
681684for i, c in enumerate(c_vals):
682685 mcm = McCallModel(c=c)
683686 w_bar = compute_reservation_wage_two(mcm)
684- stop_times[i] = compute_mean_stopping_time(w_bar)
687+ # Convert JAX scalar to Python float
688+ stop_times[i] = compute_mean_stopping_time(float(w_bar))
685689
686690fig, ax = plt.subplots()
687691
@@ -690,10 +694,8 @@ ax.set(xlabel="unemployment compensation", ylabel="months")
690694ax.legend()
691695
692696plt.show()
693-
694697```
695698
696-
697699And here's a solution using JAX.
698700
699701``` {code-cell} ipython3
0 commit comments