|
| 1 | +""" |
| 2 | +Benchmark comparing parallel Numba vs optimized JAX for ex_mm1 |
| 3 | +""" |
| 4 | + |
| 5 | +import time |
| 6 | +import numpy as np |
| 7 | +import numba |
| 8 | +import jax |
| 9 | +import jax.numpy as jnp |
| 10 | +from functools import partial |
| 11 | +import quantecon as qe |
| 12 | +from typing import NamedTuple |
| 13 | + |
| 14 | +# Try CPU JAX backend |
| 15 | +jax.config.update("jax_platform_name", "cpu") |
| 16 | +jax.config.update("jax_enable_x64", True) |
| 17 | + |
| 18 | + |
| 19 | +# Setup model parameters |
| 20 | +class McCallModel(NamedTuple): |
| 21 | + c: float = 25.0 # unemployment compensation |
| 22 | + β: float = 0.99 # discount factor |
| 23 | + w: jnp.ndarray = jnp.array([10.0, 20.0, 30.0, 40.0, 50.0, 60.0], dtype=jnp.float64) |
| 24 | + q: jnp.ndarray = jnp.array([0.1, 0.15, 0.2, 0.25, 0.2, 0.1], dtype=jnp.float64) |
| 25 | + |
| 26 | +# Default values |
| 27 | +q_default = jnp.array([0.1, 0.15, 0.2, 0.25, 0.2, 0.1], dtype=jnp.float64) |
| 28 | +w_default = jnp.array([10.0, 20.0, 30.0, 40.0, 50.0, 60.0], dtype=jnp.float64) |
| 29 | + |
| 30 | +# ============================================================================ |
| 31 | +# PARALLEL NUMBA VERSION |
| 32 | +# ============================================================================ |
| 33 | + |
| 34 | +q_default_np = np.array(q_default) |
| 35 | +w_default_np = np.array(w_default) |
| 36 | +cdf_np = np.cumsum(q_default_np) |
| 37 | + |
| 38 | +@numba.jit |
| 39 | +def compute_stopping_time_numba(w_bar, seed=1234): |
| 40 | + np.random.seed(seed) |
| 41 | + t = 1 |
| 42 | + while True: |
| 43 | + w = w_default_np[qe.random.draw(cdf_np)] |
| 44 | + if w >= w_bar: |
| 45 | + stopping_time = t |
| 46 | + break |
| 47 | + else: |
| 48 | + t += 1 |
| 49 | + return stopping_time |
| 50 | + |
| 51 | +@numba.jit(parallel=True) |
| 52 | +def compute_mean_stopping_time_numba(w_bar, num_reps=100000): |
| 53 | + obs = np.empty(num_reps) |
| 54 | + for i in numba.prange(num_reps): |
| 55 | + obs[i] = compute_stopping_time_numba(w_bar, seed=i) |
| 56 | + return obs.mean() |
| 57 | + |
| 58 | +# ============================================================================ |
| 59 | +# OPTIMIZED JAX VERSION |
| 60 | +# ============================================================================ |
| 61 | + |
| 62 | +@jax.jit |
| 63 | +def _acceptance_probability(w_bar): |
| 64 | + """ |
| 65 | + Compute probability that an offer exceeds the reservation wage. |
| 66 | + """ |
| 67 | + accept_mass = jnp.where(w_default >= w_bar, q_default, 0.0) |
| 68 | + return jnp.sum(accept_mass) |
| 69 | + |
| 70 | +@jax.jit |
| 71 | +def compute_stopping_time_jax(w_bar, key): |
| 72 | + """ |
| 73 | + Draw a stopping time by sampling directly from the geometric |
| 74 | + distribution implied by the acceptance probability. |
| 75 | + """ |
| 76 | + prob = _acceptance_probability(w_bar) |
| 77 | + def _sample(k): |
| 78 | + draw = jax.random.geometric(k, prob, dtype=jnp.int64) |
| 79 | + return jnp.asarray(draw, dtype=jnp.float64) |
| 80 | + return jax.lax.cond( |
| 81 | + prob <= 0.0, |
| 82 | + lambda _: jnp.array(jnp.inf, dtype=jnp.float64), |
| 83 | + _sample, |
| 84 | + operand=key |
| 85 | + ) |
| 86 | + |
| 87 | +@partial(jax.jit, static_argnames=('num_reps',)) |
| 88 | +def compute_mean_stopping_time_jax(w_bar, num_reps=100000, seed=1234): |
| 89 | + """ |
| 90 | + Generate a mean stopping time over `num_reps` repetitions by repeatedly |
| 91 | + drawing from `compute_stopping_time`. |
| 92 | + """ |
| 93 | + key = jax.random.PRNGKey(seed) |
| 94 | + keys = jax.random.split(key, num_reps) |
| 95 | + # Vectorize compute_stopping_time and evaluate across keys |
| 96 | + compute_fn = jax.vmap(compute_stopping_time_jax, in_axes=(None, 0)) |
| 97 | + obs = compute_fn(w_bar, keys) |
| 98 | + return jnp.mean(obs, dtype=jnp.float64) |
| 99 | + |
| 100 | +# ============================================================================ |
| 101 | +# BENCHMARK |
| 102 | +# ============================================================================ |
| 103 | + |
| 104 | +def benchmark(num_trials=5, num_reps=100000): |
| 105 | + """ |
| 106 | + Benchmark parallel Numba vs optimized JAX. |
| 107 | + """ |
| 108 | + w_bar = 35.0 |
| 109 | + |
| 110 | + print("="*70) |
| 111 | + print("Benchmark: Parallel Numba vs Optimized JAX (ex_mm1)") |
| 112 | + print("="*70) |
| 113 | + print(f"Number of MC replications: {num_reps:,}") |
| 114 | + print(f"Number of benchmark trials: {num_trials}") |
| 115 | + print(f"Reservation wage: {w_bar}") |
| 116 | + print(f"Number of CPU threads: {numba.config.NUMBA_NUM_THREADS}") |
| 117 | + print() |
| 118 | + |
| 119 | + # Warm-up runs |
| 120 | + print("Warming up...") |
| 121 | + _ = compute_mean_stopping_time_numba(w_bar, num_reps=num_reps) |
| 122 | + _ = compute_mean_stopping_time_jax(w_bar, num_reps=num_reps).block_until_ready() |
| 123 | + print("Warm-up complete.\n") |
| 124 | + |
| 125 | + results = {} |
| 126 | + |
| 127 | + # Benchmark Numba (Parallel) |
| 128 | + print("Benchmarking Numba (Parallel)...") |
| 129 | + numba_times = [] |
| 130 | + for i in range(num_trials): |
| 131 | + start = time.perf_counter() |
| 132 | + result = compute_mean_stopping_time_numba(w_bar, num_reps=num_reps) |
| 133 | + elapsed = time.perf_counter() - start |
| 134 | + numba_times.append(elapsed) |
| 135 | + print(f" Trial {i+1}: {elapsed:.4f} seconds") |
| 136 | + |
| 137 | + numba_mean = np.mean(numba_times) |
| 138 | + numba_std = np.std(numba_times) |
| 139 | + results['Numba (Parallel)'] = (numba_mean, numba_std, result) |
| 140 | + print(f" Mean: {numba_mean:.4f} ± {numba_std:.4f} seconds") |
| 141 | + print(f" Result: {result:.4f}\n") |
| 142 | + |
| 143 | + # Benchmark JAX (Optimized) |
| 144 | + print("Benchmarking JAX (Optimized)...") |
| 145 | + jax_times = [] |
| 146 | + for i in range(num_trials): |
| 147 | + start = time.perf_counter() |
| 148 | + result = compute_mean_stopping_time_jax(w_bar, num_reps=num_reps).block_until_ready() |
| 149 | + elapsed = time.perf_counter() - start |
| 150 | + jax_times.append(elapsed) |
| 151 | + print(f" Trial {i+1}: {elapsed:.4f} seconds") |
| 152 | + |
| 153 | + jax_mean = np.mean(jax_times) |
| 154 | + jax_std = np.std(jax_times) |
| 155 | + results['JAX (Optimized)'] = (jax_mean, jax_std, float(result)) |
| 156 | + print(f" Mean: {jax_mean:.4f} ± {jax_std:.4f} seconds") |
| 157 | + print(f" Result: {float(result):.4f}\n") |
| 158 | + |
| 159 | + # Summary |
| 160 | + print("="*70) |
| 161 | + print("SUMMARY") |
| 162 | + print("="*70) |
| 163 | + print(f"{'Implementation':<25} {'Time (s)':<20} {'Relative Performance'}") |
| 164 | + print("-"*70) |
| 165 | + |
| 166 | + for name, (mean_time, std_time, _) in results.items(): |
| 167 | + print(f"{name:<25} {mean_time:>6.4f} ± {std_time:<6.4f}") |
| 168 | + |
| 169 | + print("-"*70) |
| 170 | + |
| 171 | + # Determine winner |
| 172 | + if numba_mean < jax_mean: |
| 173 | + speedup = jax_mean / numba_mean |
| 174 | + print(f"\n🏆 WINNER: Numba (Parallel)") |
| 175 | + print(f" Numba is {speedup:.2f}x faster than JAX") |
| 176 | + else: |
| 177 | + speedup = numba_mean / jax_mean |
| 178 | + print(f"\n🏆 WINNER: JAX (Optimized)") |
| 179 | + print(f" JAX is {speedup:.2f}x faster than Numba") |
| 180 | + |
| 181 | + print("="*70) |
| 182 | + |
| 183 | +if __name__ == "__main__": |
| 184 | + benchmark() |
0 commit comments