Skip to content

Commit bd0ee98

Browse files
committed
as
1 parent 9f99f63 commit bd0ee98

File tree

3 files changed

+1522
-37
lines changed

3 files changed

+1522
-37
lines changed
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
# MLX Attention Optimization
2+
3+
This example demonstrates using OpenEvolve to optimize attention mechanisms for Apple Silicon, similar to the Gemini kernel optimization described in the AlphaEvolve paper.
4+
5+
## Overview
6+
7+
The goal is to evolve the core attention computation in MLX (Apple's ML framework) to achieve better performance while maintaining numerical accuracy. This example focuses on optimizing the scaled dot-product attention mechanism that forms the heart of transformer models.
8+
9+
## What Gets Optimized
10+
11+
The example evolves the core attention computation within the `OptimizedAttention` class:
12+
13+
```python
14+
# EVOLVE-BLOCK-START
15+
# This section contains the attention computation that gets evolved
16+
scores = mx.matmul(queries, keys.transpose(0, 1, 3, 2))
17+
scores = scores * self.scale
18+
if mask is not None:
19+
scores = scores + mask
20+
attn_weights = mx.softmax(scores, axis=-1)
21+
output = mx.matmul(attn_weights, values)
22+
# EVOLVE-BLOCK-END
23+
```
24+
25+
**What remains fixed:**
26+
- Query, Key, Value projections
27+
- RMSNorm layers
28+
- RoPE (Rotary Position Embedding)
29+
- Output projection
30+
- Input/output shapes and interfaces
31+
32+
**What can evolve:**
33+
- Attention computation patterns (chunked, sparse, etc.)
34+
- Memory access strategies
35+
- Optimized implementations for Apple Silicon
36+
- Alternative attention mechanisms
37+
- Memory tiling strategies
38+
39+
## Key Features
40+
41+
### Comprehensive Evaluation
42+
The evaluator tests multiple aspects:
43+
44+
1. **Numerical Accuracy**: Compares outputs with reference implementation using MLX-LM's `scaled_dot_product_attention`
45+
2. **Performance**: Measures throughput (tokens/second) and compares with reference
46+
3. **Memory Efficiency**: Tracks memory usage during computation
47+
4. **Stability**: Tests with edge cases (small/large values, different input sizes)
48+
5. **Robustness**: Tests across different configurations (batch sizes, sequence lengths, GQA)
49+
50+
### Test Cases
51+
Evaluates across diverse scenarios:
52+
- Different sequence lengths (64 to 2048 tokens)
53+
- Various model sizes (256 to 1024 hidden dimensions)
54+
- Grouped Query Attention (GQA) with different num_kv_heads
55+
- Multiple batch sizes
56+
- Edge cases for numerical stability
57+
58+
### Apple Silicon Optimization Opportunities
59+
The evolution process can discover optimizations specific to Apple Silicon:
60+
- Leveraging unified memory architecture
61+
- Cache-friendly memory access patterns
62+
- Vectorized operations optimized for ARM
63+
- Efficient use of Apple's matrix units (AMX)
64+
65+
## Running the Example
66+
67+
### Prerequisites
68+
```bash
69+
pip install -r requirements.txt
70+
# Or manually:
71+
pip install mlx mlx-lm psutil numpy pyyaml
72+
export OPENAI_API_KEY="your-api-key" # For Gemini models
73+
```
74+
75+
### Basic Usage
76+
```bash
77+
cd examples/mlx_attention_optimization
78+
python ../../../openevolve-run.py initial_program.py evaluator.py --config config.yaml --iterations 200
79+
```
80+
81+
### Testing Initial Implementation
82+
```bash
83+
python initial_program.py # Test basic functionality
84+
python evaluator.py # Run full evaluation
85+
```
86+
87+
## Configuration
88+
89+
The example uses stronger LLM models (Gemini 2.0 Flash/Pro) given the complexity of attention optimization:
90+
91+
```yaml
92+
llm:
93+
primary_model: "gemini-2.0-flash"
94+
secondary_model: "gemini-2.0-pro"
95+
temperature: 0.8
96+
max_tokens: 8192
97+
```
98+
99+
Key configuration choices:
100+
- **200 iterations**: More iterations for complex optimization
101+
- **Cascade evaluation**: Quick accuracy check before expensive performance tests
102+
- **Larger population**: 100 programs to explore diverse optimization strategies
103+
- **Higher temperature**: More creative exploration for novel optimizations
104+
105+
## Expected Optimizations
106+
107+
OpenEvolve might discover:
108+
109+
### Memory Optimizations
110+
- **Chunked Attention**: Process attention in memory-efficient chunks
111+
- **Tiled Computation**: Optimize memory access patterns for Apple Silicon
112+
- **Unified Memory Exploitation**: Leverage shared CPU/GPU memory
113+
114+
### Algorithmic Improvements
115+
- **Sparse Attention**: Skip computation for irrelevant token pairs
116+
- **Local Attention**: Focus on nearby tokens for efficiency
117+
- **Fused Operations**: Combine multiple operations to reduce memory bandwidth
118+
119+
### Apple Silicon Specific
120+
- **AMX Optimization**: Efficient use of Apple's matrix units
121+
- **Cache-Friendly Patterns**: Optimize for Apple Silicon's cache hierarchy
122+
- **Vectorization**: Better use of NEON/Advanced SIMD instructions
123+
124+
## Success Metrics
125+
126+
A successful optimization should achieve:
127+
- **High accuracy score** (>0.95): Maintains numerical equivalence with reference
128+
- **Performance improvement** (>1.2x): Meaningful speedup over reference implementation
129+
- **Memory efficiency**: Better tokens/MB ratio
130+
- **Stability**: Robust across different input configurations
131+
132+
## Comparison to AlphaEvolve Results
133+
134+
The original AlphaEvolve achieved:
135+
- **23% speedup** in Gemini kernel optimization (Pallas/TPU)
136+
- **1% overall training time reduction** for large models
137+
138+
Our goals for MLX/Apple Silicon:
139+
- **15-30% attention speedup**: Similar to original results
140+
- **Better memory efficiency**: Exploit unified memory advantages
141+
- **Cross-model benefits**: Optimizations that work across different transformer architectures
142+
143+
## Using Your Optimized Attention
144+
145+
After evolution completes, you'll have an optimized attention implementation. Here's how to use it:
146+
147+
### Quick Start (3 lines of code)
148+
```python
149+
from attention_integration import load_and_patch_model
150+
from mlx_lm import generate
151+
152+
# Load any MLX-LM model with evolved attention
153+
model, tokenizer = load_and_patch_model(
154+
model_path="mlx-community/Qwen3-0.6B-bf16",
155+
evolved_program_path="openevolve_output/best/best_program.py"
156+
)
157+
158+
# Use exactly like any other MLX-LM model - but faster!
159+
response = generate(model, tokenizer, "Write a Python function:", max_tokens=100)
160+
```
161+
162+
### Testing Your Implementation
163+
```bash
164+
# Quick demo
165+
python use_evolved_attention.py demo
166+
167+
# Comprehensive benchmarking
168+
python test_workloads.py --model mlx-community/Qwen3-0.6B-bf16 --evolved-program openevolve_output/best/best_program.py
169+
```
170+
171+
### Recommended Test Workloads
172+
- **Text generation**: Stories, articles, reports (15-30% speedup expected)
173+
- **Code generation**: Functions, classes, APIs (20-40% speedup expected)
174+
- **Long-form content**: 1024+ tokens (30-50% speedup expected)
175+
- **Question answering**: Complex reasoning tasks (10-25% speedup expected)
176+
177+
📖 **See [USAGE.md](USAGE.md) for complete integration guide and benchmarking instructions.**
178+
179+
## Advanced Usage
180+
181+
### Custom Test Cases
182+
Modify `create_test_cases()` in `evaluator.py` to test specific configurations:
183+
184+
```python
185+
def create_test_cases():
186+
return [
187+
{"batch_size": 1, "seq_len": 4096, "hidden_size": 2048, "num_heads": 32, "num_kv_heads": 8},
188+
# Add your custom test cases
189+
]
190+
```
191+
192+
### Different Tolerance Levels
193+
Adjust accuracy requirements in `compare_outputs()`:
194+
195+
```python
196+
comparison = compare_outputs(evolved_output, reference_output, tolerance=1e-4)
197+
```
198+
199+
### Integration Testing
200+
Test evolved attention with real models by replacing the attention module in mlx-lm implementations.
201+
202+
## Troubleshooting
203+
204+
### Common Issues
205+
1. **Low accuracy scores**: Check tensor shapes and ensure proper masking
206+
2. **Memory errors**: Reduce batch sizes or sequence lengths in test cases
207+
3. **Slow evaluation**: Reduce number of test cases or performance benchmark runs
208+
209+
### Debugging
210+
Enable detailed logging:
211+
```bash
212+
python evaluator.py # Run standalone evaluation
213+
```
214+
215+
Check specific test cases:
216+
```python
217+
python -c "
218+
from evaluator import evaluate_stage1
219+
print(evaluate_stage1('initial_program.py'))
220+
"
221+
```
222+
223+
## Future Extensions
224+
225+
- **Multi-Head Attention Variants**: Optimize different attention patterns
226+
- **KV Caching**: Optimize for inference with key-value caching
227+
- **Mixed Precision**: Automatic precision optimization
228+
- **Cross-Platform**: Extend optimizations to other Apple Silicon variants (A-series, etc.)

0 commit comments

Comments
 (0)