@@ -127,7 +127,8 @@ class BenchResult:
127127 impl : str
128128 case : str
129129 status : str
130- load_time_ms : float
130+ cold_load_time_ms : float
131+ warm_load_time_ms : float
131132 tokens_produced : int
132133 bytes_processed : int
133134 avg_time_ms : float
@@ -136,6 +137,23 @@ class BenchResult:
136137 notes : str = ""
137138
138139
140+ @dataclass
141+ class BenchAggregate :
142+ impl : str
143+ case : str
144+ n : int
145+ tokens_per_sec_mean : float
146+ tokens_per_sec_std : float
147+ cold_load_time_ms_mean : float
148+ cold_load_time_ms_std : float
149+ warm_load_time_ms_mean : float
150+ warm_load_time_ms_std : float
151+ mb_per_sec_mean : float
152+ mb_per_sec_std : float
153+ tokens_produced_mean : float
154+ tokens_produced_std : float
155+
156+
139157def _default_cases () -> List [BenchCase ]:
140158 english = (
141159 "The quick brown fox jumps over the lazy dog. "
@@ -177,7 +195,12 @@ def _run_single(
177195 try :
178196 t0 = time .perf_counter ()
179197 load_fn ()
180- load_ms = (time .perf_counter () - t0 ) * 1000.0
198+ cold_load_ms = (time .perf_counter () - t0 ) * 1000.0
199+
200+ # Warm load measurement: call load again after the cold mapping/parse.
201+ t1 = time .perf_counter ()
202+ load_fn ()
203+ warm_load_ms = (time .perf_counter () - t1 ) * 1000.0
181204
182205 payload = case .text * case .repeat
183206 payload_bytes = payload .encode ("utf-8" )
@@ -203,7 +226,8 @@ def _run_single(
203226 impl = impl_name ,
204227 case = case .name ,
205228 status = "OK" ,
206- load_time_ms = load_ms ,
229+ cold_load_time_ms = cold_load_ms ,
230+ warm_load_time_ms = warm_load_ms ,
207231 tokens_produced = avg_tokens ,
208232 bytes_processed = len (payload_bytes ),
209233 avg_time_ms = avg_t * 1000.0 ,
@@ -215,7 +239,8 @@ def _run_single(
215239 impl = impl_name ,
216240 case = case .name ,
217241 status = "FAIL" ,
218- load_time_ms = 0.0 ,
242+ cold_load_time_ms = 0.0 ,
243+ warm_load_time_ms = 0.0 ,
219244 tokens_produced = 0 ,
220245 bytes_processed = 0 ,
221246 avg_time_ms = 0.0 ,
@@ -308,6 +333,70 @@ def _write_outputs(results: List[BenchResult], out_dir: Path) -> None:
308333 w .writerow (r .__dict__ )
309334
310335
336+ def _std (values : List [float ], mean : float ) -> float :
337+ if not values :
338+ return 0.0
339+ if len (values ) == 1 :
340+ return 0.0
341+ var = sum ((v - mean ) ** 2 for v in values ) / float (len (values ) - 1 )
342+ return var ** 0.5
343+
344+
345+ def _aggregate (results : List [BenchResult ]) -> List [BenchAggregate ]:
346+ ok = [r for r in results if r .status == "OK" ]
347+ groups : Dict [Tuple [str , str ], List [BenchResult ]] = {}
348+ for r in ok :
349+ groups .setdefault ((r .impl , r .case ), []).append (r )
350+
351+ aggs : List [BenchAggregate ] = []
352+ for (impl , case ), rs in sorted (groups .items ()):
353+ tps = [float (r .tokens_per_sec ) for r in rs ]
354+ cold_lms = [float (r .cold_load_time_ms ) for r in rs ]
355+ warm_lms = [float (r .warm_load_time_ms ) for r in rs ]
356+ mbs = [float (r .mb_per_sec ) for r in rs ]
357+ tok = [float (r .tokens_produced ) for r in rs ]
358+
359+ tps_m = sum (tps ) / float (len (tps ))
360+ cold_lms_m = sum (cold_lms ) / float (len (cold_lms ))
361+ warm_lms_m = sum (warm_lms ) / float (len (warm_lms ))
362+ mbs_m = sum (mbs ) / float (len (mbs ))
363+ tok_m = sum (tok ) / float (len (tok ))
364+
365+ aggs .append (
366+ BenchAggregate (
367+ impl = impl ,
368+ case = case ,
369+ n = len (rs ),
370+ tokens_per_sec_mean = tps_m ,
371+ tokens_per_sec_std = _std (tps , tps_m ),
372+ cold_load_time_ms_mean = cold_lms_m ,
373+ cold_load_time_ms_std = _std (cold_lms , cold_lms_m ),
374+ warm_load_time_ms_mean = warm_lms_m ,
375+ warm_load_time_ms_std = _std (warm_lms , warm_lms_m ),
376+ mb_per_sec_mean = mbs_m ,
377+ mb_per_sec_std = _std (mbs , mbs_m ),
378+ tokens_produced_mean = tok_m ,
379+ tokens_produced_std = _std (tok , tok_m ),
380+ )
381+ )
382+ return aggs
383+
384+
385+ def _write_summary (aggs : List [BenchAggregate ], out_dir : Path ) -> None :
386+ out_dir .mkdir (parents = True , exist_ok = True )
387+
388+ json_path = out_dir / "benchmark_summary.json"
389+ with open (json_path , "w" , encoding = "utf-8" ) as f :
390+ json .dump ([a .__dict__ for a in aggs ], f , ensure_ascii = False , indent = 2 )
391+
392+ csv_path = out_dir / "benchmark_summary.csv"
393+ with open (csv_path , "w" , encoding = "utf-8" , newline = "" ) as f :
394+ w = csv .DictWriter (f , fieldnames = list (BenchAggregate .__dataclass_fields__ .keys ()))
395+ w .writeheader ()
396+ for a in aggs :
397+ w .writerow (a .__dict__ )
398+
399+
311400def _write_metadata (metadata : Dict [str , Any ], out_dir : Path ) -> None :
312401 out_dir .mkdir (parents = True , exist_ok = True )
313402 meta_path = out_dir / "metadata.json"
@@ -374,9 +463,10 @@ def main() -> int:
374463 ap = argparse .ArgumentParser (prog = "benchmark_suite" )
375464 ap .add_argument ("--device" , default = "cpu" , choices = ["cpu" , "auto" , "cuda" , "rocm" ])
376465 ap .add_argument ("--iterations" , type = int , default = 10 )
377- ap .add_argument ("--warmup" , type = int , default = 2 )
466+ ap .add_argument ("--warmup" , type = int , default = 5 )
378467 ap .add_argument ("--out" , default = str (Path ("benchmark_results" ) / _now_tag ()))
379468 ap .add_argument ("--include-hf" , action = "store_true" )
469+ ap .add_argument ("--repeats" , type = int , default = 10 )
380470 args = ap .parse_args ()
381471
382472 cases = _default_cases ()
@@ -437,38 +527,53 @@ def main() -> int:
437527 print (f" - { c .name } : ~{ approx_mb :.2f} MB" )
438528 print ("-" * 90 )
439529
440- for impl_name , load_fn , tok_fn in impls :
441- for case in cases :
442- r = _run_single (
443- impl_name = impl_name ,
444- case = case ,
445- load_fn = load_fn ,
446- tokenize_fn = tok_fn ,
447- iterations = args . iterations ,
448- warmup = args . warmup ,
449- )
450- results . append ( r )
451- if r . status == "OK" :
452- print (
453- f"[OK] { r . impl :<22 } { r . case :<8 } "
454- f"load= { r . load_time_ms :>8.2f } ms "
455- f"avg= { r . avg_time_ms :>8.2f } ms "
456- f"tok= { r . tokens_produced :>8 } "
457- f"tps= { r . tokens_per_sec :>12.0f } "
458- f"mbps= { r . mb_per_sec :>8.2f } "
530+ repeats = int ( args . repeats )
531+ if repeats < 1 :
532+ repeats = 1
533+
534+ print ( f"Repeats: { repeats } " )
535+ print ( "-" * 90 )
536+
537+ for rep in range ( repeats ):
538+ if repeats > 1 :
539+ print ( f"REPEAT { rep + 1 } / { repeats } " )
540+ for impl_name , load_fn , tok_fn in impls :
541+ for case in cases :
542+ r = _run_single (
543+ impl_name = impl_name ,
544+ case = case ,
545+ load_fn = load_fn ,
546+ tokenize_fn = tok_fn ,
547+ iterations = args . iterations ,
548+ warmup = args . warmup ,
459549 )
460- else :
461- print (f"[FAIL] { r .impl :<22} { r .case :<8} { r .notes } " )
550+ results .append (r )
551+ if r .status == "OK" :
552+ print (
553+ f"[OK] { r .impl :<22} { r .case :<8} "
554+ f"cold_load={ r .cold_load_time_ms :>8.2f} ms "
555+ f"warm_load={ r .warm_load_time_ms :>8.2f} ms "
556+ f"avg={ r .avg_time_ms :>8.2f} ms "
557+ f"tok={ r .tokens_produced :>8} "
558+ f"tps={ r .tokens_per_sec :>12.0f} "
559+ f"mbps={ r .mb_per_sec :>8.2f} "
560+ )
561+ else :
562+ print (f"[FAIL] { r .impl :<22} { r .case :<8} { r .notes } " )
462563
463564 out_dir = Path (args .out )
464565 _write_outputs (results , out_dir )
465566 _write_metadata (metadata , out_dir )
567+ aggs = _aggregate (results )
568+ _write_summary (aggs , out_dir )
466569 _plot (results , out_dir )
467570
468571 print ("-" * 90 )
469572 print ("WROTE:" )
470573 print (f" - { out_dir / 'benchmark_results.json' } " )
471574 print (f" - { out_dir / 'benchmark_results.csv' } " )
575+ print (f" - { out_dir / 'benchmark_summary.json' } " )
576+ print (f" - { out_dir / 'benchmark_summary.csv' } " )
472577 print (f" - { out_dir / 'metadata.json' } " )
473578 print (f" - { out_dir / 'tokens_per_sec.png' } (if matplotlib installed)" )
474579 print (f" - { out_dir / 'mb_per_sec.png' } (if matplotlib installed)" )
0 commit comments