Skip to content

Commit 17ee9f1

Browse files
committed
d
1 parent dd688f8 commit 17ee9f1

File tree

2 files changed

+203
-92
lines changed

2 files changed

+203
-92
lines changed

examples/mlx_metal_kernel_opt/quick_benchmark_test.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
def run_quick_test():
15-
"""Run a quick test with just a few key benchmarks"""
15+
"""Run a quick test with just a few key benchmarks with proper warmup"""
1616

1717
# Test configs - subset of full suite
1818
test_configs = [
@@ -50,17 +50,39 @@ def run_quick_test():
5050

5151
# Use mlx-lm as installed package (no need to change directories)
5252
try:
53+
# Import mlx for cache clearing
54+
import mlx.core as mx
55+
import numpy as np
56+
5357
benchmark_suite = Qwen3BenchmarkSuite()
5458

5559
print(f"\n{'='*80}")
5660
print(f"Quick Benchmark Test - Qwen3-0.6B")
57-
print(f"Testing {len(test_configs)} key scenarios")
61+
print(f"Testing {len(test_configs)} key scenarios with warmup")
5862
print(f"{'='*80}")
63+
64+
# Global warmup - run one quick test to warm up the system
65+
print(f"🔥 Running global warmup to initialize MLX and model...")
66+
try:
67+
mx.clear_cache()
68+
warmup_config = BenchmarkConfig(
69+
name="warmup",
70+
prompt="Hello",
71+
max_tokens=5,
72+
description="Warmup run"
73+
)
74+
print(f" Global warmup in progress...")
75+
warmup_result = benchmark_suite.run_single_benchmark(warmup_config)
76+
print(f" ✅ Global warmup completed")
77+
except Exception as e:
78+
print(f" ⚠️ Global warmup failed: {e}")
79+
print(f" Continuing with individual tests...")
5980

6081
results = []
6182
for i, config in enumerate(test_configs, 1):
6283
print(f"\n[{i}/{len(test_configs)}] Running: {config.name}")
6384
try:
85+
# The benchmark_suite.run_single_benchmark already has warmup built-in
6486
result = benchmark_suite.run_single_benchmark(config)
6587
results.append(result)
6688
except Exception as e:
@@ -72,32 +94,36 @@ def run_quick_test():
7294
print(f"\n{'='*80}")
7395
print(f"Quick Test Results Summary")
7496
print(f"{'='*80}")
75-
print(f"{'Name':<20} {'Gen Tokens':<12} {'Decode Speed':<12} {'Memory':<10}")
97+
print(f"{'Name':<25} {'Gen Tokens':<12} {'Decode Speed':<15} {'Memory':<10} {'CV%':<8}")
7698
print(f"{'-'*80}")
7799

78100
for result in results:
101+
# Extract standard deviation from the result display if available
102+
cv_display = "N/A"
79103
print(
80-
f"{result.name:<20} "
104+
f"{result.name:<25} "
81105
f"{result.generated_tokens:<12} "
82-
f"{result.decode_tokens_per_sec:<12.1f} "
83-
f"{result.peak_memory_gb:<10.2f}"
106+
f"{result.decode_tokens_per_sec:<15.1f} "
107+
f"{result.peak_memory_gb:<10.2f} "
108+
f"{cv_display:<8}"
84109
)
85110

86111
print(f"{'-'*80}")
87112
decode_speeds = [
88113
r.decode_tokens_per_sec for r in results if r.decode_tokens_per_sec > 0
89114
]
90115
if decode_speeds:
91-
import numpy as np
92-
93116
print(f"Average decode speed: {np.mean(decode_speeds):.1f} tokens/sec")
94117
print(
95118
f"Speed range: {np.min(decode_speeds):.1f} - {np.max(decode_speeds):.1f} tokens/sec"
96119
)
120+
print(f"Performance std dev: {np.std(decode_speeds):.1f} tokens/sec")
121+
print(f"Overall consistency: {np.std(decode_speeds)/np.mean(decode_speeds)*100:.1f}% CV")
97122

98123
print(f"\n{'='*80}")
99124
print("Quick test complete! If this looks good, run the full benchmark suite.")
100125
print("python qwen3_benchmark_suite.py")
126+
print(f"✅ All tests included proper warmup for reliable results")
101127
print(f"{'='*80}")
102128

103129
return results

examples/mlx_metal_kernel_opt/qwen3_benchmark_suite.py

Lines changed: 169 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -547,13 +547,17 @@ def _create_maximum_context_prompt(self) -> str:
547547
return extended_context
548548

549549
def run_single_benchmark(self, config: BenchmarkConfig) -> BenchmarkResult:
550-
"""Run a single benchmark configuration"""
550+
"""Run a single benchmark configuration with proper warmup"""
551551
print(f"\n{'='*60}")
552552
print(f"Running: {config.name}")
553553
print(f"Description: {config.description}")
554554
print(f"Max tokens: {config.max_tokens}")
555555
print(f"{'='*60}")
556556

557+
# Performance measurement parameters
558+
WARMUP_RUNS = 2 # Warmup runs to eliminate cold start effects
559+
MEASUREMENT_RUNS = 3 # Multiple measurement runs for reliability
560+
557561
# Create temporary prompt file
558562
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f:
559563
f.write(config.prompt)
@@ -571,102 +575,183 @@ def run_single_benchmark(self, config: BenchmarkConfig) -> BenchmarkResult:
571575
config.prompt,
572576
"--max-tokens",
573577
str(config.max_tokens),
574-
# Remove --verbose flag as it requires an argument in newer mlx-lm
575578
]
576579

