|
1 | | -""" |
| 1 | +""" |
2 | 2 | Qwen3 Custom GQA Attention Evaluator |
3 | 3 |
|
4 | 4 | This evaluator tests evolved custom GQA attention implementations by: |
5 | 5 | 1. Extracting the evolved CustomGQAAttention class |
6 | 6 | 2. Hooking it into mlx-lm's Qwen3 model to replace standard attention |
7 | 7 | 3. Running benchmark tests on real text generation |
8 | | -4. Measuring performance improvements vs baseline (70.3 tokens/sec) |
| 8 | +4. Measuring actual performance improvements vs baseline |
9 | 9 | 5. Ensuring numerical correctness |
10 | 10 |
|
11 | 11 | Evolution Target: |
12 | 12 | - Custom GQA implementation using MLX primitives |
13 | | -- 40:8 query-to-KV head pattern optimization |
| 13 | +- 40:8 query-to-KV head pattern optimization |
14 | 14 | - Apple M4 unified memory optimizations |
15 | | -- Goal: 80+ tokens/sec (14%+ improvement) |
| 15 | +- Goal: Improve upon current 2.12% average baseline improvement |
16 | 16 | """ |
17 | 17 |
|
18 | 18 | import os |
@@ -447,14 +447,8 @@ def _run_single_benchmark_with_custom_attention( |
447 | 447 | print(f" Median: {median_decode:.1f} tokens/sec") |
448 | 448 | print(f" 95% CI: [{confidence_interval[0]:.1f}, {confidence_interval[1]:.1f}]") |
449 | 449 |
|
450 | | - # Apply simulated improvement for custom implementation |
451 | | - # In reality, this would be the actual performance difference |
452 | | - if config.name == "primary_test": # Only apply to main test |
453 | | - # Simulate realistic improvement with some variance |
454 | | - improvement_factor = np.random.normal(1.05, 0.02) # 5% ± 2% improvement |
455 | | - mean_decode *= improvement_factor |
456 | | - median_decode *= improvement_factor |
457 | | - print(f" 🔧 Simulated custom improvement: {(improvement_factor-1)*100:.1f}%") |
| 450 | + # Real performance measurement - no simulation needed |
| 451 | + # The custom attention implementation should show its actual performance |
458 | 452 |
|
459 | 453 | # Create result with statistical information |
460 | 454 | benchmark_result = BenchmarkResult( |
|
0 commit comments