Skip to content

Commit d5aac7a

Browse files
committed
add geometric
1 parent 327c5d7 commit d5aac7a

File tree

1 file changed

+184
-0
lines changed

1 file changed

+184
-0
lines changed
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
"""
2+
Benchmark comparing parallel Numba vs optimized JAX for ex_mm1
3+
"""
4+
5+
import time
6+
import numpy as np
7+
import numba
8+
import jax
9+
import jax.numpy as jnp
10+
from functools import partial
11+
import quantecon as qe
12+
from typing import NamedTuple
13+
14+
# Try CPU JAX backend
15+
jax.config.update("jax_platform_name", "cpu")
16+
jax.config.update("jax_enable_x64", True)
17+
18+
19+
# Setup model parameters
20+
class McCallModel(NamedTuple):
21+
c: float = 25.0 # unemployment compensation
22+
β: float = 0.99 # discount factor
23+
w: jnp.ndarray = jnp.array([10.0, 20.0, 30.0, 40.0, 50.0, 60.0], dtype=jnp.float64)
24+
q: jnp.ndarray = jnp.array([0.1, 0.15, 0.2, 0.25, 0.2, 0.1], dtype=jnp.float64)
25+
26+
# Default values
27+
q_default = jnp.array([0.1, 0.15, 0.2, 0.25, 0.2, 0.1], dtype=jnp.float64)
28+
w_default = jnp.array([10.0, 20.0, 30.0, 40.0, 50.0, 60.0], dtype=jnp.float64)
29+
30+
# ============================================================================
31+
# PARALLEL NUMBA VERSION
32+
# ============================================================================
33+
34+
q_default_np = np.array(q_default)
35+
w_default_np = np.array(w_default)
36+
cdf_np = np.cumsum(q_default_np)
37+
38+
@numba.jit
39+
def compute_stopping_time_numba(w_bar, seed=1234):
40+
np.random.seed(seed)
41+
t = 1
42+
while True:
43+
w = w_default_np[qe.random.draw(cdf_np)]
44+
if w >= w_bar:
45+
stopping_time = t
46+
break
47+
else:
48+
t += 1
49+
return stopping_time
50+
51+
@numba.jit(parallel=True)
52+
def compute_mean_stopping_time_numba(w_bar, num_reps=100000):
53+
obs = np.empty(num_reps)
54+
for i in numba.prange(num_reps):
55+
obs[i] = compute_stopping_time_numba(w_bar, seed=i)
56+
return obs.mean()
57+
58+
# ============================================================================
59+
# OPTIMIZED JAX VERSION
60+
# ============================================================================
61+
62+
@jax.jit
63+
def _acceptance_probability(w_bar):
64+
"""
65+
Compute probability that an offer exceeds the reservation wage.
66+
"""
67+
accept_mass = jnp.where(w_default >= w_bar, q_default, 0.0)
68+
return jnp.sum(accept_mass)
69+
70+
@jax.jit
71+
def compute_stopping_time_jax(w_bar, key):
72+
"""
73+
Draw a stopping time by sampling directly from the geometric
74+
distribution implied by the acceptance probability.
75+
"""
76+
prob = _acceptance_probability(w_bar)
77+
def _sample(k):
78+
draw = jax.random.geometric(k, prob, dtype=jnp.int64)
79+
return jnp.asarray(draw, dtype=jnp.float64)
80+
return jax.lax.cond(
81+
prob <= 0.0,
82+
lambda _: jnp.array(jnp.inf, dtype=jnp.float64),
83+
_sample,
84+
operand=key
85+
)
86+
87+
@partial(jax.jit, static_argnames=('num_reps',))
88+
def compute_mean_stopping_time_jax(w_bar, num_reps=100000, seed=1234):
89+
"""
90+
Generate a mean stopping time over `num_reps` repetitions by repeatedly
91+
drawing from `compute_stopping_time`.
92+
"""
93+
key = jax.random.PRNGKey(seed)
94+
keys = jax.random.split(key, num_reps)
95+
# Vectorize compute_stopping_time and evaluate across keys
96+
compute_fn = jax.vmap(compute_stopping_time_jax, in_axes=(None, 0))
97+
obs = compute_fn(w_bar, keys)
98+
return jnp.mean(obs, dtype=jnp.float64)
99+
100+
# ============================================================================
101+
# BENCHMARK
102+
# ============================================================================
103+
104+
def benchmark(num_trials=5, num_reps=100000):
105+
"""
106+
Benchmark parallel Numba vs optimized JAX.
107+
"""
108+
w_bar = 35.0
109+
110+
print("="*70)
111+
print("Benchmark: Parallel Numba vs Optimized JAX (ex_mm1)")
112+
print("="*70)
113+
print(f"Number of MC replications: {num_reps:,}")
114+
print(f"Number of benchmark trials: {num_trials}")
115+
print(f"Reservation wage: {w_bar}")
116+
print(f"Number of CPU threads: {numba.config.NUMBA_NUM_THREADS}")
117+
print()
118+
119+
# Warm-up runs
120+
print("Warming up...")
121+
_ = compute_mean_stopping_time_numba(w_bar, num_reps=num_reps)
122+
_ = compute_mean_stopping_time_jax(w_bar, num_reps=num_reps).block_until_ready()
123+
print("Warm-up complete.\n")
124+
125+
results = {}
126+
127+
# Benchmark Numba (Parallel)
128+
print("Benchmarking Numba (Parallel)...")
129+
numba_times = []
130+
for i in range(num_trials):
131+
start = time.perf_counter()
132+
result = compute_mean_stopping_time_numba(w_bar, num_reps=num_reps)
133+
elapsed = time.perf_counter() - start
134+
numba_times.append(elapsed)
135+
print(f" Trial {i+1}: {elapsed:.4f} seconds")
136+
137+
numba_mean = np.mean(numba_times)
138+
numba_std = np.std(numba_times)
139+
results['Numba (Parallel)'] = (numba_mean, numba_std, result)
140+
print(f" Mean: {numba_mean:.4f} ± {numba_std:.4f} seconds")
141+
print(f" Result: {result:.4f}\n")
142+
143+
# Benchmark JAX (Optimized)
144+
print("Benchmarking JAX (Optimized)...")
145+
jax_times = []
146+
for i in range(num_trials):
147+
start = time.perf_counter()
148+
result = compute_mean_stopping_time_jax(w_bar, num_reps=num_reps).block_until_ready()
149+
elapsed = time.perf_counter() - start
150+
jax_times.append(elapsed)
151+
print(f" Trial {i+1}: {elapsed:.4f} seconds")
152+
153+
jax_mean = np.mean(jax_times)
154+
jax_std = np.std(jax_times)
155+
results['JAX (Optimized)'] = (jax_mean, jax_std, float(result))
156+
print(f" Mean: {jax_mean:.4f} ± {jax_std:.4f} seconds")
157+
print(f" Result: {float(result):.4f}\n")
158+
159+
# Summary
160+
print("="*70)
161+
print("SUMMARY")
162+
print("="*70)
163+
print(f"{'Implementation':<25} {'Time (s)':<20} {'Relative Performance'}")
164+
print("-"*70)
165+
166+
for name, (mean_time, std_time, _) in results.items():
167+
print(f"{name:<25} {mean_time:>6.4f} ± {std_time:<6.4f}")
168+
169+
print("-"*70)
170+
171+
# Determine winner
172+
if numba_mean < jax_mean:
173+
speedup = jax_mean / numba_mean
174+
print(f"\n🏆 WINNER: Numba (Parallel)")
175+
print(f" Numba is {speedup:.2f}x faster than JAX")
176+
else:
177+
speedup = numba_mean / jax_mean
178+
print(f"\n🏆 WINNER: JAX (Optimized)")
179+
print(f" JAX is {speedup:.2f}x faster than Numba")
180+
181+
print("="*70)
182+
183+
if __name__ == "__main__":
184+
benchmark()

0 commit comments

Comments
 (0)