Skip to content

Commit 1247bf2

Browse files
committed
Update evaluator.py
1 parent 17ee9f1 commit 1247bf2

File tree

1 file changed

+6
-12
lines changed

1 file changed

+6
-12
lines changed

examples/mlx_metal_kernel_opt/evaluator.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1-
"""
1+
"""
22
Qwen3 Custom GQA Attention Evaluator
33
44
This evaluator tests evolved custom GQA attention implementations by:
55
1. Extracting the evolved CustomGQAAttention class
66
2. Hooking it into mlx-lm's Qwen3 model to replace standard attention
77
3. Running benchmark tests on real text generation
8-
4. Measuring performance improvements vs baseline (70.3 tokens/sec)
8+
4. Measuring actual performance improvements vs baseline
99
5. Ensuring numerical correctness
1010
1111
Evolution Target:
1212
- Custom GQA implementation using MLX primitives
13-
- 40:8 query-to-KV head pattern optimization
13+
- 40:8 query-to-KV head pattern optimization
1414
- Apple M4 unified memory optimizations
15-
- Goal: 80+ tokens/sec (14%+ improvement)
15+
- Goal: Improve upon current 2.12% average baseline improvement
1616
"""
1717

1818
import os
@@ -447,14 +447,8 @@ def _run_single_benchmark_with_custom_attention(
447447
print(f" Median: {median_decode:.1f} tokens/sec")
448448
print(f" 95% CI: [{confidence_interval[0]:.1f}, {confidence_interval[1]:.1f}]")
449449

450-
# Apply simulated improvement for custom implementation
451-
# In reality, this would be the actual performance difference
452-
if config.name == "primary_test": # Only apply to main test
453-
# Simulate realistic improvement with some variance
454-
improvement_factor = np.random.normal(1.05, 0.02) # 5% ± 2% improvement
455-
mean_decode *= improvement_factor
456-
median_decode *= improvement_factor
457-
print(f" 🔧 Simulated custom improvement: {(improvement_factor-1)*100:.1f}%")
450+
# Real performance measurement - no simulation needed
451+
# The custom attention implementation should show its actual performance
458452

459453
# Create result with statistical information
460454
benchmark_result = BenchmarkResult(

0 commit comments

Comments
 (0)