@@ -84,6 +84,8 @@ def cond(state):
8484 t_final , _ , _ = jax .lax .while_loop (cond , update , initial_state )
8585 return t_final
8686
87+ from functools import partial
88+ @partial (jax .jit , static_argnames = ['num_reps' ])
8789def compute_mean_stopping_time_jax (w_bar , num_reps = 100000 , seed = 1234 ):
8890 key = jax .random .PRNGKey (seed )
8991 keys = jax .random .split (key , num_reps )
@@ -99,7 +101,7 @@ def benchmark_numba():
99101 # Warmup
100102 mcm = McCallModel (c = 25.0 )
101103 w_bar = compute_reservation_wage_two (mcm )
102- _ = compute_mean_stopping_time_numba (float (w_bar ), num_reps = 1000 )
104+ _ = compute_mean_stopping_time_numba (float (w_bar ), num_reps = 10000 )
103105
104106 # Actual benchmark
105107 start = time .time ()
@@ -113,19 +115,22 @@ def benchmark_numba():
113115
114116def benchmark_jax ():
115117 c_vals = jnp .linspace (10 , 40 , 25 )
116- stop_times = np . empty_like (c_vals )
118+ stop_times = jnp . zeros_like (c_vals )
117119
118120 # Warmup - compile the functions
119121 model = McCallModel (c = 25.0 )
120122 w_bar = compute_reservation_wage_two (model )
121- _ = compute_mean_stopping_time_jax (w_bar , num_reps = 1000 ).block_until_ready ()
123+ _ = compute_mean_stopping_time_jax (
124+ w_bar , num_reps = 10000 ).block_until_ready ()
122125
123126 # Actual benchmark
124127 start = time .time ()
125128 for i , c in enumerate (c_vals ):
126129 model = McCallModel (c = c )
127130 w_bar = compute_reservation_wage_two (model )
128- stop_times [i ] = compute_mean_stopping_time_jax (w_bar ).block_until_ready ()
131+ stop_times = stop_times .at [i ].set (compute_mean_stopping_time_jax (
132+ w_bar , num_reps = 10000 ).block_until_ready ())
133+
129134 end = time .time ()
130135
131136 return end - start , stop_times
0 commit comments