|
| 1 | +# MLX Attention Optimization |
| 2 | + |
| 3 | +This example demonstrates using OpenEvolve to optimize attention mechanisms for Apple Silicon, similar to the Gemini kernel optimization described in the AlphaEvolve paper. |
| 4 | + |
| 5 | +## Overview |
| 6 | + |
| 7 | +The goal is to evolve the core attention computation in MLX (Apple's ML framework) to achieve better performance while maintaining numerical accuracy. This example focuses on optimizing the scaled dot-product attention mechanism that forms the heart of transformer models. |
| 8 | + |
| 9 | +## What Gets Optimized |
| 10 | + |
| 11 | +The example evolves the core attention computation within the `OptimizedAttention` class: |
| 12 | + |
| 13 | +```python |
| 14 | +# EVOLVE-BLOCK-START |
| 15 | +# This section contains the attention computation that gets evolved |
| 16 | +scores = mx.matmul(queries, keys.transpose(0, 1, 3, 2)) |
| 17 | +scores = scores * self.scale |
| 18 | +if mask is not None: |
| 19 | + scores = scores + mask |
| 20 | +attn_weights = mx.softmax(scores, axis=-1) |
| 21 | +output = mx.matmul(attn_weights, values) |
| 22 | +# EVOLVE-BLOCK-END |
| 23 | +``` |
| 24 | + |
| 25 | +**What remains fixed:** |
| 26 | +- Query, Key, Value projections |
| 27 | +- RMSNorm layers |
| 28 | +- RoPE (Rotary Position Embedding) |
| 29 | +- Output projection |
| 30 | +- Input/output shapes and interfaces |
| 31 | + |
| 32 | +**What can evolve:** |
| 33 | +- Attention computation patterns (chunked, sparse, etc.) |
| 34 | +- Memory access strategies |
| 35 | +- Optimized implementations for Apple Silicon |
| 36 | +- Alternative attention mechanisms |
| 37 | +- Memory tiling strategies |
| 38 | + |
| 39 | +## Key Features |
| 40 | + |
| 41 | +### Comprehensive Evaluation |
| 42 | +The evaluator tests multiple aspects: |
| 43 | + |
| 44 | +1. **Numerical Accuracy**: Compares outputs with reference implementation using MLX-LM's `scaled_dot_product_attention` |
| 45 | +2. **Performance**: Measures throughput (tokens/second) and compares with reference |
| 46 | +3. **Memory Efficiency**: Tracks memory usage during computation |
| 47 | +4. **Stability**: Tests with edge cases (small/large values, different input sizes) |
| 48 | +5. **Robustness**: Tests across different configurations (batch sizes, sequence lengths, GQA) |
| 49 | + |
| 50 | +### Test Cases |
| 51 | +Evaluates across diverse scenarios: |
| 52 | +- Different sequence lengths (64 to 2048 tokens) |
| 53 | +- Various model sizes (256 to 1024 hidden dimensions) |
| 54 | +- Grouped Query Attention (GQA) with different num_kv_heads |
| 55 | +- Multiple batch sizes |
| 56 | +- Edge cases for numerical stability |
| 57 | + |
| 58 | +### Apple Silicon Optimization Opportunities |
| 59 | +The evolution process can discover optimizations specific to Apple Silicon: |
| 60 | +- Leveraging unified memory architecture |
| 61 | +- Cache-friendly memory access patterns |
| 62 | +- Vectorized operations optimized for ARM |
| 63 | +- Efficient use of Apple's matrix units (AMX) |
| 64 | + |
| 65 | +## Running the Example |
| 66 | + |
| 67 | +### Prerequisites |
| 68 | +```bash |
| 69 | +pip install -r requirements.txt |
| 70 | +# Or manually: |
| 71 | +pip install mlx mlx-lm psutil numpy pyyaml |
| 72 | +export OPENAI_API_KEY="your-api-key" # For Gemini models |
| 73 | +``` |
| 74 | + |
| 75 | +### Basic Usage |
| 76 | +```bash |
| 77 | +cd examples/mlx_attention_optimization |
| 78 | +python ../../../openevolve-run.py initial_program.py evaluator.py --config config.yaml --iterations 200 |
| 79 | +``` |
| 80 | + |
| 81 | +### Testing Initial Implementation |
| 82 | +```bash |
| 83 | +python initial_program.py # Test basic functionality |
| 84 | +python evaluator.py # Run full evaluation |
| 85 | +``` |
| 86 | + |
| 87 | +## Configuration |
| 88 | + |
| 89 | +The example uses stronger LLM models (Gemini 2.0 Flash/Pro) given the complexity of attention optimization: |
| 90 | + |
| 91 | +```yaml |
| 92 | +llm: |
| 93 | + primary_model: "gemini-2.0-flash" |
| 94 | + secondary_model: "gemini-2.0-pro" |
| 95 | + temperature: 0.8 |
| 96 | + max_tokens: 8192 |
| 97 | +``` |
| 98 | +
|
| 99 | +Key configuration choices: |
| 100 | +- **200 iterations**: More iterations for complex optimization |
| 101 | +- **Cascade evaluation**: Quick accuracy check before expensive performance tests |
| 102 | +- **Larger population**: 100 programs to explore diverse optimization strategies |
| 103 | +- **Higher temperature**: More creative exploration for novel optimizations |
| 104 | +
|
| 105 | +## Expected Optimizations |
| 106 | +
|
| 107 | +OpenEvolve might discover: |
| 108 | +
|
| 109 | +### Memory Optimizations |
| 110 | +- **Chunked Attention**: Process attention in memory-efficient chunks |
| 111 | +- **Tiled Computation**: Optimize memory access patterns for Apple Silicon |
| 112 | +- **Unified Memory Exploitation**: Leverage shared CPU/GPU memory |
| 113 | +
|
| 114 | +### Algorithmic Improvements |
| 115 | +- **Sparse Attention**: Skip computation for irrelevant token pairs |
| 116 | +- **Local Attention**: Focus on nearby tokens for efficiency |
| 117 | +- **Fused Operations**: Combine multiple operations to reduce memory bandwidth |
| 118 | +
|
| 119 | +### Apple Silicon Specific |
| 120 | +- **AMX Optimization**: Efficient use of Apple's matrix units |
| 121 | +- **Cache-Friendly Patterns**: Optimize for Apple Silicon's cache hierarchy |
| 122 | +- **Vectorization**: Better use of NEON/Advanced SIMD instructions |
| 123 | +
|
| 124 | +## Success Metrics |
| 125 | +
|
| 126 | +A successful optimization should achieve: |
| 127 | +- **High accuracy score** (>0.95): Maintains numerical equivalence with reference |
| 128 | +- **Performance improvement** (>1.2x): Meaningful speedup over reference implementation |
| 129 | +- **Memory efficiency**: Better tokens/MB ratio |
| 130 | +- **Stability**: Robust across different input configurations |
| 131 | +
|
| 132 | +## Comparison to AlphaEvolve Results |
| 133 | +
|
| 134 | +The original AlphaEvolve achieved: |
| 135 | +- **23% speedup** in Gemini kernel optimization (Pallas/TPU) |
| 136 | +- **1% overall training time reduction** for large models |
| 137 | +
|
| 138 | +Our goals for MLX/Apple Silicon: |
| 139 | +- **15-30% attention speedup**: Similar to original results |
| 140 | +- **Better memory efficiency**: Exploit unified memory advantages |
| 141 | +- **Cross-model benefits**: Optimizations that work across different transformer architectures |
| 142 | +
|
| 143 | +## Using Your Optimized Attention |
| 144 | +
|
| 145 | +After evolution completes, you'll have an optimized attention implementation. Here's how to use it: |
| 146 | +
|
| 147 | +### Quick Start (3 lines of code) |
| 148 | +```python |
| 149 | +from attention_integration import load_and_patch_model |
| 150 | +from mlx_lm import generate |
| 151 | + |
| 152 | +# Load any MLX-LM model with evolved attention |
| 153 | +model, tokenizer = load_and_patch_model( |
| 154 | + model_path="mlx-community/Qwen3-0.6B-bf16", |
| 155 | + evolved_program_path="openevolve_output/best/best_program.py" |
| 156 | +) |
| 157 | + |
| 158 | +# Use exactly like any other MLX-LM model - but faster! |
| 159 | +response = generate(model, tokenizer, "Write a Python function:", max_tokens=100) |
| 160 | +``` |
| 161 | + |
| 162 | +### Testing Your Implementation |
| 163 | +```bash |
| 164 | +# Quick demo |
| 165 | +python use_evolved_attention.py demo |
| 166 | + |
| 167 | +# Comprehensive benchmarking |
| 168 | +python test_workloads.py --model mlx-community/Qwen3-0.6B-bf16 --evolved-program openevolve_output/best/best_program.py |
| 169 | +``` |
| 170 | + |
| 171 | +### Recommended Test Workloads |
| 172 | +- **Text generation**: Stories, articles, reports (15-30% speedup expected) |
| 173 | +- **Code generation**: Functions, classes, APIs (20-40% speedup expected) |
| 174 | +- **Long-form content**: 1024+ tokens (30-50% speedup expected) |
| 175 | +- **Question answering**: Complex reasoning tasks (10-25% speedup expected) |
| 176 | + |
| 177 | +📖 **See [USAGE.md](USAGE.md) for complete integration guide and benchmarking instructions.** |
| 178 | + |
| 179 | +## Advanced Usage |
| 180 | + |
| 181 | +### Custom Test Cases |
| 182 | +Modify `create_test_cases()` in `evaluator.py` to test specific configurations: |
| 183 | + |
| 184 | +```python |
| 185 | +def create_test_cases(): |
| 186 | + return [ |
| 187 | + {"batch_size": 1, "seq_len": 4096, "hidden_size": 2048, "num_heads": 32, "num_kv_heads": 8}, |
| 188 | + # Add your custom test cases |
| 189 | + ] |
| 190 | +``` |
| 191 | + |
| 192 | +### Different Tolerance Levels |
| 193 | +Adjust accuracy requirements in `compare_outputs()`: |
| 194 | + |
| 195 | +```python |
| 196 | +comparison = compare_outputs(evolved_output, reference_output, tolerance=1e-4) |
| 197 | +``` |
| 198 | + |
| 199 | +### Integration Testing |
| 200 | +Test evolved attention with real models by replacing the attention module in mlx-lm implementations. |
| 201 | + |
| 202 | +## Troubleshooting |
| 203 | + |
| 204 | +### Common Issues |
| 205 | +1. **Low accuracy scores**: Check tensor shapes and ensure proper masking |
| 206 | +2. **Memory errors**: Reduce batch sizes or sequence lengths in test cases |
| 207 | +3. **Slow evaluation**: Reduce number of test cases or performance benchmark runs |
| 208 | + |
| 209 | +### Debugging |
| 210 | +Enable detailed logging: |
| 211 | +```bash |
| 212 | +python evaluator.py # Run standalone evaluation |
| 213 | +``` |
| 214 | + |
| 215 | +Check specific test cases: |
| 216 | +```python |
| 217 | +python -c " |
| 218 | +from evaluator import evaluate_stage1 |
| 219 | +print(evaluate_stage1('initial_program.py')) |
| 220 | +" |
| 221 | +``` |
| 222 | + |
| 223 | +## Future Extensions |
| 224 | + |
| 225 | +- **Multi-Head Attention Variants**: Optimize different attention patterns |
| 226 | +- **KV Caching**: Optimize for inference with key-value caching |
| 227 | +- **Mixed Precision**: Automatic precision optimization |
| 228 | +- **Cross-Platform**: Extend optimizations to other Apple Silicon variants (A-series, etc.) |
0 commit comments