Skip to content

Commit bbc6bde

Browse files
.
1 parent 6b2b2f2 commit bbc6bde

File tree

1 file changed

+131
-26
lines changed

1 file changed

+131
-26
lines changed

benchmark_suite.py

Lines changed: 131 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ class BenchResult:
127127
impl: str
128128
case: str
129129
status: str
130-
load_time_ms: float
130+
cold_load_time_ms: float
131+
warm_load_time_ms: float
131132
tokens_produced: int
132133
bytes_processed: int
133134
avg_time_ms: float
@@ -136,6 +137,23 @@ class BenchResult:
136137
notes: str = ""
137138

138139

140+
@dataclass
141+
class BenchAggregate:
142+
impl: str
143+
case: str
144+
n: int
145+
tokens_per_sec_mean: float
146+
tokens_per_sec_std: float
147+
cold_load_time_ms_mean: float
148+
cold_load_time_ms_std: float
149+
warm_load_time_ms_mean: float
150+
warm_load_time_ms_std: float
151+
mb_per_sec_mean: float
152+
mb_per_sec_std: float
153+
tokens_produced_mean: float
154+
tokens_produced_std: float
155+
156+
139157
def _default_cases() -> List[BenchCase]:
140158
english = (
141159
"The quick brown fox jumps over the lazy dog. "
@@ -177,7 +195,12 @@ def _run_single(
177195
try:
178196
t0 = time.perf_counter()
179197
load_fn()
180-
load_ms = (time.perf_counter() - t0) * 1000.0
198+
cold_load_ms = (time.perf_counter() - t0) * 1000.0
199+
200+
# Warm load measurement: call load again after the cold mapping/parse.
201+
t1 = time.perf_counter()
202+
load_fn()
203+
warm_load_ms = (time.perf_counter() - t1) * 1000.0
181204

182205
payload = case.text * case.repeat
183206
payload_bytes = payload.encode("utf-8")
@@ -203,7 +226,8 @@ def _run_single(
203226
impl=impl_name,
204227
case=case.name,
205228
status="OK",
206-
load_time_ms=load_ms,
229+
cold_load_time_ms=cold_load_ms,
230+
warm_load_time_ms=warm_load_ms,
207231
tokens_produced=avg_tokens,
208232
bytes_processed=len(payload_bytes),
209233
avg_time_ms=avg_t * 1000.0,
@@ -215,7 +239,8 @@ def _run_single(
215239
impl=impl_name,
216240
case=case.name,
217241
status="FAIL",
218-
load_time_ms=0.0,
242+
cold_load_time_ms=0.0,
243+
warm_load_time_ms=0.0,
219244
tokens_produced=0,
220245
bytes_processed=0,
221246
avg_time_ms=0.0,
@@ -308,6 +333,70 @@ def _write_outputs(results: List[BenchResult], out_dir: Path) -> None:
308333
w.writerow(r.__dict__)
309334

310335

336+
def _std(values: List[float], mean: float) -> float:
337+
if not values:
338+
return 0.0
339+
if len(values) == 1:
340+
return 0.0
341+
var = sum((v - mean) ** 2 for v in values) / float(len(values) - 1)
342+
return var ** 0.5
343+
344+
345+
def _aggregate(results: List[BenchResult]) -> List[BenchAggregate]:
346+
ok = [r for r in results if r.status == "OK"]
347+
groups: Dict[Tuple[str, str], List[BenchResult]] = {}
348+
for r in ok:
349+
groups.setdefault((r.impl, r.case), []).append(r)
350+
351+
aggs: List[BenchAggregate] = []
352+
for (impl, case), rs in sorted(groups.items()):
353+
tps = [float(r.tokens_per_sec) for r in rs]
354+
cold_lms = [float(r.cold_load_time_ms) for r in rs]
355+
warm_lms = [float(r.warm_load_time_ms) for r in rs]
356+
mbs = [float(r.mb_per_sec) for r in rs]
357+
tok = [float(r.tokens_produced) for r in rs]
358+
359+
tps_m = sum(tps) / float(len(tps))
360+
cold_lms_m = sum(cold_lms) / float(len(cold_lms))
361+
warm_lms_m = sum(warm_lms) / float(len(warm_lms))
362+
mbs_m = sum(mbs) / float(len(mbs))
363+
tok_m = sum(tok) / float(len(tok))
364+
365+
aggs.append(
366+
BenchAggregate(
367+
impl=impl,
368+
case=case,
369+
n=len(rs),
370+
tokens_per_sec_mean=tps_m,
371+
tokens_per_sec_std=_std(tps, tps_m),
372+
cold_load_time_ms_mean=cold_lms_m,
373+
cold_load_time_ms_std=_std(cold_lms, cold_lms_m),
374+
warm_load_time_ms_mean=warm_lms_m,
375+
warm_load_time_ms_std=_std(warm_lms, warm_lms_m),
376+
mb_per_sec_mean=mbs_m,
377+
mb_per_sec_std=_std(mbs, mbs_m),
378+
tokens_produced_mean=tok_m,
379+
tokens_produced_std=_std(tok, tok_m),
380+
)
381+
)
382+
return aggs
383+
384+
385+
def _write_summary(aggs: List[BenchAggregate], out_dir: Path) -> None:
386+
out_dir.mkdir(parents=True, exist_ok=True)
387+
388+
json_path = out_dir / "benchmark_summary.json"
389+
with open(json_path, "w", encoding="utf-8") as f:
390+
json.dump([a.__dict__ for a in aggs], f, ensure_ascii=False, indent=2)
391+
392+
csv_path = out_dir / "benchmark_summary.csv"
393+
with open(csv_path, "w", encoding="utf-8", newline="") as f:
394+
w = csv.DictWriter(f, fieldnames=list(BenchAggregate.__dataclass_fields__.keys()))
395+
w.writeheader()
396+
for a in aggs:
397+
w.writerow(a.__dict__)
398+
399+
311400
def _write_metadata(metadata: Dict[str, Any], out_dir: Path) -> None:
312401
out_dir.mkdir(parents=True, exist_ok=True)
313402
meta_path = out_dir / "metadata.json"
@@ -374,9 +463,10 @@ def main() -> int:
374463
ap = argparse.ArgumentParser(prog="benchmark_suite")
375464
ap.add_argument("--device", default="cpu", choices=["cpu", "auto", "cuda", "rocm"])
376465
ap.add_argument("--iterations", type=int, default=10)
377-
ap.add_argument("--warmup", type=int, default=2)
466+
ap.add_argument("--warmup", type=int, default=5)
378467
ap.add_argument("--out", default=str(Path("benchmark_results") / _now_tag()))
379468
ap.add_argument("--include-hf", action="store_true")
469+
ap.add_argument("--repeats", type=int, default=10)
380470
args = ap.parse_args()
381471

382472
cases = _default_cases()
@@ -437,38 +527,53 @@ def main() -> int:
437527
print(f" - {c.name}: ~{approx_mb:.2f} MB")
438528
print("-" * 90)
439529

440-
for impl_name, load_fn, tok_fn in impls:
441-
for case in cases:
442-
r = _run_single(
443-
impl_name=impl_name,
444-
case=case,
445-
load_fn=load_fn,
446-
tokenize_fn=tok_fn,
447-
iterations=args.iterations,
448-
warmup=args.warmup,
449-
)
450-
results.append(r)
451-
if r.status == "OK":
452-
print(
453-
f"[OK] {r.impl:<22} {r.case:<8} "
454-
f"load={r.load_time_ms:>8.2f}ms "
455-
f"avg={r.avg_time_ms:>8.2f}ms "
456-
f"tok={r.tokens_produced:>8} "
457-
f"tps={r.tokens_per_sec:>12.0f} "
458-
f"mbps={r.mb_per_sec:>8.2f}"
530+
repeats = int(args.repeats)
531+
if repeats < 1:
532+
repeats = 1
533+
534+
print(f"Repeats: {repeats}")
535+
print("-" * 90)
536+
537+
for rep in range(repeats):
538+
if repeats > 1:
539+
print(f"REPEAT {rep + 1}/{repeats}")
540+
for impl_name, load_fn, tok_fn in impls:
541+
for case in cases:
542+
r = _run_single(
543+
impl_name=impl_name,
544+
case=case,
545+
load_fn=load_fn,
546+
tokenize_fn=tok_fn,
547+
iterations=args.iterations,
548+
warmup=args.warmup,
459549
)
460-
else:
461-
print(f"[FAIL] {r.impl:<22} {r.case:<8} {r.notes}")
550+
results.append(r)
551+
if r.status == "OK":
552+
print(
553+
f"[OK] {r.impl:<22} {r.case:<8} "
554+
f"cold_load={r.cold_load_time_ms:>8.2f}ms "
555+
f"warm_load={r.warm_load_time_ms:>8.2f}ms "
556+
f"avg={r.avg_time_ms:>8.2f}ms "
557+
f"tok={r.tokens_produced:>8} "
558+
f"tps={r.tokens_per_sec:>12.0f} "
559+
f"mbps={r.mb_per_sec:>8.2f}"
560+
)
561+
else:
562+
print(f"[FAIL] {r.impl:<22} {r.case:<8} {r.notes}")
462563

463564
out_dir = Path(args.out)
464565
_write_outputs(results, out_dir)
465566
_write_metadata(metadata, out_dir)
567+
aggs = _aggregate(results)
568+
_write_summary(aggs, out_dir)
466569
_plot(results, out_dir)
467570

468571
print("-" * 90)
469572
print("WROTE:")
470573
print(f" - {out_dir / 'benchmark_results.json'}")
471574
print(f" - {out_dir / 'benchmark_results.csv'}")
575+
print(f" - {out_dir / 'benchmark_summary.json'}")
576+
print(f" - {out_dir / 'benchmark_summary.csv'}")
472577
print(f" - {out_dir / 'metadata.json'}")
473578
print(f" - {out_dir / 'tokens_per_sec.png'} (if matplotlib installed)")
474579
print(f" - {out_dir / 'mb_per_sec.png'} (if matplotlib installed)")

0 commit comments

Comments
 (0)