|
| 1 | +# MLX Attention Optimization Example |
| 2 | + |
| 3 | +This example implements **High-Level ML Kernel Optimization** inspired by AlphaEvolve's **Gemini kernel engineering** approach (Section 3.3.2), but adapted for **realistic Python/MLX optimization** on Apple Silicon. |
| 4 | + |
| 5 | +## 🎯 Why Attention Optimization? |
| 6 | + |
| 7 | +Unlike low-level matrix multiplication (where MLX's C++/Metal kernels are hard to beat), **attention mechanisms** offer genuine opportunities for optimization at the algorithm level: |
| 8 | + |
| 9 | +- **Complex multi-step operations** with room for fusion and reordering |
| 10 | +- **Memory access patterns** that can be optimized for Apple Silicon's unified memory |
| 11 | +- **Numerical precision tradeoffs** that affect both speed and accuracy |
| 12 | +- **Sequence length handling** strategies for different workloads |
| 13 | +- **Multi-head computation** patterns that can be optimized |
| 14 | + |
| 15 | +## 🔬 What We're Optimizing |
| 16 | + |
| 17 | +### **Core Attention Parameters (Evolvable)** |
| 18 | +```python |
| 19 | +def get_attention_config(): |
| 20 | + return { |
| 21 | + "attention_dtype": "float32", # ← float32/float16/bfloat16 |
| 22 | + "memory_layout": "standard", # ← standard/transposed/blocked |
| 23 | + "chunking_strategy": "none", # ← none/query_chunks/key_chunks/both |
| 24 | + "chunk_size": 512, # ← 128/256/512/1024 |
| 25 | + "softmax_precision": "high", # ← high/medium/fast |
| 26 | + "scale_strategy": "sqrt_dk", # ← sqrt_dk/learned/fixed |
| 27 | + "use_fused_qkv": True, # ← fusion optimizations |
| 28 | + "kv_cache_optimized": False # ← inference optimizations |
| 29 | + } |
| 30 | +``` |
| 31 | + |
| 32 | +### **Optimization Strategies** |
| 33 | +1. **Memory Layout Optimization**: How Q, K, V matrices are arranged in memory |
| 34 | +2. **Precision Strategies**: When to use float16 vs float32 for speed/accuracy balance |
| 35 | +3. **Chunking Algorithms**: Breaking large sequences into cache-friendly chunks |
| 36 | +4. **Fused Operations**: Combining multiple attention steps to reduce memory bandwidth |
| 37 | +5. **Computation Ordering**: Optimizing the sequence of operations for Apple Silicon |
| 38 | + |
| 39 | +## 🏗️ Architecture |
| 40 | + |
| 41 | +### **Initial Implementation (`initial_program.py`)** |
| 42 | +- **Comprehensive attention kernel** with multiple optimization strategies |
| 43 | +- **Configurable parameters** for all major attention optimizations |
| 44 | +- **Memory layout options** (standard, transposed, blocked) |
| 45 | +- **Chunking strategies** for long sequences |
| 46 | +- **Precision control** for speed/accuracy tradeoffs |
| 47 | + |
| 48 | +### **Evaluation Framework (`evaluator.py`)** |
| 49 | +- **Correctness verification** against reference MLX attention |
| 50 | +- **Performance benchmarking** on realistic model configurations |
| 51 | +- **Full model inference testing** using simplified transformer blocks |
| 52 | +- **Multi-objective optimization**: speed + accuracy + memory efficiency |
| 53 | + |
| 54 | +### **Test Configurations** |
| 55 | +Based on models like **Qwen3-0.6B-bf16**: |
| 56 | +- **Batch sizes**: 1, 2, 4, 8 (typical inference/training) |
| 57 | +- **Sequence lengths**: 128, 256, 512, 1024, 2048 |
| 58 | +- **Model dimensions**: 256, 512, 768, 1024 (small to medium models) |
| 59 | +- **Number of heads**: 8, 12, 16 |
| 60 | + |
| 61 | +## 📊 Expected Results |
| 62 | + |
| 63 | +### **Realistic Performance Targets** |
| 64 | +Based on attention complexity, we expect: |
| 65 | +- **10-30% speedup** over standard MLX attention (realistic for Python optimization) |
| 66 | +- **Memory efficiency gains** through better chunking and layout |
| 67 | +- **Accuracy preservation** (numerical error < 1e-3) |
| 68 | +- **Robust performance** across different model sizes |
| 69 | + |
| 70 | +### **Key Optimizations We Expect Evolution to Discover** |
| 71 | +1. **Float16 strategies** where accuracy allows (~20-30% speedup potential) |
| 72 | +2. **Optimal chunk sizes** for Apple Silicon memory hierarchy (likely 256-512) |
| 73 | +3. **Memory layout patterns** optimized for unified memory architecture |
| 74 | +4. **Fused operation sequences** to reduce memory bandwidth |
| 75 | +5. **Precision mixing** (high precision for critical steps, lower for others) |
| 76 | + |
| 77 | +## 🚀 Running the Example |
| 78 | + |
| 79 | +### **Prerequisites** |
| 80 | +```bash |
| 81 | +# Install MLX (Apple Silicon only) |
| 82 | +pip install mlx |
| 83 | + |
| 84 | +# Ensure OpenEvolve is installed |
| 85 | +pip install -e . |
| 86 | +``` |
| 87 | + |
| 88 | +### **Quick Test** |
| 89 | +Verify the setup works: |
| 90 | +```bash |
| 91 | +cd examples/mlx_attention_optimization |
| 92 | +python initial_program.py |
| 93 | +``` |
| 94 | + |
| 95 | +Expected output: |
| 96 | +``` |
| 97 | +MLX Attention Optimization Example |
| 98 | +Current configuration: {'attention_dtype': 'float32', 'memory_layout': 'standard', ...} |
| 99 | +
|
| 100 | +Running benchmark... |
| 101 | +Results: |
| 102 | + b1_s128_d256: 0.0045s, 12.34 GFLOPS |
| 103 | + b1_s512_d512: 0.0234s, 23.45 GFLOPS |
| 104 | + ... |
| 105 | +``` |
| 106 | + |
| 107 | +### **Run Evolution** |
| 108 | +```bash |
| 109 | +# Quick test (50 iterations, ~30 minutes) |
| 110 | +python ../../openevolve-run.py initial_program.py evaluator.py --config config.yaml --iterations 50 |
| 111 | + |
| 112 | +# Standard run (150 iterations, ~2-3 hours) |
| 113 | +python ../../openevolve-run.py initial_program.py evaluator.py --config config.yaml --iterations 150 |
| 114 | + |
| 115 | +# Full optimization (300 iterations, ~6-8 hours) |
| 116 | +python ../../openevolve-run.py initial_program.py evaluator.py --config config.yaml --iterations 300 |
| 117 | +``` |
| 118 | + |
| 119 | +## 📈 Understanding the Results |
| 120 | + |
| 121 | +### **Key Metrics** |
| 122 | +- **`attention_efficiency`**: Primary optimization target (0-1 scale) |
| 123 | +- **`model_efficiency`**: Speedup on full model inference (>1.0 is good) |
| 124 | +- **`correctness_score`**: Numerical accuracy vs reference (should be ~1.0) |
| 125 | +- **`avg_speedup`**: Average speedup across all model configurations |
| 126 | +- **`avg_throughput_gflops`**: Raw attention throughput |
| 127 | + |
| 128 | +### **Success Indicators** |
| 129 | +- **Model efficiency > 1.1**: 10%+ speedup on real model inference |
| 130 | +- **Correctness score > 0.99**: Maintains numerical accuracy |
| 131 | +- **Attention efficiency > 0.7**: Good overall optimization |
| 132 | + |
| 133 | +### **Evolution Progress** |
| 134 | +``` |
| 135 | +INFO - Iteration 75: Child abc123 from parent def456 in 45.67s. |
| 136 | +Metrics: attention_efficiency=0.7234, model_efficiency=1.1456, correctness_score=0.9987 |
| 137 | +(Δ: attention_efficiency=+0.0234, model_efficiency=+0.0456) |
| 138 | +``` |
| 139 | + |
| 140 | +## 🔍 Comparison to AlphaEvolve Paper |
| 141 | + |
| 142 | +| **Aspect** | **AlphaEvolve (TPU)** | **Our Implementation (MLX)** | |
| 143 | +|------------|----------------------|------------------------------| |
| 144 | +| **Target** | Pallas kernel tiling | Attention algorithm optimization | |
| 145 | +| **Hardware** | Google TPU | Apple Silicon GPU | |
| 146 | +| **Scope** | Low-level kernel parameters | High-level algorithm strategies | |
| 147 | +| **Language** | TPU assembly/Pallas | Python/MLX | |
| 148 | +| **Optimization Space** | Tile shapes, memory patterns | Attention fusion, precision, chunking | |
| 149 | +| **Expected Improvement** | 23% kernel speedup | 10-30% attention speedup | |
| 150 | +| **Evaluation** | Real TPU performance | Real model inference on Apple Silicon | |
| 151 | + |
| 152 | +## 🎯 Why This Approach Works |
| 153 | + |
| 154 | +### **Realistic Optimization Scope** |
| 155 | +- **Algorithm-level optimizations** rather than competing with optimized C++ kernels |
| 156 | +- **Memory access pattern improvements** for Apple Silicon's architecture |
| 157 | +- **Numerical precision strategies** that balance speed and accuracy |
| 158 | +- **Computation fusion** at the Python/MLX level |
| 159 | + |
| 160 | +### **Genuine Room for Improvement** |
| 161 | +- **Standard MLX attention** is not necessarily optimized for all use cases |
| 162 | +- **Memory layout choices** can significantly impact performance |
| 163 | +- **Precision strategies** offer real speed/accuracy tradeoffs |
| 164 | +- **Chunking algorithms** can improve memory efficiency for long sequences |
| 165 | + |
| 166 | +### **Measurable Real-World Impact** |
| 167 | +- **Full model inference testing** ensures practical relevance |
| 168 | +- **Multiple model configurations** validate generalization |
| 169 | +- **Correctness verification** ensures reliability |
| 170 | +- **Performance comparison** provides clear improvement metrics |
| 171 | + |
| 172 | +## 🔬 Advanced Usage |
| 173 | + |
| 174 | +### **Custom Model Testing** |
| 175 | +Modify `evaluator.py` to test on your specific model: |
| 176 | +```python |
| 177 | +# Add your model configuration |
| 178 | +model_configs = [ |
| 179 | + {"d_model": your_d_model, "n_heads": your_n_heads, "n_layers": 2, "seq_len": your_seq_len} |
| 180 | +] |
| 181 | +``` |
| 182 | + |
| 183 | +### **Production Integration** |
| 184 | +Use evolved configurations in real models: |
| 185 | +```python |
| 186 | +# Load best configuration |
| 187 | +with open("openevolve_output/best/best_program_info.json") as f: |
| 188 | + best_config = json.load(f)["metrics"] |
| 189 | + |
| 190 | +# Apply to your model |
| 191 | +optimized_attention = partial(optimized_attention_kernel, **best_config) |
| 192 | +``` |
| 193 | + |
| 194 | +### **Comparative Analysis** |
| 195 | +Compare different optimization strategies: |
| 196 | +```python |
| 197 | +# Test float16 vs float32 |
| 198 | +config_fp16 = {"attention_dtype": "float16", ...} |
| 199 | +config_fp32 = {"attention_dtype": "float32", ...} |
| 200 | +``` |
| 201 | + |
| 202 | +## 🎓 Learning Outcomes |
| 203 | + |
| 204 | +This example demonstrates: |
| 205 | +- **Realistic scope** for Python-based ML optimization |
| 206 | +- **Multi-objective optimization** balancing speed, accuracy, and memory |
| 207 | +- **Real-world evaluation** on transformer model inference |
| 208 | +- **Evolutionary discovery** of non-obvious optimization strategies |
| 209 | + |
| 210 | +Unlike the matrix multiplication example, this has genuine potential to discover optimizations that outperform naive implementations while remaining practically implementable. |
| 211 | + |
| 212 | +## 🔧 Troubleshooting |
| 213 | + |
| 214 | +**Common Issues:** |
| 215 | +- **MLX import errors**: Ensure you're on Apple Silicon and MLX is installed |
| 216 | +- **Memory errors**: Reduce batch sizes or sequence lengths in config |
| 217 | +- **Slow evaluation**: Reduce the number of test configurations |
| 218 | +- **Correctness failures**: Check tolerance values in evaluator |
| 219 | + |
| 220 | +**Performance Tips:** |
| 221 | +- **Monitor memory usage** during evolution |
| 222 | +- **Start with shorter sequences** for faster iteration |
| 223 | +- **Use checkpointing** for long evolution runs |
| 224 | +- **Analyze intermediate results** to understand optimization trends |
| 225 | + |
| 226 | +This example represents a more realistic and achievable optimization target compared to competing with highly optimized BLAS libraries, while still demonstrating the power of evolutionary code optimization for real ML workloads. |
0 commit comments