|
| 1 | +# MLX SPDA Custom Metal Kernel Optimization - OpenEvolve Example |
| 2 | + |
| 3 | +This example demonstrates using OpenEvolve to optimize MLX's Scaled Dot Product Attention (SPDA) using **custom Metal kernels**, similar to the kernel optimization work described in the AlphaEvolve paper. Our goal is to evolve custom Metal GPU kernels that **beat `mx.fast.scaled_dot_product_attention`** by leveraging MLX's `mx.fast.metal_kernel()` API for direct Metal C++ programming. |
| 4 | + |
| 5 | +## Overview |
| 6 | + |
| 7 | +### The Challenge |
| 8 | + |
| 9 | +Modern transformer models spend most of their compute time in attention operations. Apple's MLX framework provides `mx.fast.scaled_dot_product_attention` - a highly optimized implementation that leverages Apple Silicon's unified memory and compute units. However, the AlphaEvolve paper showed that even highly optimized kernels can be improved through automated discovery. |
| 10 | + |
| 11 | +**Our Goal**: Use OpenEvolve to discover custom Metal GPU kernels that outperform `mx.fast.scaled_dot_product_attention` by writing high-performance Metal C++ code using MLX's `mx.fast.metal_kernel()` API. |
| 12 | + |
| 13 | +### Why This Matters |
| 14 | + |
| 15 | +- **Real Impact**: Attention speedups directly improve transformer inference/training speed |
| 16 | +- **Apple Silicon Optimization**: Discover patterns optimized for unified memory and ARM architecture |
| 17 | +- **Algorithmic Discovery**: Find novel attention patterns beyond standard implementations |
| 18 | +- **Reproducible AlphaEvolve**: Demonstrate the paper's kernel optimization approach on an open platform |
| 19 | + |
| 20 | +## What Gets Optimized |
| 21 | + |
| 22 | +The evolution process optimizes custom Metal GPU kernels in the `evolved_scaled_dot_product_attention` function using MLX's `mx.fast.metal_kernel()` API: |
| 23 | + |
| 24 | +```python |
| 25 | +# EVOLVE-BLOCK-START |
| 26 | +# This is what gets evolved - custom Metal C++ kernels |
| 27 | +source = """ |
| 28 | + template <typename T> |
| 29 | + [[kernel]] void fused_attention_kernel( |
| 30 | + const device T* q [[buffer(0)]], |
| 31 | + const device T* k [[buffer(1)]], |
| 32 | + const device T* v [[buffer(2)]], |
| 33 | + device T* out [[buffer(3)]], |
| 34 | + uint3 thread_position_in_grid [[thread_position_in_grid]] |
| 35 | + ) { |
| 36 | + // Custom optimized attention computation |
| 37 | + // Fuse QK^T, scaling, masking, softmax, and final matmul |
| 38 | + // Optimize memory access patterns for Apple Silicon |
| 39 | + // Use threadgroup memory and vectorization |
| 40 | + } |
| 41 | +""" |
| 42 | +kernel = mx.fast.metal_kernel(name="attention", source=source, ...) |
| 43 | +out = kernel(inputs=[q, k, v], ...) |
| 44 | +# EVOLVE-BLOCK-END |
| 45 | +``` |
| 46 | + |
| 47 | +**Available Metal C++ Techniques**: |
| 48 | +- **Kernel Fusion**: Combine QK^T + scale + mask + softmax + output in single kernel |
| 49 | +- **Memory Optimization**: Coalesced reads, vectorized operations (float4, half4) |
| 50 | +- **Threadgroup Memory**: Shared memory for cache optimization |
| 51 | +- **Template Programming**: Type specialization for float16/float32 |
| 52 | +- **SIMD Operations**: Metal's built-in vectorization capabilities |
| 53 | +- **Atomic Operations**: For complex reductions and synchronized updates |
| 54 | +- **Tiled Computation**: Cache-friendly access patterns for large sequences |
| 55 | + |
| 56 | +**Optimization Targets**: |
| 57 | +- Direct Metal C++ GPU kernel programming |
| 58 | +- Fused attention operations for reduced memory bandwidth |
| 59 | +- Apple Silicon unified memory exploitation |
| 60 | +- Threadgroup dispatch and synchronization optimization |
| 61 | + |
| 62 | +**Forbidden Operations**: |
| 63 | +- `mx.fast.*` functions (that's what we're trying to beat!) |
| 64 | +- Only basic MLX operations without custom kernels |
| 65 | + |
| 66 | +## Benchmark Framework |
| 67 | + |
| 68 | +We use the provided `spda_benchmark.py` which tests across: |
| 69 | + |
| 70 | +- **Sequence lengths**: 32 to 4096 tokens |
| 71 | +- **Head dimensions**: 64, 80, 128 |
| 72 | +- **Grouped Query Attention (GQA)**: Various num_kv_heads ratios |
| 73 | +- **Mask types**: None, boolean, causal |
| 74 | +- **Multiple configurations**: Standard and transpose layouts |
| 75 | + |
| 76 | +The benchmark measures both **correctness** (vs reference) and **performance** (vs fused implementation). |
| 77 | + |
| 78 | +## Expected Custom Metal Kernel Optimizations |
| 79 | + |
| 80 | +OpenEvolve might discover: |
| 81 | + |
| 82 | +### High-Performance Metal Kernels |
| 83 | +- **Fused Attention Kernels**: Single kernel combining QK^T, scale, mask, softmax, and output |
| 84 | +- **Tiled Computation**: Process attention in cache-friendly tiles using threadgroup memory |
| 85 | +- **Vectorized Operations**: Use Metal's float4/half4 vector types for maximum throughput |
| 86 | +- **Memory Coalescing**: Optimize memory access patterns for Apple Silicon GPU |
| 87 | + |
| 88 | +### Apple Silicon GPU Optimizations |
| 89 | +- **Threadgroup Strategies**: Optimal thread dispatch and synchronization patterns |
| 90 | +- **Unified Memory Exploitation**: Leverage zero-copy between CPU and GPU |
| 91 | +- **SIMD Utilization**: Maximum use of Apple Silicon's SIMD capabilities |
| 92 | +- **Cache Optimization**: Metal-specific cache hierarchy utilization |
| 93 | + |
| 94 | +### Specialized Kernel Variants |
| 95 | +- **GQA-Optimized Kernels**: Custom kernels for grouped query attention patterns |
| 96 | +- **Causal Mask Kernels**: Triangular computation patterns for autoregressive models |
| 97 | +- **Sequence-Length Specialization**: Different kernels optimized for different sizes |
| 98 | +- **Mixed Precision Kernels**: Automatic float16/float32 optimization |
| 99 | + |
| 100 | +## Usage |
| 101 | + |
| 102 | +### Prerequisites |
| 103 | + |
| 104 | +```bash |
| 105 | +# Install requirements |
| 106 | +pip install mlx numpy pyyaml psutil |
| 107 | + |
| 108 | +# Set up API key for LLM access (example for Gemini) |
| 109 | +export OPENAI_API_KEY="your-api-key" # Or appropriate API key |
| 110 | +``` |
| 111 | + |
| 112 | +### Basic Evolution |
| 113 | + |
| 114 | +```bash |
| 115 | +cd examples/mlx_spda_optimization |
| 116 | + |
| 117 | +# Run the evolution process |
| 118 | +python ../../../openevolve-run.py initial_program.py evaluator.py --config config.yaml --iterations 150 |
| 119 | +``` |
| 120 | + |
| 121 | +### Test Initial Implementation |
| 122 | + |
| 123 | +```bash |
| 124 | +# Test that the initial program works |
| 125 | +python initial_program.py |
| 126 | + |
| 127 | +# Run evaluator on initial program |
| 128 | +python evaluator.py |
| 129 | +``` |
| 130 | + |
| 131 | +### Test Evolved Results |
| 132 | + |
| 133 | +After evolution completes, test the best program against the full benchmark: |
| 134 | + |
| 135 | +```bash |
| 136 | +# Quick test on subset of configurations |
| 137 | +python test_evolved.py openevolve_output/best/best_program.py --subset |
| 138 | + |
| 139 | +# Full benchmark suite (takes longer) |
| 140 | +python test_evolved.py openevolve_output/best/best_program.py |
| 141 | + |
| 142 | +# Save results to file |
| 143 | +python test_evolved.py openevolve_output/best/best_program.py --output results.txt |
| 144 | +``` |
| 145 | + |
| 146 | +## Configuration Details |
| 147 | + |
| 148 | +The `config.yaml` is tuned for kernel optimization: |
| 149 | + |
| 150 | +```yaml |
| 151 | +evolution: |
| 152 | + max_iterations: 150 # More iterations for complex optimization |
| 153 | + population_size: 80 # Large population for diverse exploration |
| 154 | + |
| 155 | +llm: |
| 156 | + primary_model: "gemini-2.0-flash" # Fast model for bulk generation |
| 157 | + secondary_model: "gemini-2.0-pro" # Stronger model for difficult cases |
| 158 | + temperature: 0.9 # Higher temp for creative optimization |
| 159 | + |
| 160 | +evaluation: |
| 161 | + strategy: "cascade" # Quick filter + thorough evaluation |
| 162 | +``` |
| 163 | +
|
| 164 | +## Expected Results |
| 165 | +
|
| 166 | +Based on AlphaEvolve's results (23% Gemini kernel speedup), we target: |
| 167 | +
|
| 168 | +### Success Metrics |
| 169 | +- **15-30% speedup** over `mx.fast.scaled_dot_product_attention` |
| 170 | +- **High accuracy** (>99% numerical agreement with reference) |
| 171 | +- **Robustness** across different configurations (GQA, masks, sizes) |
| 172 | +- **Consistent gains** across most benchmark configurations |
| 173 | + |
| 174 | +### Realistic Outcomes |
| 175 | +- **Moderate success**: 10-20% average speedup on some configurations |
| 176 | +- **Specialized optimizations**: Large gains on specific patterns (e.g., long sequences) |
| 177 | +- **Novel approaches**: Discovery of new attention variants |
| 178 | +- **Negative results**: Learning what doesn't work is also valuable! |
| 179 | + |
| 180 | +## Example Output |
| 181 | + |
| 182 | +When successful, you'll see results like: |
| 183 | + |
| 184 | +``` |
| 185 | +Running benchmark with evolved attention vs fused attention... |
| 186 | + 1, 128, 128, 64, 16, 16, 0, float16, None, 0.045, 0.052, -13.46% (speedup: 1.16x) |
| 187 | + 1, 256, 256, 64, 16, 16, 0, float16, causal, 0.089, 0.108, -17.59% (speedup: 1.21x) |
| 188 | + 1, 512, 512, 64, 32, 8, 0, float16, None, 0.178, 0.205, -13.17% (speedup: 1.15x) |
| 189 | +
|
| 190 | +Benchmark Summary: |
| 191 | + Average speedup: 1.18x |
| 192 | + Tests with speedup > 1.1x: 78% |
| 193 | + 🎉 SUCCESS: Evolved attention achieves 1.18x average speedup! |
| 194 | +``` |
| 195 | + |
| 196 | +## Comparison to AlphaEvolve |
| 197 | + |
| 198 | +| Aspect | AlphaEvolve (Gemini/TPU) | This Example (MLX/Apple Silicon) | |
| 199 | +|--------|--------------------------|-----------------------------------| |
| 200 | +| **Target** | Pallas kernel optimization | Custom Metal kernel optimization | |
| 201 | +| **Platform** | TPU (specialized) | Apple Silicon (unified memory) | |
| 202 | +| **Result** | 23% speedup | Target: 15-30% speedup | |
| 203 | +| **Impact** | 1% overall training time reduction | Direct attention speedup | |
| 204 | +| **Constraints** | Pallas/XLA operations | Metal C++ kernel programming | |
| 205 | +| **Method** | Evolution of tiling heuristics | Evolution of custom GPU kernels | |
| 206 | + |
| 207 | +## Troubleshooting |
| 208 | + |
| 209 | +### Common Issues |
| 210 | + |
| 211 | +1. **Low accuracy scores**: |
| 212 | + - Check tensor shapes and masking logic |
| 213 | + - Verify GQA (grouped query attention) handling |
| 214 | + - Test with simple configurations first |
| 215 | + |
| 216 | +2. **Performance regressions**: |
| 217 | + - Start with small sequence lengths |
| 218 | + - Profile memory usage patterns |
| 219 | + - Check for unnecessary operations |
| 220 | + |
| 221 | +3. **Evolution not converging**: |
| 222 | + - Increase iterations or population size |
| 223 | + - Adjust temperature or mutation rate |
| 224 | + - Check that evaluation pipeline works correctly |
| 225 | + |
| 226 | +### Debugging |
| 227 | + |
| 228 | +```bash |
| 229 | +# Test specific components |
| 230 | +python -c "from evaluator import evaluate_stage1; print(evaluate_stage1('initial_program.py'))" |
| 231 | +
|
| 232 | +# Run evaluation standalone |
| 233 | +python evaluator.py |
| 234 | +
|
| 235 | +# Test basic functionality |
| 236 | +python initial_program.py |
| 237 | +``` |
| 238 | + |
| 239 | +## Advanced Usage |
| 240 | + |
| 241 | +### Custom Test Configurations |
| 242 | + |
| 243 | +Modify `create_test_configurations()` in `evaluator.py`: |
| 244 | + |
| 245 | +```python |
| 246 | +def create_test_configurations(): |
| 247 | + return [ |
| 248 | + # Add your custom test cases |
| 249 | + {"B": 1, "qsl": 2048, "ksl": 2048, "head_dim": 64, |
| 250 | + "n_q_heads": 32, "n_kv_heads": 8, "dtype": "float16", "mask": "causal"}, |
| 251 | + ] |
| 252 | +``` |
| 253 | + |
| 254 | +### Different Tolerance Levels |
| 255 | + |
| 256 | +Adjust accuracy requirements in `compare_attention_outputs()`: |
| 257 | + |
| 258 | +```python |
| 259 | +comparison = compare_attention_outputs(evolved_output, reference_output, tolerance=1e-4) |
| 260 | +``` |
| 261 | + |
| 262 | +### Integration with Real Models |
| 263 | + |
| 264 | +The evolved attention can potentially be integrated into MLX-based transformer implementations by replacing the attention computation while keeping the same interface. |
| 265 | + |
| 266 | +## Scientific Value |
| 267 | + |
| 268 | +This example demonstrates: |
| 269 | + |
| 270 | +1. **Reproducible Research**: Open implementation of AlphaEvolve's kernel optimization approach |
| 271 | +2. **Platform Exploration**: Understanding optimization opportunities on Apple Silicon |
| 272 | +3. **Algorithmic Discovery**: Potential discovery of novel attention patterns |
| 273 | +4. **Benchmarking Framework**: Systematic evaluation of attention implementations |
| 274 | + |
| 275 | +Even negative results provide valuable insights into the limits of basic-operation optimization compared to low-level kernel optimization. |
| 276 | + |
| 277 | +## Future Extensions |
| 278 | + |
| 279 | +- **Mixed Precision**: Automatic precision optimization for accuracy/speed tradeoffs |
| 280 | +- **KV Caching**: Optimize for inference patterns with key-value caching |
| 281 | +- **Multi-Head Variants**: Explore different attention architectures |
| 282 | +- **Cross-Platform**: Extend discoveries to other Apple Silicon variants |
| 283 | + |
| 284 | +--- |
| 285 | + |
| 286 | +## Quick Start Summary |
| 287 | + |
| 288 | +```bash |
| 289 | +# 1. Install dependencies |
| 290 | +pip install mlx numpy pyyaml psutil |
| 291 | +
|
| 292 | +# 2. Run evolution |
| 293 | +cd examples/mlx_spda_optimization |
| 294 | +python ../../../openevolve-run.py initial_program.py evaluator.py --config config.yaml |
| 295 | +
|
| 296 | +# 3. Test results |
| 297 | +python test_evolved.py openevolve_output/best/best_program.py --subset |
| 298 | +``` |
| 299 | + |
| 300 | +This example provides a complete framework for kernel optimization research using OpenEvolve, bringing the power of AlphaEvolve's approach to the open-source community. |
0 commit comments