Skip to content

Commit f3174f9

Browse files
committed
init
1 parent 2c2e0aa commit f3174f9

File tree

4 files changed

+1194
-0
lines changed

4 files changed

+1194
-0
lines changed
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
# MLX Attention Optimization Example
2+
3+
This example implements **High-Level ML Kernel Optimization** inspired by AlphaEvolve's **Gemini kernel engineering** approach (Section 3.3.2), but adapted for **realistic Python/MLX optimization** on Apple Silicon.
4+
5+
## 🎯 Why Attention Optimization?
6+
7+
Unlike low-level matrix multiplication (where MLX's C++/Metal kernels are hard to beat), **attention mechanisms** offer genuine opportunities for optimization at the algorithm level:
8+
9+
- **Complex multi-step operations** with room for fusion and reordering
10+
- **Memory access patterns** that can be optimized for Apple Silicon's unified memory
11+
- **Numerical precision tradeoffs** that affect both speed and accuracy
12+
- **Sequence length handling** strategies for different workloads
13+
- **Multi-head computation** patterns that can be optimized
14+
15+
## 🔬 What We're Optimizing
16+
17+
### **Core Attention Parameters (Evolvable)**
18+
```python
19+
def get_attention_config():
20+
return {
21+
"attention_dtype": "float32", # ← float32/float16/bfloat16
22+
"memory_layout": "standard", # ← standard/transposed/blocked
23+
"chunking_strategy": "none", # ← none/query_chunks/key_chunks/both
24+
"chunk_size": 512, # ← 128/256/512/1024
25+
"softmax_precision": "high", # ← high/medium/fast
26+
"scale_strategy": "sqrt_dk", # ← sqrt_dk/learned/fixed
27+
"use_fused_qkv": True, # ← fusion optimizations
28+
"kv_cache_optimized": False # ← inference optimizations
29+
}
30+
```
31+
32+
### **Optimization Strategies**
33+
1. **Memory Layout Optimization**: How Q, K, V matrices are arranged in memory
34+
2. **Precision Strategies**: When to use float16 vs float32 for speed/accuracy balance
35+
3. **Chunking Algorithms**: Breaking large sequences into cache-friendly chunks
36+
4. **Fused Operations**: Combining multiple attention steps to reduce memory bandwidth
37+
5. **Computation Ordering**: Optimizing the sequence of operations for Apple Silicon
38+
39+
## 🏗️ Architecture
40+
41+
### **Initial Implementation (`initial_program.py`)**
42+
- **Comprehensive attention kernel** with multiple optimization strategies
43+
- **Configurable parameters** for all major attention optimizations
44+
- **Memory layout options** (standard, transposed, blocked)
45+
- **Chunking strategies** for long sequences
46+
- **Precision control** for speed/accuracy tradeoffs
47+
48+
### **Evaluation Framework (`evaluator.py`)**
49+
- **Correctness verification** against reference MLX attention
50+
- **Performance benchmarking** on realistic model configurations
51+
- **Full model inference testing** using simplified transformer blocks
52+
- **Multi-objective optimization**: speed + accuracy + memory efficiency
53+
54+
### **Test Configurations**
55+
Based on models like **Qwen3-0.6B-bf16**:
56+
- **Batch sizes**: 1, 2, 4, 8 (typical inference/training)
57+
- **Sequence lengths**: 128, 256, 512, 1024, 2048
58+
- **Model dimensions**: 256, 512, 768, 1024 (small to medium models)
59+
- **Number of heads**: 8, 12, 16
60+
61+
## 📊 Expected Results
62+
63+
### **Realistic Performance Targets**
64+
Based on attention complexity, we expect:
65+
- **10-30% speedup** over standard MLX attention (realistic for Python optimization)
66+
- **Memory efficiency gains** through better chunking and layout
67+
- **Accuracy preservation** (numerical error < 1e-3)
68+
- **Robust performance** across different model sizes
69+
70+
### **Key Optimizations We Expect Evolution to Discover**
71+
1. **Float16 strategies** where accuracy allows (~20-30% speedup potential)
72+
2. **Optimal chunk sizes** for Apple Silicon memory hierarchy (likely 256-512)
73+
3. **Memory layout patterns** optimized for unified memory architecture
74+
4. **Fused operation sequences** to reduce memory bandwidth
75+
5. **Precision mixing** (high precision for critical steps, lower for others)
76+
77+
## 🚀 Running the Example
78+
79+
### **Prerequisites**
80+
```bash
81+
# Install MLX (Apple Silicon only)
82+
pip install mlx
83+
84+
# Ensure OpenEvolve is installed
85+
pip install -e .
86+
```
87+
88+
### **Quick Test**
89+
Verify the setup works:
90+
```bash
91+
cd examples/mlx_attention_optimization
92+
python initial_program.py
93+
```
94+
95+
Expected output:
96+
```
97+
MLX Attention Optimization Example
98+
Current configuration: {'attention_dtype': 'float32', 'memory_layout': 'standard', ...}
99+
100+
Running benchmark...
101+
Results:
102+
b1_s128_d256: 0.0045s, 12.34 GFLOPS
103+
b1_s512_d512: 0.0234s, 23.45 GFLOPS
104+
...
105+
```
106+
107+
### **Run Evolution**
108+
```bash
109+
# Quick test (50 iterations, ~30 minutes)
110+
python ../../openevolve-run.py initial_program.py evaluator.py --config config.yaml --iterations 50
111+
112+
# Standard run (150 iterations, ~2-3 hours)
113+
python ../../openevolve-run.py initial_program.py evaluator.py --config config.yaml --iterations 150
114+
115+
# Full optimization (300 iterations, ~6-8 hours)
116+
python ../../openevolve-run.py initial_program.py evaluator.py --config config.yaml --iterations 300
117+
```
118+
119+
## 📈 Understanding the Results
120+
121+
### **Key Metrics**
122+
- **`attention_efficiency`**: Primary optimization target (0-1 scale)
123+
- **`model_efficiency`**: Speedup on full model inference (>1.0 is good)
124+
- **`correctness_score`**: Numerical accuracy vs reference (should be ~1.0)
125+
- **`avg_speedup`**: Average speedup across all model configurations
126+
- **`avg_throughput_gflops`**: Raw attention throughput
127+
128+
### **Success Indicators**
129+
- **Model efficiency > 1.1**: 10%+ speedup on real model inference
130+
- **Correctness score > 0.99**: Maintains numerical accuracy
131+
- **Attention efficiency > 0.7**: Good overall optimization
132+
133+
### **Evolution Progress**
134+
```
135+
INFO - Iteration 75: Child abc123 from parent def456 in 45.67s.
136+
Metrics: attention_efficiency=0.7234, model_efficiency=1.1456, correctness_score=0.9987
137+
(Δ: attention_efficiency=+0.0234, model_efficiency=+0.0456)
138+
```
139+
140+
## 🔍 Comparison to AlphaEvolve Paper
141+
142+
| **Aspect** | **AlphaEvolve (TPU)** | **Our Implementation (MLX)** |
143+
|------------|----------------------|------------------------------|
144+
| **Target** | Pallas kernel tiling | Attention algorithm optimization |
145+
| **Hardware** | Google TPU | Apple Silicon GPU |
146+
| **Scope** | Low-level kernel parameters | High-level algorithm strategies |
147+
| **Language** | TPU assembly/Pallas | Python/MLX |
148+
| **Optimization Space** | Tile shapes, memory patterns | Attention fusion, precision, chunking |
149+
| **Expected Improvement** | 23% kernel speedup | 10-30% attention speedup |
150+
| **Evaluation** | Real TPU performance | Real model inference on Apple Silicon |
151+
152+
## 🎯 Why This Approach Works
153+
154+
### **Realistic Optimization Scope**
155+
- **Algorithm-level optimizations** rather than competing with optimized C++ kernels
156+
- **Memory access pattern improvements** for Apple Silicon's architecture
157+
- **Numerical precision strategies** that balance speed and accuracy
158+
- **Computation fusion** at the Python/MLX level
159+
160+
### **Genuine Room for Improvement**
161+
- **Standard MLX attention** is not necessarily optimized for all use cases
162+
- **Memory layout choices** can significantly impact performance
163+
- **Precision strategies** offer real speed/accuracy tradeoffs
164+
- **Chunking algorithms** can improve memory efficiency for long sequences
165+
166+
### **Measurable Real-World Impact**
167+
- **Full model inference testing** ensures practical relevance
168+
- **Multiple model configurations** validate generalization
169+
- **Correctness verification** ensures reliability
170+
- **Performance comparison** provides clear improvement metrics
171+
172+
## 🔬 Advanced Usage
173+
174+
### **Custom Model Testing**
175+
Modify `evaluator.py` to test on your specific model:
176+
```python
177+
# Add your model configuration
178+
model_configs = [
179+
{"d_model": your_d_model, "n_heads": your_n_heads, "n_layers": 2, "seq_len": your_seq_len}
180+
]
181+
```
182+
183+
### **Production Integration**
184+
Use evolved configurations in real models:
185+
```python
186+
# Load best configuration
187+
with open("openevolve_output/best/best_program_info.json") as f:
188+
best_config = json.load(f)["metrics"]
189+
190+
# Apply to your model
191+
optimized_attention = partial(optimized_attention_kernel, **best_config)
192+
```
193+
194+
### **Comparative Analysis**
195+
Compare different optimization strategies:
196+
```python
197+
# Test float16 vs float32
198+
config_fp16 = {"attention_dtype": "float16", ...}
199+
config_fp32 = {"attention_dtype": "float32", ...}
200+
```
201+
202+
## 🎓 Learning Outcomes
203+
204+
This example demonstrates:
205+
- **Realistic scope** for Python-based ML optimization
206+
- **Multi-objective optimization** balancing speed, accuracy, and memory
207+
- **Real-world evaluation** on transformer model inference
208+
- **Evolutionary discovery** of non-obvious optimization strategies
209+
210+
Unlike the matrix multiplication example, this has genuine potential to discover optimizations that outperform naive implementations while remaining practically implementable.
211+
212+
## 🔧 Troubleshooting
213+
214+
**Common Issues:**
215+
- **MLX import errors**: Ensure you're on Apple Silicon and MLX is installed
216+
- **Memory errors**: Reduce batch sizes or sequence lengths in config
217+
- **Slow evaluation**: Reduce the number of test configurations
218+
- **Correctness failures**: Check tolerance values in evaluator
219+
220+
**Performance Tips:**
221+
- **Monitor memory usage** during evolution
222+
- **Start with shorter sequences** for faster iteration
223+
- **Use checkpointing** for long evolution runs
224+
- **Analyze intermediate results** to understand optimization trends
225+
226+
This example represents a more realistic and achievable optimization target compared to competing with highly optimized BLAS libraries, while still demonstrating the power of evolutionary code optimization for real ML workloads.
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Configuration for MLX Attention Optimization
2+
# Inspired by AlphaEvolve's Gemini kernel engineering approach
3+
# Focused on optimizing real ML workloads for Apple Silicon
4+
5+
max_iterations: 100
6+
checkpoint_interval: 10
7+
log_level: "INFO"
8+
9+
# LLM configuration optimized for ML kernel development
10+
llm:
11+
primary_model: "gemini-2.5-flash-preview-05-20"
12+
primary_model_weight: 0.8
13+
secondary_model: "gemini-2.5-pro-preview-05-06"
14+
secondary_model_weight: 0.2
15+
api_base: "https://generativelanguage.googleapis.com/v1beta/openai/"
16+
temperature: 0.7
17+
top_p: 0.95
18+
max_tokens: 24000 # thinking models require sufficient tokens otherwise the responses are trucated or empty
19+
timeout: 600
20+
21+
# Specialized prompt for attention optimization
22+
prompt:
23+
system_message: |
24+
You are an expert ML systems engineer specializing in optimizing transformer attention mechanisms for Apple Silicon and MLX.
25+
Your task is to evolve high-performance attention implementations that can outperform standard MLX operations on real model inference and training.
26+
27+
Focus on REALISTIC optimizations that can work in Python/MLX:
28+
29+
**Memory and Computation Strategies:**
30+
- Fused operations to reduce memory bandwidth
31+
- Optimal data layouts for Apple Silicon's unified memory
32+
- Strategic use of float16/bfloat16 vs float32 for speed/accuracy tradeoffs
33+
- Chunking strategies for long sequences to fit in memory
34+
- Cache-friendly computation ordering
35+
36+
**Apple Silicon Specific Optimizations:**
37+
- Leverage unified memory architecture (no GPU-CPU transfers)
38+
- Optimize for Apple's GPU compute units and memory hierarchy
39+
- Use MLX's optimized primitives as building blocks
40+
- Consider Metal Performance Shaders integration patterns
41+
42+
**Attention-Specific Optimizations:**
43+
- Different scaling strategies (sqrt(d_k), learned, fixed)
44+
- Memory layout optimizations for Q, K, V matrices
45+
- Softmax approximations that maintain accuracy
46+
- Causal masking optimizations
47+
- Multi-head attention fusion strategies
48+
- KV-cache optimization for inference
49+
50+
**Realistic Performance Targets:**
51+
- 10-30% speedup over standard MLX attention (realistic for Python optimizations)
52+
- Maintain numerical correctness (max error < 1e-3)
53+
- Support common model sizes (256-1024 d_model, 128-2048 seq_len)
54+
- Optimize for batch sizes 1-8 (typical inference/training)
55+
56+
**Key Parameters to Evolve:**
57+
- attention_dtype: "float32", "float16", "bfloat16"
58+
- memory_layout: "standard", "transposed", "blocked"
59+
- chunking_strategy: "none", "query_chunks", "key_chunks", "both"
60+
- chunk_size: 128, 256, 512, 1024
61+
- softmax_precision: "high", "medium", "fast"
62+
- scale_strategy: "sqrt_dk", "learned", "fixed"
63+
64+
Always ensure correctness while maximizing real-world performance on transformer models.
65+
66+
num_top_programs: 4
67+
num_diverse_programs: 3
68+
use_template_stochasticity: true
69+
70+
# Database configuration for attention evolution
71+
database:
72+
population_size: 150 # Moderate size for attention optimization
73+
archive_size: 40
74+
num_islands: 4
75+
elite_selection_ratio: 0.2 # Keep more elite solutions for complex optimization
76+
exploitation_ratio: 0.6
77+
exploration_ratio: 0.3
78+
79+
# Evaluator configuration for attention benchmarking
80+
evaluator:
81+
timeout: 180 # Longer timeout for model inference testing
82+
cascade_evaluation: true
83+
cascade_thresholds: [0.4, 0.7] # Lower thresholds since attention optimization is challenging
84+
parallel_evaluations: 2 # Conservative since we're testing full models
85+
use_llm_feedback: false
86+
87+
# Evolution settings for attention optimization
88+
diff_based_evolution: true
89+
allow_full_rewrites: true # Allow full rewrites for significant attention improvements
90+
max_code_length: 100000 # Larger for complex attention implementations

0 commit comments

Comments
 (0)