577-
# Record memory before
580+
# Clear MLX cache before starting
581+
print(f"🧹 Clearing MLX cache...")
578582
mx.clear_cache()
579-
initial_memory = mx.get_active_memory()
580-
581-
# Run benchmark
582-
start_time = time.perf_counter()
583-
result = subprocess.run(
584-
cmd, capture_output=True, text=True, timeout=300 # 5 minute timeout
585-
)
586-
end_time = time.perf_counter()
587-
588-
if result.returncode != 0:
589-
print(f"Error running benchmark: {result.stderr}")
590-
raise RuntimeError(f"Benchmark failed: {result.stderr}")
591-
592-
# Parse output
593-
output_lines = result.stdout.strip().split("\n")
594-
595-
# Find the generated text (between ========== markers)
596-
generated_text = ""
597-
in_generation = False
598-
prompt_tokens = 0
599-
generation_tokens = 0
600-
prompt_speed = 0.0
601-
generation_speed = 0.0
602-
peak_memory_str = ""
603-
604-
for line in output_lines:
605-
if line.strip() == "==========":
606-
in_generation = not in_generation
607-
elif in_generation:
608-
generated_text += line + "\n"
609-
elif "Prompt:" in line and "tokens-per-sec" in line:
610-
# Parse: "Prompt: 13 tokens, 310.367 tokens-per-sec"
611-
parts = line.split(",")
612-
prompt_tokens = int(parts[0].split(":")[1].strip().split()[0])
613-
prompt_speed = float(parts[1].strip().split()[0])
614-
elif "Generation:" in line and "tokens-per-sec" in line:
615-
# Parse: "Generation: 468 tokens, 69.860 tokens-per-sec"
616-
parts = line.split(",")
617-
generation_tokens = int(parts[0].split(":")[1].strip().split()[0])
618-
generation_speed = float(parts[1].strip().split()[0])
619-
elif "Peak memory:" in line:
620-
peak_memory_str = line.split(":")[1].strip()
621-
622-
# Parse peak memory
623-
peak_memory_gb = 0.0
624-
if peak_memory_str:
625-
if "GB" in peak_memory_str:
626-
peak_memory_gb = float(peak_memory_str.replace("GB", "").strip())
627-
elif "MB" in peak_memory_str:
628-
peak_memory_gb = float(peak_memory_str.replace("MB", "").strip()) / 1024
629-
630-
# Calculate overall tokens per second
631-
total_tokens = generation_tokens
632-
total_time = end_time - start_time
633-
total_tokens_per_sec = total_tokens / total_time if total_time > 0 else 0
634-
635-
# Create result
636-
benchmark_result = BenchmarkResult(
583+
584+
# Warmup runs - don't measure these
585+
print(f"🔥 Running {WARMUP_RUNS} warmup runs to eliminate cold start effects...")
586+
for i in range(WARMUP_RUNS):
587+
try:
588+
print(f" Warmup run {i+1}/{WARMUP_RUNS}...")
589+
warmup_result = subprocess.run(
590+
cmd, capture_output=True, text=True, timeout=300
591+
)
592+
if warmup_result.returncode != 0:
593+
print(f" ⚠️ Warmup run {i+1} failed: {warmup_result.stderr[:100]}...")
594+
else:
595+
print(f" ✅ Warmup run {i+1} completed")
596+
597+
# Clear cache between warmup runs
598+
mx.clear_cache()
599+
600+
except subprocess.TimeoutExpired:
601+
print(f" ⏰ Warmup run {i+1} timed out")
602+
except Exception as e:
603+
print(f" ❌ Warmup run {i+1} error: {e}")
604+
605+
print(f"📊 Running {MEASUREMENT_RUNS} measurement runs...")
606+
607+
# Measurement runs
608+
successful_results = []
609+
for run_idx in range(MEASUREMENT_RUNS):
610+
try:
611+
print(f" Measurement run {run_idx+1}/{MEASUREMENT_RUNS}...")
612+
613+
# Clear cache before each measurement run for consistency
614+
mx.clear_cache()
615+
initial_memory = mx.get_active_memory()
616+
617+
# Run benchmark
618+
start_time = time.perf_counter()
619+
result = subprocess.run(
620+
cmd, capture_output=True, text=True, timeout=300
621+
)
622+
end_time = time.perf_counter()
623+
624+
if result.returncode != 0:
625+
print(f" ❌ Measurement run {run_idx+1} failed: {result.stderr[:100]}...")
626+
continue
627+
628+
# Parse output
629+
parsed_result = self._parse_benchmark_output(
630+
result.stdout, config, end_time - start_time
631+
)
632+
633+
if parsed_result:
634+
successful_results.append(parsed_result)
635+
print(f" ✅ Run {run_idx+1}: {parsed_result.decode_tokens_per_sec:.1f} tokens/sec")
636+
else:
637+
print(f" ❌ Run {run_idx+1}: Failed to parse output")
638+
639+
except subprocess.TimeoutExpired:
640+
print(f" ⏰ Measurement run {run_idx+1} timed out")
641+
except Exception as e:
642+
print(f" ❌ Measurement run {run_idx+1} error: {e}")
643+
644+
# Require at least 2 successful runs for reliable results
645+
if len(successful_results) < 2:
646+
print(f"❌ Only {len(successful_results)}/{MEASUREMENT_RUNS} measurement runs succeeded")
647+
print(f"❌ Need at least 2 successful runs for reliable results")
648+
raise RuntimeError(f"Insufficient successful runs: {len(successful_results)}/{MEASUREMENT_RUNS}")
649+
650+
# Calculate statistics from multiple runs
651+
decode_speeds = [r.decode_tokens_per_sec for r in successful_results]
652+
prefill_speeds = [r.prefill_tokens_per_sec for r in successful_results]
653+
memories = [r.peak_memory_gb for r in successful_results]
654+
times = [r.total_time_sec for r in successful_results]
655+
656+
# Use median for more robust results (less sensitive to outliers)
657+
final_result = BenchmarkResult(
637658
name=config.name,
638-
prompt_tokens=prompt_tokens,
639-
generated_tokens=generation_tokens,
640-
prefill_tokens_per_sec=prompt_speed,
641-
decode_tokens_per_sec=generation_speed,
642-
total_tokens_per_sec=total_tokens_per_sec,
643-
peak_memory_gb=peak_memory_gb,
644-
total_time_sec=total_time,
659+
prompt_tokens=int(np.median([r.prompt_tokens for r in successful_results])),
660+
generated_tokens=int(np.median([r.generated_tokens for r in successful_results])),
661+
prefill_tokens_per_sec=float(np.median(prefill_speeds)),
662+
decode_tokens_per_sec=float(np.median(decode_speeds)),
663+
total_tokens_per_sec=float(np.median([r.total_tokens_per_sec for r in successful_results])),
664+
peak_memory_gb=float(np.median(memories)),
665+
total_time_sec=float(np.median(times)),
645666
prompt=config.prompt[:200] + "..." if len(config.prompt) > 200 else config.prompt,
646-
generated_text=(
647-
generated_text.strip()[:200] + "..."
648-
if len(generated_text.strip()) > 200
649-
else generated_text.strip()
650-
),
667+
generated_text=successful_results[0].generated_text, # Use first result's text
651668
)
652669

653-
# Print results
654-
print(f"\nResults:")
655-
print(f" Prompt tokens: {prompt_tokens}")
656-
print(f" Generated tokens: {generation_tokens}")
657-
print(f" Prefill speed: {prompt_speed:.2f} tokens/sec")
658-
print(f" Decode speed: {generation_speed:.2f} tokens/sec")
659-
print(f" Overall speed: {total_tokens_per_sec:.2f} tokens/sec")
660-
print(f" Peak memory: {peak_memory_gb:.3f} GB")
661-
print(f" Total time: {total_time:.2f} seconds")
662-
663-
return benchmark_result
670+
# Print final results with statistics
671+
print(f"\n📈 Final Results (median of {len(successful_results)} runs):")
672+
print(f" Prompt tokens: {final_result.prompt_tokens}")
673+
print(f" Generated tokens: {final_result.generated_tokens}")
674+
print(f" Prefill speed: {final_result.prefill_tokens_per_sec:.2f} tokens/sec")
675+
print(f" Decode speed: {final_result.decode_tokens_per_sec:.2f} tokens/sec (σ={np.std(decode_speeds):.2f})")
676+
print(f" Overall speed: {final_result.total_tokens_per_sec:.2f} tokens/sec")
677+
print(f" Peak memory: {final_result.peak_memory_gb:.3f} GB")
678+
print(f" Total time: {final_result.total_time_sec:.2f} seconds")
679+
680+
if len(decode_speeds) > 1:
681+
print(f" Performance consistency: {np.std(decode_speeds)/np.mean(decode_speeds)*100:.1f}% CV")
682+
683+
return final_result
664684

665685
finally:
666686
# Clean up
667687
if os.path.exists(prompt_file):
668688
os.unlink(prompt_file)
669689

690+
def _parse_benchmark_output(
691+
self, stdout: str, config: BenchmarkConfig, total_time: float
692+
) -> Optional[BenchmarkResult]:
693+
"""Parse mlx-lm output to extract performance metrics"""
694+
output_lines = stdout.strip().split("\n")
695+
696+
# Find the generated text (between ========== markers)
697+
generated_text = ""
698+
in_generation = False
699+
prompt_tokens = 0
700+
generation_tokens = 0
701+
prompt_speed = 0.0
702+
generation_speed = 0.0
703+
peak_memory_str = ""
704+
705+
for line in output_lines:
706+
if line.strip() == "==========":
707+
in_generation = not in_generation
708+
elif in_generation:
709+
generated_text += line + "\n"
710+
elif "Prompt:" in line and "tokens-per-sec" in line:
711+
# Parse: "Prompt: 13 tokens, 310.367 tokens-per-sec"
712+
parts = line.split(",")
713+
prompt_tokens = int(parts[0].split(":")[1].strip().split()[0])
714+
prompt_speed = float(parts[1].strip().split()[0])
715+
elif "Generation:" in line and "tokens-per-sec" in line:
716+
# Parse: "Generation: 468 tokens, 69.860 tokens-per-sec"
717+
parts = line.split(",")
718+
generation_tokens = int(parts[0].split(":")[1].strip().split()[0])
719+
generation_speed = float(parts[1].strip().split()[0])
720+
elif "Peak memory:" in line:
721+
peak_memory_str = line.split(":")[1].strip()
722+
723+
# Parse peak memory
724+
peak_memory_gb = 0.0
725+
if peak_memory_str:
726+
if "GB" in peak_memory_str:
727+
peak_memory_gb = float(peak_memory_str.replace("GB", "").strip())
728+
elif "MB" in peak_memory_str:
729+
peak_memory_gb = float(peak_memory_str.replace("MB", "").strip()) / 1024
730+
731+
# Validate we got meaningful results
732+
if generation_tokens == 0 or generation_speed == 0:
733+
return None
734+
735+
# Calculate overall tokens per second
736+
total_tokens_per_sec = generation_tokens / total_time if total_time > 0 else 0
737+
738+
return BenchmarkResult(
739+
name=config.name,
740+
prompt_tokens=prompt_tokens,
741+
generated_tokens=generation_tokens,
742+
prefill_tokens_per_sec=prompt_speed,
743+
decode_tokens_per_sec=generation_speed,
744+
total_tokens_per_sec=total_tokens_per_sec,
745+
peak_memory_gb=peak_memory_gb,
746+
total_time_sec=total_time,
747+
prompt=config.prompt[:200] + "..." if len(config.prompt) > 200 else config.prompt,
748+
generated_text=(
749+
generated_text.strip()[:200] + "..."
750+
if len(generated_text.strip()) > 200
751+
else generated_text.strip()
752+
),
753+
)
754+
670755
def run_full_benchmark_suite(self) -> Dict:
671756
"""Run the complete benchmark suite"""
672757
print(f"\n{'='*80}")

0 commit comments

Comments
 (0)