Skip to content

Commit 233b098

Browse files
committed
fxies
1 parent 957f4a7 commit 233b098

File tree

9 files changed

+1900
-0
lines changed

9 files changed

+1900
-0
lines changed
Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
# MLX SPDA Custom Metal Kernel Optimization - OpenEvolve Example
2+
3+
This example demonstrates using OpenEvolve to optimize MLX's Scaled Dot Product Attention (SPDA) using **custom Metal kernels**, similar to the kernel optimization work described in the AlphaEvolve paper. Our goal is to evolve custom Metal GPU kernels that **beat `mx.fast.scaled_dot_product_attention`** by leveraging MLX's `mx.fast.metal_kernel()` API for direct Metal C++ programming.
4+
5+
## Overview
6+
7+
### The Challenge
8+
9+
Modern transformer models spend most of their compute time in attention operations. Apple's MLX framework provides `mx.fast.scaled_dot_product_attention` - a highly optimized implementation that leverages Apple Silicon's unified memory and compute units. However, the AlphaEvolve paper showed that even highly optimized kernels can be improved through automated discovery.
10+
11+
**Our Goal**: Use OpenEvolve to discover custom Metal GPU kernels that outperform `mx.fast.scaled_dot_product_attention` by writing high-performance Metal C++ code using MLX's `mx.fast.metal_kernel()` API.
12+
13+
### Why This Matters
14+
15+
- **Real Impact**: Attention speedups directly improve transformer inference/training speed
16+
- **Apple Silicon Optimization**: Discover patterns optimized for unified memory and ARM architecture
17+
- **Algorithmic Discovery**: Find novel attention patterns beyond standard implementations
18+
- **Reproducible AlphaEvolve**: Demonstrate the paper's kernel optimization approach on an open platform
19+
20+
## What Gets Optimized
21+
22+
The evolution process optimizes custom Metal GPU kernels in the `evolved_scaled_dot_product_attention` function using MLX's `mx.fast.metal_kernel()` API:
23+
24+
```python
25+
# EVOLVE-BLOCK-START
26+
# This is what gets evolved - custom Metal C++ kernels
27+
source = """
28+
template <typename T>
29+
[[kernel]] void fused_attention_kernel(
30+
const device T* q [[buffer(0)]],
31+
const device T* k [[buffer(1)]],
32+
const device T* v [[buffer(2)]],
33+
device T* out [[buffer(3)]],
34+
uint3 thread_position_in_grid [[thread_position_in_grid]]
35+
) {
36+
// Custom optimized attention computation
37+
// Fuse QK^T, scaling, masking, softmax, and final matmul
38+
// Optimize memory access patterns for Apple Silicon
39+
// Use threadgroup memory and vectorization
40+
}
41+
"""
42+
kernel = mx.fast.metal_kernel(name="attention", source=source, ...)
43+
out = kernel(inputs=[q, k, v], ...)
44+
# EVOLVE-BLOCK-END
45+
```
46+
47+
**Available Metal C++ Techniques**:
48+
- **Kernel Fusion**: Combine QK^T + scale + mask + softmax + output in single kernel
49+
- **Memory Optimization**: Coalesced reads, vectorized operations (float4, half4)
50+
- **Threadgroup Memory**: Shared memory for cache optimization
51+
- **Template Programming**: Type specialization for float16/float32
52+
- **SIMD Operations**: Metal's built-in vectorization capabilities
53+
- **Atomic Operations**: For complex reductions and synchronized updates
54+
- **Tiled Computation**: Cache-friendly access patterns for large sequences
55+
56+
**Optimization Targets**:
57+
- Direct Metal C++ GPU kernel programming
58+
- Fused attention operations for reduced memory bandwidth
59+
- Apple Silicon unified memory exploitation
60+
- Threadgroup dispatch and synchronization optimization
61+
62+
**Forbidden Operations**:
63+
- `mx.fast.*` functions (that's what we're trying to beat!)
64+
- Only basic MLX operations without custom kernels
65+
66+
## Benchmark Framework
67+
68+
We use the provided `spda_benchmark.py` which tests across:
69+
70+
- **Sequence lengths**: 32 to 4096 tokens
71+
- **Head dimensions**: 64, 80, 128
72+
- **Grouped Query Attention (GQA)**: Various num_kv_heads ratios
73+
- **Mask types**: None, boolean, causal
74+
- **Multiple configurations**: Standard and transpose layouts
75+
76+
The benchmark measures both **correctness** (vs reference) and **performance** (vs fused implementation).
77+
78+
## Expected Custom Metal Kernel Optimizations
79+
80+
OpenEvolve might discover:
81+
82+
### High-Performance Metal Kernels
83+
- **Fused Attention Kernels**: Single kernel combining QK^T, scale, mask, softmax, and output
84+
- **Tiled Computation**: Process attention in cache-friendly tiles using threadgroup memory
85+
- **Vectorized Operations**: Use Metal's float4/half4 vector types for maximum throughput
86+
- **Memory Coalescing**: Optimize memory access patterns for Apple Silicon GPU
87+
88+
### Apple Silicon GPU Optimizations
89+
- **Threadgroup Strategies**: Optimal thread dispatch and synchronization patterns
90+
- **Unified Memory Exploitation**: Leverage zero-copy between CPU and GPU
91+
- **SIMD Utilization**: Maximum use of Apple Silicon's SIMD capabilities
92+
- **Cache Optimization**: Metal-specific cache hierarchy utilization
93+
94+
### Specialized Kernel Variants
95+
- **GQA-Optimized Kernels**: Custom kernels for grouped query attention patterns
96+
- **Causal Mask Kernels**: Triangular computation patterns for autoregressive models
97+
- **Sequence-Length Specialization**: Different kernels optimized for different sizes
98+
- **Mixed Precision Kernels**: Automatic float16/float32 optimization
99+
100+
## Usage
101+
102+
### Prerequisites
103+
104+
```bash
105+
# Install requirements
106+
pip install mlx numpy pyyaml psutil
107+
108+
# Set up API key for LLM access (example for Gemini)
109+
export OPENAI_API_KEY="your-api-key" # Or appropriate API key
110+
```
111+
112+
### Basic Evolution
113+
114+
```bash
115+
cd examples/mlx_spda_optimization
116+
117+
# Run the evolution process
118+
python ../../../openevolve-run.py initial_program.py evaluator.py --config config.yaml --iterations 150
119+
```
120+
121+
### Test Initial Implementation
122+
123+
```bash
124+
# Test that the initial program works
125+
python initial_program.py
126+
127+
# Run evaluator on initial program
128+
python evaluator.py
129+
```
130+
131+
### Test Evolved Results
132+
133+
After evolution completes, test the best program against the full benchmark:
134+
135+
```bash
136+
# Quick test on subset of configurations
137+
python test_evolved.py openevolve_output/best/best_program.py --subset
138+
139+
# Full benchmark suite (takes longer)
140+
python test_evolved.py openevolve_output/best/best_program.py
141+
142+
# Save results to file
143+
python test_evolved.py openevolve_output/best/best_program.py --output results.txt
144+
```
145+
146+
## Configuration Details
147+
148+
The `config.yaml` is tuned for kernel optimization:
149+
150+
```yaml
151+
evolution:
152+
max_iterations: 150 # More iterations for complex optimization
153+
population_size: 80 # Large population for diverse exploration
154+
155+
llm:
156+
primary_model: "gemini-2.0-flash" # Fast model for bulk generation
157+
secondary_model: "gemini-2.0-pro" # Stronger model for difficult cases
158+
temperature: 0.9 # Higher temp for creative optimization
159+
160+
evaluation:
161+
strategy: "cascade" # Quick filter + thorough evaluation
162+
```
163+
164+
## Expected Results
165+
166+
Based on AlphaEvolve's results (23% Gemini kernel speedup), we target:
167+
168+
### Success Metrics
169+
- **15-30% speedup** over `mx.fast.scaled_dot_product_attention`
170+
- **High accuracy** (>99% numerical agreement with reference)
171+
- **Robustness** across different configurations (GQA, masks, sizes)
172+
- **Consistent gains** across most benchmark configurations
173+
174+
### Realistic Outcomes
175+
- **Moderate success**: 10-20% average speedup on some configurations
176+
- **Specialized optimizations**: Large gains on specific patterns (e.g., long sequences)
177+
- **Novel approaches**: Discovery of new attention variants
178+
- **Negative results**: Learning what doesn't work is also valuable!
179+
180+
## Example Output
181+
182+
When successful, you'll see results like:
183+
184+
```
185+
Running benchmark with evolved attention vs fused attention...
186+
1, 128, 128, 64, 16, 16, 0, float16, None, 0.045, 0.052, -13.46% (speedup: 1.16x)
187+
1, 256, 256, 64, 16, 16, 0, float16, causal, 0.089, 0.108, -17.59% (speedup: 1.21x)
188+
1, 512, 512, 64, 32, 8, 0, float16, None, 0.178, 0.205, -13.17% (speedup: 1.15x)
189+
190+
Benchmark Summary:
191+
Average speedup: 1.18x
192+
Tests with speedup > 1.1x: 78%
193+
🎉 SUCCESS: Evolved attention achieves 1.18x average speedup!
194+
```
195+
196+
## Comparison to AlphaEvolve
197+
198+
| Aspect | AlphaEvolve (Gemini/TPU) | This Example (MLX/Apple Silicon) |
199+
|--------|--------------------------|-----------------------------------|
200+
| **Target** | Pallas kernel optimization | Custom Metal kernel optimization |
201+
| **Platform** | TPU (specialized) | Apple Silicon (unified memory) |
202+
| **Result** | 23% speedup | Target: 15-30% speedup |
203+
| **Impact** | 1% overall training time reduction | Direct attention speedup |
204+
| **Constraints** | Pallas/XLA operations | Metal C++ kernel programming |
205+
| **Method** | Evolution of tiling heuristics | Evolution of custom GPU kernels |
206+
207+
## Troubleshooting
208+
209+
### Common Issues
210+
211+
1. **Low accuracy scores**:
212+
- Check tensor shapes and masking logic
213+
- Verify GQA (grouped query attention) handling
214+
- Test with simple configurations first
215+
216+
2. **Performance regressions**:
217+
- Start with small sequence lengths
218+
- Profile memory usage patterns
219+
- Check for unnecessary operations
220+
221+
3. **Evolution not converging**:
222+
- Increase iterations or population size
223+
- Adjust temperature or mutation rate
224+
- Check that evaluation pipeline works correctly
225+
226+
### Debugging
227+
228+
```bash
229+
# Test specific components
230+
python -c "from evaluator import evaluate_stage1; print(evaluate_stage1('initial_program.py'))"
231+
232+
# Run evaluation standalone
233+
python evaluator.py
234+
235+
# Test basic functionality
236+
python initial_program.py
237+
```
238+
239+
## Advanced Usage
240+
241+
### Custom Test Configurations
242+
243+
Modify `create_test_configurations()` in `evaluator.py`:
244+
245+
```python
246+
def create_test_configurations():
247+
return [
248+
# Add your custom test cases
249+
{"B": 1, "qsl": 2048, "ksl": 2048, "head_dim": 64,
250+
"n_q_heads": 32, "n_kv_heads": 8, "dtype": "float16", "mask": "causal"},
251+
]
252+
```
253+
254+
### Different Tolerance Levels
255+
256+
Adjust accuracy requirements in `compare_attention_outputs()`:
257+
258+
```python
259+
comparison = compare_attention_outputs(evolved_output, reference_output, tolerance=1e-4)
260+
```
261+
262+
### Integration with Real Models
263+
264+
The evolved attention can potentially be integrated into MLX-based transformer implementations by replacing the attention computation while keeping the same interface.
265+
266+
## Scientific Value
267+
268+
This example demonstrates:
269+
270+
1. **Reproducible Research**: Open implementation of AlphaEvolve's kernel optimization approach
271+
2. **Platform Exploration**: Understanding optimization opportunities on Apple Silicon
272+
3. **Algorithmic Discovery**: Potential discovery of novel attention patterns
273+
4. **Benchmarking Framework**: Systematic evaluation of attention implementations
274+
275+
Even negative results provide valuable insights into the limits of basic-operation optimization compared to low-level kernel optimization.
276+
277+
## Future Extensions
278+
279+
- **Mixed Precision**: Automatic precision optimization for accuracy/speed tradeoffs
280+
- **KV Caching**: Optimize for inference patterns with key-value caching
281+
- **Multi-Head Variants**: Explore different attention architectures
282+
- **Cross-Platform**: Extend discoveries to other Apple Silicon variants
283+
284+
---
285+
286+
## Quick Start Summary
287+
288+
```bash
289+
# 1. Install dependencies
290+
pip install mlx numpy pyyaml psutil
291+
292+
# 2. Run evolution
293+
cd examples/mlx_spda_optimization
294+
python ../../../openevolve-run.py initial_program.py evaluator.py --config config.yaml
295+
296+
# 3. Test results
297+
python test_evolved.py openevolve_output/best/best_program.py --subset
298+
```
299+
300+
This example provides a complete framework for kernel optimization research using OpenEvolve, bringing the power of AlphaEvolve's approach to the open-source community.

0 commit comments

Comments
 (0)