Skip to content

Commit 54e7a02

Browse files
committed
fix lecture
1 parent 2591f55 commit 54e7a02

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

lectures/mccall_model.md

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ jupytext:
44
extension: .md
55
format_name: myst
66
format_version: 0.13
7-
jupytext_version: 1.17.2
7+
jupytext_version: 1.17.1
88
kernelspec:
9-
display_name: Python 3
10-
language: python
119
name: python3
10+
display_name: Python 3 (ipykernel)
11+
language: python
1212
---
1313

1414
(mccall)=
@@ -651,7 +651,10 @@ Plot mean unemployment duration as a function of $c$ in `c_vals`.
651651
Here's a solution using Numba.
652652

653653
```{code-cell} ipython3
654-
cdf = np.cumsum(q_default)
654+
# Convert JAX arrays to NumPy arrays for use with Numba
655+
q_default_np = np.array(q_default)
656+
w_default_np = np.array(w_default)
657+
cdf = np.cumsum(q_default_np)
655658
656659
@numba.jit
657660
def compute_stopping_time(w_bar, seed=1234):
@@ -660,7 +663,7 @@ def compute_stopping_time(w_bar, seed=1234):
660663
t = 1
661664
while True:
662665
# Generate a wage draw
663-
w = w_default[qe.random.draw(cdf)]
666+
w = w_default_np[qe.random.draw(cdf)]
664667
# Stop when the draw is above the reservation wage
665668
if w >= w_bar:
666669
stopping_time = t
@@ -681,7 +684,8 @@ stop_times = np.empty_like(c_vals)
681684
for i, c in enumerate(c_vals):
682685
mcm = McCallModel(c=c)
683686
w_bar = compute_reservation_wage_two(mcm)
684-
stop_times[i] = compute_mean_stopping_time(w_bar)
687+
# Convert JAX scalar to Python float
688+
stop_times[i] = compute_mean_stopping_time(float(w_bar))
685689
686690
fig, ax = plt.subplots()
687691
@@ -690,10 +694,8 @@ ax.set(xlabel="unemployment compensation", ylabel="months")
690694
ax.legend()
691695
692696
plt.show()
693-
694697
```
695698

696-
697699
And here's a solution using JAX.
698700

699701
```{code-cell} ipython3

0 commit comments

Comments
 (0)