Skip to content

Commit 2591f55

Browse files
committed
update benchmark code
1 parent d0ccfb1 commit 2591f55

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

lectures/benchmark_mccall.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def cond(state):
8484
t_final, _, _ = jax.lax.while_loop(cond, update, initial_state)
8585
return t_final
8686

87+
from functools import partial
88+
@partial(jax.jit, static_argnames=['num_reps'])
8789
def compute_mean_stopping_time_jax(w_bar, num_reps=100000, seed=1234):
8890
key = jax.random.PRNGKey(seed)
8991
keys = jax.random.split(key, num_reps)
@@ -99,7 +101,7 @@ def benchmark_numba():
99101
# Warmup
100102
mcm = McCallModel(c=25.0)
101103
w_bar = compute_reservation_wage_two(mcm)
102-
_ = compute_mean_stopping_time_numba(float(w_bar), num_reps=1000)
104+
_ = compute_mean_stopping_time_numba(float(w_bar), num_reps=10000)
103105

104106
# Actual benchmark
105107
start = time.time()
@@ -113,19 +115,22 @@ def benchmark_numba():
113115

114116
def benchmark_jax():
115117
c_vals = jnp.linspace(10, 40, 25)
116-
stop_times = np.empty_like(c_vals)
118+
stop_times = jnp.zeros_like(c_vals)
117119

118120
# Warmup - compile the functions
119121
model = McCallModel(c=25.0)
120122
w_bar = compute_reservation_wage_two(model)
121-
_ = compute_mean_stopping_time_jax(w_bar, num_reps=1000).block_until_ready()
123+
_ = compute_mean_stopping_time_jax(
124+
w_bar, num_reps=10000).block_until_ready()
122125

123126
# Actual benchmark
124127
start = time.time()
125128
for i, c in enumerate(c_vals):
126129
model = McCallModel(c=c)
127130
w_bar = compute_reservation_wage_two(model)
128-
stop_times[i] = compute_mean_stopping_time_jax(w_bar).block_until_ready()
131+
stop_times = stop_times.at[i].set(compute_mean_stopping_time_jax(
132+
w_bar, num_reps=10000).block_until_ready())
133+
129134
end = time.time()
130135

131136
return end - start, stop_times

0 commit comments

Comments
 (0)