@@ -547,13 +547,17 @@ def _create_maximum_context_prompt(self) -> str:
547547 return extended_context
548548
549549 def run_single_benchmark (self , config : BenchmarkConfig ) -> BenchmarkResult :
550- """Run a single benchmark configuration"""
550+ """Run a single benchmark configuration with proper warmup """
551551 print (f"\n { '=' * 60 } " )
552552 print (f"Running: { config .name } " )
553553 print (f"Description: { config .description } " )
554554 print (f"Max tokens: { config .max_tokens } " )
555555 print (f"{ '=' * 60 } " )
556556
557+ # Performance measurement parameters
558+ WARMUP_RUNS = 2 # Warmup runs to eliminate cold start effects
559+ MEASUREMENT_RUNS = 3 # Multiple measurement runs for reliability
560+
557561 # Create temporary prompt file
558562 with tempfile .NamedTemporaryFile (mode = "w" , delete = False , suffix = ".txt" ) as f :
559563 f .write (config .prompt )
@@ -571,102 +575,183 @@ def run_single_benchmark(self, config: BenchmarkConfig) -> BenchmarkResult:
571575 config .prompt ,
572576 "--max-tokens" ,
573577 str (config .max_tokens ),
574- # Remove --verbose flag as it requires an argument in newer mlx-lm
575578 ]
576579
577- # Record memory before
580+ # Clear MLX cache before starting
581+ print (f"🧹 Clearing MLX cache..." )
578582 mx .clear_cache ()
579- initial_memory = mx .get_active_memory ()
580-
581- # Run benchmark
582- start_time = time .perf_counter ()
583- result = subprocess .run (
584- cmd , capture_output = True , text = True , timeout = 300 # 5 minute timeout
585- )
586- end_time = time .perf_counter ()
587-
588- if result .returncode != 0 :
589- print (f"Error running benchmark: { result .stderr } " )
590- raise RuntimeError (f"Benchmark failed: { result .stderr } " )
591-
592- # Parse output
593- output_lines = result .stdout .strip ().split ("\n " )
594-
595- # Find the generated text (between ========== markers)
596- generated_text = ""
597- in_generation = False
598- prompt_tokens = 0
599- generation_tokens = 0
600- prompt_speed = 0.0
601- generation_speed = 0.0
602- peak_memory_str = ""
603-
604- for line in output_lines :
605- if line .strip () == "==========" :
606- in_generation = not in_generation
607- elif in_generation :
608- generated_text += line + "\n "
609- elif "Prompt:" in line and "tokens-per-sec" in line :
610- # Parse: "Prompt: 13 tokens, 310.367 tokens-per-sec"
611- parts = line .split ("," )
612- prompt_tokens = int (parts [0 ].split (":" )[1 ].strip ().split ()[0 ])
613- prompt_speed = float (parts [1 ].strip ().split ()[0 ])
614- elif "Generation:" in line and "tokens-per-sec" in line :
615- # Parse: "Generation: 468 tokens, 69.860 tokens-per-sec"
616- parts = line .split ("," )
617- generation_tokens = int (parts [0 ].split (":" )[1 ].strip ().split ()[0 ])
618- generation_speed = float (parts [1 ].strip ().split ()[0 ])
619- elif "Peak memory:" in line :
620- peak_memory_str = line .split (":" )[1 ].strip ()
621-
622- # Parse peak memory
623- peak_memory_gb = 0.0
624- if peak_memory_str :
625- if "GB" in peak_memory_str :
626- peak_memory_gb = float (peak_memory_str .replace ("GB" , "" ).strip ())
627- elif "MB" in peak_memory_str :
628- peak_memory_gb = float (peak_memory_str .replace ("MB" , "" ).strip ()) / 1024
629-
630- # Calculate overall tokens per second
631- total_tokens = generation_tokens
632- total_time = end_time - start_time
633- total_tokens_per_sec = total_tokens / total_time if total_time > 0 else 0
634-
635- # Create result
636- benchmark_result = BenchmarkResult (
583+
584+ # Warmup runs - don't measure these
585+ print (f"🔥 Running { WARMUP_RUNS } warmup runs to eliminate cold start effects..." )
586+ for i in range (WARMUP_RUNS ):
587+ try :
588+ print (f" Warmup run { i + 1 } /{ WARMUP_RUNS } ..." )
589+ warmup_result = subprocess .run (
590+ cmd , capture_output = True , text = True , timeout = 300
591+ )
592+ if warmup_result .returncode != 0 :
593+ print (f" ⚠️ Warmup run { i + 1 } failed: { warmup_result .stderr [:100 ]} ..." )
594+ else :
595+ print (f" ✅ Warmup run { i + 1 } completed" )
596+
597+ # Clear cache between warmup runs
598+ mx .clear_cache ()
599+
600+ except subprocess .TimeoutExpired :
601+ print (f" ⏰ Warmup run { i + 1 } timed out" )
602+ except Exception as e :
603+ print (f" ❌ Warmup run { i + 1 } error: { e } " )
604+
605+ print (f"📊 Running { MEASUREMENT_RUNS } measurement runs..." )
606+
607+ # Measurement runs
608+ successful_results = []
609+ for run_idx in range (MEASUREMENT_RUNS ):
610+ try :
611+ print (f" Measurement run { run_idx + 1 } /{ MEASUREMENT_RUNS } ..." )
612+
613+ # Clear cache before each measurement run for consistency
614+ mx .clear_cache ()
615+ initial_memory = mx .get_active_memory ()
616+
617+ # Run benchmark
618+ start_time = time .perf_counter ()
619+ result = subprocess .run (
620+ cmd , capture_output = True , text = True , timeout = 300
621+ )
622+ end_time = time .perf_counter ()
623+
624+ if result .returncode != 0 :
625+ print (f" ❌ Measurement run { run_idx + 1 } failed: { result .stderr [:100 ]} ..." )
626+ continue
627+
628+ # Parse output
629+ parsed_result = self ._parse_benchmark_output (
630+ result .stdout , config , end_time - start_time
631+ )
632+
633+ if parsed_result :
634+ successful_results .append (parsed_result )
635+ print (f" ✅ Run { run_idx + 1 } : { parsed_result .decode_tokens_per_sec :.1f} tokens/sec" )
636+ else :
637+ print (f" ❌ Run { run_idx + 1 } : Failed to parse output" )
638+
639+ except subprocess .TimeoutExpired :
640+ print (f" ⏰ Measurement run { run_idx + 1 } timed out" )
641+ except Exception as e :
642+ print (f" ❌ Measurement run { run_idx + 1 } error: { e } " )
643+
644+ # Require at least 2 successful runs for reliable results
645+ if len (successful_results ) < 2 :
646+ print (f"❌ Only { len (successful_results )} /{ MEASUREMENT_RUNS } measurement runs succeeded" )
647+ print (f"❌ Need at least 2 successful runs for reliable results" )
648+ raise RuntimeError (f"Insufficient successful runs: { len (successful_results )} /{ MEASUREMENT_RUNS } " )
649+
650+ # Calculate statistics from multiple runs
651+ decode_speeds = [r .decode_tokens_per_sec for r in successful_results ]
652+ prefill_speeds = [r .prefill_tokens_per_sec for r in successful_results ]
653+ memories = [r .peak_memory_gb for r in successful_results ]
654+ times = [r .total_time_sec for r in successful_results ]
655+
656+ # Use median for more robust results (less sensitive to outliers)
657+ final_result = BenchmarkResult (
637658 name = config .name ,
638- prompt_tokens = prompt_tokens ,
639- generated_tokens = generation_tokens ,
640- prefill_tokens_per_sec = prompt_speed ,
641- decode_tokens_per_sec = generation_speed ,
642- total_tokens_per_sec = total_tokens_per_sec ,
643- peak_memory_gb = peak_memory_gb ,
644- total_time_sec = total_time ,
659+ prompt_tokens = int ( np . median ([ r . prompt_tokens for r in successful_results ])) ,
660+ generated_tokens = int ( np . median ([ r . generated_tokens for r in successful_results ])) ,
661+ prefill_tokens_per_sec = float ( np . median ( prefill_speeds )) ,
662+ decode_tokens_per_sec = float ( np . median ( decode_speeds )) ,
663+ total_tokens_per_sec = float ( np . median ([ r . total_tokens_per_sec for r in successful_results ])) ,
664+ peak_memory_gb = float ( np . median ( memories )) ,
665+ total_time_sec = float ( np . median ( times )) ,
645666 prompt = config .prompt [:200 ] + "..." if len (config .prompt ) > 200 else config .prompt ,
646- generated_text = (
647- generated_text .strip ()[:200 ] + "..."
648- if len (generated_text .strip ()) > 200
649- else generated_text .strip ()
650- ),
667+ generated_text = successful_results [0 ].generated_text , # Use first result's text
651668 )
652669
653- # Print results
654- print (f"\n Results:" )
655- print (f" Prompt tokens: { prompt_tokens } " )
656- print (f" Generated tokens: { generation_tokens } " )
657- print (f" Prefill speed: { prompt_speed :.2f} tokens/sec" )
658- print (f" Decode speed: { generation_speed :.2f} tokens/sec" )
659- print (f" Overall speed: { total_tokens_per_sec :.2f} tokens/sec" )
660- print (f" Peak memory: { peak_memory_gb :.3f} GB" )
661- print (f" Total time: { total_time :.2f} seconds" )
662-
663- return benchmark_result
670+ # Print final results with statistics
671+ print (f"\n 📈 Final Results (median of { len (successful_results )} runs):" )
672+ print (f" Prompt tokens: { final_result .prompt_tokens } " )
673+ print (f" Generated tokens: { final_result .generated_tokens } " )
674+ print (f" Prefill speed: { final_result .prefill_tokens_per_sec :.2f} tokens/sec" )
675+ print (f" Decode speed: { final_result .decode_tokens_per_sec :.2f} tokens/sec (σ={ np .std (decode_speeds ):.2f} )" )
676+ print (f" Overall speed: { final_result .total_tokens_per_sec :.2f} tokens/sec" )
677+ print (f" Peak memory: { final_result .peak_memory_gb :.3f} GB" )
678+ print (f" Total time: { final_result .total_time_sec :.2f} seconds" )
679+
680+ if len (decode_speeds ) > 1 :
681+ print (f" Performance consistency: { np .std (decode_speeds )/ np .mean (decode_speeds )* 100 :.1f} % CV" )
682+
683+ return final_result
664684
665685 finally :
666686 # Clean up
667687 if os .path .exists (prompt_file ):
668688 os .unlink (prompt_file )
669689
690+ def _parse_benchmark_output (
691+ self , stdout : str , config : BenchmarkConfig , total_time : float
692+ ) -> Optional [BenchmarkResult ]:
693+ """Parse mlx-lm output to extract performance metrics"""
694+ output_lines = stdout .strip ().split ("\n " )
695+
696+ # Find the generated text (between ========== markers)
697+ generated_text = ""
698+ in_generation = False
699+ prompt_tokens = 0
700+ generation_tokens = 0
701+ prompt_speed = 0.0
702+ generation_speed = 0.0
703+ peak_memory_str = ""
704+
705+ for line in output_lines :
706+ if line .strip () == "==========" :
707+ in_generation = not in_generation
708+ elif in_generation :
709+ generated_text += line + "\n "
710+ elif "Prompt:" in line and "tokens-per-sec" in line :
711+ # Parse: "Prompt: 13 tokens, 310.367 tokens-per-sec"
712+ parts = line .split ("," )
713+ prompt_tokens = int (parts [0 ].split (":" )[1 ].strip ().split ()[0 ])
714+ prompt_speed = float (parts [1 ].strip ().split ()[0 ])
715+ elif "Generation:" in line and "tokens-per-sec" in line :
716+ # Parse: "Generation: 468 tokens, 69.860 tokens-per-sec"
717+ parts = line .split ("," )
718+ generation_tokens = int (parts [0 ].split (":" )[1 ].strip ().split ()[0 ])
719+ generation_speed = float (parts [1 ].strip ().split ()[0 ])
720+ elif "Peak memory:" in line :
721+ peak_memory_str = line .split (":" )[1 ].strip ()
722+
723+ # Parse peak memory
724+ peak_memory_gb = 0.0
725+ if peak_memory_str :
726+ if "GB" in peak_memory_str :
727+ peak_memory_gb = float (peak_memory_str .replace ("GB" , "" ).strip ())
728+ elif "MB" in peak_memory_str :
729+ peak_memory_gb = float (peak_memory_str .replace ("MB" , "" ).strip ()) / 1024
730+
731+ # Validate we got meaningful results
732+ if generation_tokens == 0 or generation_speed == 0 :
733+ return None
734+
735+ # Calculate overall tokens per second
736+ total_tokens_per_sec = generation_tokens / total_time if total_time > 0 else 0
737+
738+ return BenchmarkResult (
739+ name = config .name ,
740+ prompt_tokens = prompt_tokens ,
741+ generated_tokens = generation_tokens ,
742+ prefill_tokens_per_sec = prompt_speed ,
743+ decode_tokens_per_sec = generation_speed ,
744+ total_tokens_per_sec = total_tokens_per_sec ,
745+ peak_memory_gb = peak_memory_gb ,
746+ total_time_sec = total_time ,
747+ prompt = config .prompt [:200 ] + "..." if len (config .prompt ) > 200 else config .prompt ,
748+ generated_text = (
749+ generated_text .strip ()[:200 ] + "..."
750+ if len (generated_text .strip ()) > 200
751+ else generated_text .strip ()
752+ ),
753+ )
754+
670755 def run_full_benchmark_suite (self ) -> Dict :
671756 """Run the complete benchmark suite"""
672757 print (f"\n { '=' * 80 } " )
0 commit comments