Skip to content

Commit 1dd7f4e

Browse files
committed
j
1 parent 8c6aaf6 commit 1dd7f4e

File tree

7 files changed

+2061
-0
lines changed

7 files changed

+2061
-0
lines changed
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# 🎯 Qwen3-0.6B Custom GQA Attention Optimization
2+
3+
**Evolving custom Grouped Query Attention kernels using MLX primitives for Qwen3-0.6B on Apple M4**
4+
5+
This example demonstrates AlphaEvolve's kernel optimization approach by implementing and evolving custom GQA attention computation using MLX primitives, targeting the specific 40:8 query-to-KV head pattern in Qwen3-0.6B.
6+
7+
## 🔄 **Updated Approach: Custom Kernel Implementation**
8+
9+
### **Why We Changed Strategy:**
10+
11+
**Previous Approach (High-level orchestration):**
12+
- ❌ Only optimized around `mx.fast.scaled_dot_product_attention`
13+
- ❌ Limited optimization opportunities
14+
- ❌ Multiple EVOLVE-BLOCKS (OpenEvolve format violation)
15+
16+
**Current Approach (Custom kernel implementation):**
17+
-**Custom GQA implementation** using MLX primitives
18+
-**Real optimization opportunities** at computation level
19+
-**Single EVOLVE-BLOCK** with core attention computation
20+
-**Follows AlphaEvolve methodology** of optimizing actual kernels
21+
22+
## 🎯 **Optimization Target**
23+
24+
- **Model**: mlx-community/Qwen3-0.6B-bf16
25+
- **Architecture**: 40 query heads : 8 key/value heads (5:1 GQA ratio)
26+
- **Hardware**: Apple M4 24GB unified memory
27+
- **Baseline Performance**: 70.3 tokens/sec average decode speed
28+
- **Goal**: 80+ tokens/sec (14%+ improvement)
29+
30+
## 🔧 **Custom GQA Implementation**
31+
32+
### **Core Evolution Area (Single EVOLVE-BLOCK):**
33+
34+
```python
35+
def __call__(self, x, mask=None, cache=None):
36+
# Standard preprocessing...
37+
queries = self.q_proj(x) # [B, L, 40*128]
38+
keys = self.k_proj(x) # [B, L, 8*128]
39+
values = self.v_proj(x) # [B, L, 8*128]
40+
41+
# EVOLVE-BLOCK-START
42+
# Custom GQA Attention Implementation using MLX primitives
43+
# This replaces mx.fast.scaled_dot_product_attention entirely
44+
45+
# Current baseline: Manual broadcasting + standard computation
46+
keys_expanded = mx.repeat(keys, self.gqa_ratio, axis=1) # [B, 40, L, 128]
47+
values_expanded = mx.repeat(values, self.gqa_ratio, axis=1) # [B, 40, L, 128]
48+
49+
scores = mx.matmul(queries, keys_expanded.transpose(0, 1, 3, 2)) * self.scale
50+
attn_weights = mx.softmax(scores, axis=-1, precise=True)
51+
output = mx.matmul(attn_weights, values_expanded)
52+
53+
# EVOLUTION OPPORTUNITIES:
54+
# 1. Better GQA broadcasting strategies (chunked computation)
55+
# 2. Fused operations (combined matmul+softmax)
56+
# 3. Memory layout optimization for Apple Silicon
57+
# 4. Optimized causal masking
58+
# EVOLVE-BLOCK-END
59+
```
60+
61+
## 🚀 **Key Optimization Opportunities**
62+
63+
### **1. GQA Broadcasting Strategies:**
64+
```python
65+
# Current: Explicit broadcasting with mx.repeat
66+
keys_expanded = mx.repeat(keys, 5, axis=1) # Creates 5x memory usage
67+
68+
# Evolution options:
69+
# - Chunked computation (process 5 query heads per KV head)
70+
# - On-demand broadcasting (avoid materialized copies)
71+
# - Strided access patterns (direct indexing)
72+
```
73+
74+
### **2. Computation Fusion:**
75+
```python
76+
# Current: Separate operations
77+
scores = mx.matmul(queries, keys_t) * scale
78+
weights = mx.softmax(scores)
79+
output = mx.matmul(weights, values)
80+
81+
# Evolution: Fused operations to reduce memory transfers
82+
```
83+
84+
### **3. Apple Silicon Optimizations:**
85+
- bfloat16 native operations
86+
- Unified memory bandwidth optimization
87+
- Cache-friendly memory access patterns
88+
- SIMD-friendly computation layouts
89+
90+
## 📊 **Baseline vs Custom Implementation**
91+
92+
From your M4 benchmarks:
93+
```
94+
Baseline Performance (mx.fast.scaled_dot_product_attention):
95+
- Average decode: 70.3 tokens/sec
96+
- Range: 65.0 - 80.7 tokens/sec
97+
- Memory: 1.24-1.69 GB
98+
- Context degradation: ~7%
99+
100+
Custom Implementation Target:
101+
- Average decode: 80+ tokens/sec (14%+ improvement)
102+
- Better memory efficiency
103+
- Improved context scaling
104+
- Maintained numerical accuracy
105+
```
106+
107+
## 🧪 **Evaluation System**
108+
109+
### **Comprehensive Testing:**
110+
1. **Correctness Verification**: Custom implementation produces identical results
111+
2. **Performance Benchmarking**: Real text generation on 5 key scenarios
112+
3. **Memory Efficiency**: Track memory usage vs baseline
113+
4. **Context Scaling**: Test performance across different sequence lengths
114+
115+
### **Success Metrics:**
116+
- **Primary**: Average decode speed improvement (70.3 → 80+ tokens/sec)
117+
- **Secondary**: Memory efficiency, context scaling
118+
- **Critical**: Numerical correctness maintained
119+
120+
## 🚀 **Usage**
121+
122+
### **1. Test Initial Custom Implementation**
123+
```bash
124+
cd /Users/asankhaya/Documents/GitHub/openevolve/examples/mlx_metal_kernel_opt
125+
python initial_program.py # Test custom GQA implementation
126+
```
127+
128+
### **2. Run Evaluator Test**
129+
```bash
130+
python evaluator.py # Test evaluation system
131+
```
132+
133+
### **3. Start Evolution**
134+
```bash
135+
cd /Users/asankhaya/Documents/GitHub/openevolve
136+
python main.py --config examples/mlx_metal_kernel_opt/config.yaml
137+
```
138+
139+
## 📈 **Expected Evolution Trajectory**
140+
141+
### **Generation 1-10: Broadcasting Optimizations**
142+
- Chunked GQA computation strategies
143+
- Memory-efficient broadcasting alternatives
144+
- Target: 70.3 → 73-75 tokens/sec
145+
146+
### **Generation 11-20: Computation Fusion**
147+
- Fused matmul + softmax operations
148+
- Optimized causal masking integration
149+
- Target: 75 → 78-82 tokens/sec
150+
151+
### **Generation 21-30: Apple Silicon Specialization**
152+
- bfloat16 optimization
153+
- Unified memory access patterns
154+
- Advanced tensor layout optimization
155+
- Target: 80+ tokens/sec (14%+ improvement)
156+
157+
## 🔍 **Key Advantages of Custom Implementation**
158+
159+
### **Real Optimization Potential:**
160+
- **Kernel-level optimizations** using MLX primitives
161+
- **GQA-specific strategies** for 40:8 pattern
162+
- **Apple Silicon specialization** for M4 architecture
163+
- **Measurable improvements** on real workloads
164+
165+
### **Realistic Scope:**
166+
- Uses MLX's optimized primitives (not raw Metal)
167+
- Maintains compatibility with mlx-lm ecosystem
168+
- Achievable 14% improvement target
169+
- Working baseline implementation
170+
171+
### **Evolution-Friendly:**
172+
- Single EVOLVE-BLOCK with core computation
173+
- Clear optimization opportunities
174+
- Concrete performance targets
175+
- Systematic testing framework
176+
177+
## 💡 **Why This Approach Will Work**
178+
179+
1. **Real baseline**: 70.3 tokens/sec from actual M4 measurements
180+
2. **Custom implementation**: Full control over GQA computation
181+
3. **MLX primitives**: Optimized building blocks, not raw Metal
182+
4. **Specific target**: Qwen3's exact 40:8 pattern, not generic attention
183+
5. **Proven methodology**: Following AlphaEvolve's kernel optimization approach
184+
185+
This approach should evolve meaningful, measurable improvements for Qwen3-0.6B's specific GQA pattern while maintaining compatibility and correctness.
186+
187+
---
188+
189+
**🎯 Ready for custom kernel evolution!**
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Qwen3-0.6B Custom GQA Attention Optimization Configuration
2+
# Target: Evolve custom GQA implementation using MLX primitives
3+
# Baseline: 70.3 tokens/sec average decode speed
4+
# Goal: 80+ tokens/sec through custom kernel evolution
5+
6+
max_iterations: 30
7+
checkpoint_interval: 5
8+
log_level: "INFO"
9+
10+
# LLM configuration - proven models for kernel optimization
11+
llm:
12+
primary_model: "gemini-2.5-flash-preview-05-20"
13+
primary_model_weight: 0.7
14+
secondary_model: "gemini-2.5-pro-preview-06-05"
15+
secondary_model_weight: 0.3
16+
api_base: "https://generativelanguage.googleapis.com/v1beta/openai/"
17+
temperature: 0.7
18+
top_p: 0.9
19+
max_tokens: 32000
20+
timeout: 300
21+
22+
# Focused prompt for custom GQA kernel evolution
23+
prompt:
24+
system_message: |
25+
You are an expert in optimizing attention kernels using MLX primitives for Apple Silicon.
26+
27+
# SPECIFIC TARGET: Custom GQA Attention Kernel Evolution
28+
# CURRENT PERFORMANCE: 70.3 tokens/sec average decode speed
29+
# GOAL: 80+ tokens/sec (14%+ improvement) through kernel-level optimizations
30+
# HARDWARE: Apple M4 24GB unified memory
31+
32+
# ARCHITECTURE DETAILS:
33+
- Qwen3-0.6B: 40 query heads : 8 key/value heads (5:1 GQA ratio)
34+
- Head dimension: 128, Hidden size: 5120
35+
- Sequence lengths: 128-2048 tokens, Precision: bfloat16
36+
37+
# CURRENT CUSTOM IMPLEMENTATION (Baseline to Evolve):
38+
```python
39+
# Manual GQA broadcasting approach (can be optimized)
40+
keys_expanded = mx.repeat(keys, self.gqa_ratio, axis=1) # [B, 40, L, 128]
41+
values_expanded = mx.repeat(values, self.gqa_ratio, axis=1) # [B, 40, L, 128]
42+
43+
# Standard attention computation (room for optimization)
44+
scores = mx.matmul(queries, keys_expanded.transpose(0, 1, 3, 2)) * self.scale
45+
attn_weights = mx.softmax(scores, axis=-1, precise=True)
46+
output = mx.matmul(attn_weights, values_expanded)
47+
```
48+
49+
# KEY OPTIMIZATION OPPORTUNITIES:
50+
51+
**1. GQA Broadcasting Strategies:**
52+
Current: `mx.repeat` creates explicit copies of KV tensors
53+
Alternatives:
54+
- Chunked computation: Process 5 query heads per KV head separately
55+
- On-demand broadcasting: Avoid materialized copies
56+
- Strided access patterns: Direct indexing instead of repeat
57+
- Memory-efficient reshaping: Better tensor layouts
58+
59+
**2. Computation Fusion:**
60+
Current: Separate matmul → softmax → matmul operations
61+
Opportunities:
62+
- Fused attention kernels using mx.fast primitives
63+
- Combined operations to reduce memory transfers
64+
- Optimized scaling and masking integration
65+
66+
**3. Memory Access Optimization:**
67+
Apple Silicon unified memory allows specific optimizations:
68+
- Coalesced memory access for 40-head query tensor
69+
- Cache-friendly KV head access patterns
70+
- Reduced intermediate tensor allocations
71+
- Better transpose operation ordering
72+
73+
**4. Apple Silicon Specific Optimizations:**
74+
- bfloat16 native operations
75+
- Metal Performance Shaders integration
76+
- Unified memory bandwidth optimization
77+
- SIMD-friendly computation patterns
78+
79+
**5. Sequence Length Scaling:**
80+
Current performance degrades with longer contexts
81+
Opportunities:
82+
- Better attention computation chunking
83+
- Optimized causal mask application
84+
- Memory-efficient large sequence handling
85+
86+
# EVOLUTION CONSTRAINTS:
87+
1. ONLY modify code inside the single EVOLVE-BLOCK-START/END section
88+
2. Use MLX primitives: mx.matmul, mx.softmax, mx.repeat, mx.where, etc.
89+
3. Maintain numerical correctness (same output as baseline)
90+
4. Keep tensor shapes compatible: input [B,40,L,128] output [B,40,L,128]
91+
5. Support causal masking for autoregressive generation
92+
93+
# SPECIFIC EVOLUTION STRATEGIES TO EXPLORE:
94+
95+
**Strategy 1: Chunked GQA Computation**
96+
Instead of broadcasting, process query heads in groups:
97+
```python
98+
outputs = []
99+
for i in range(self.gqa_ratio): # 5 iterations
100+
q_chunk = queries[:, i*8:(i+1)*8, :, :] # [B, 8, L, 128]
101+
scores = mx.matmul(q_chunk, keys.transpose(0, 1, 3, 2)) * self.scale
102+
attn_weights = mx.softmax(scores, axis=-1)
103+
output_chunk = mx.matmul(attn_weights, values)
104+
outputs.append(output_chunk)
105+
output = mx.concatenate(outputs, axis=1)
106+
```
107+
108+
**Strategy 2: Optimized Broadcasting**
109+
Use reshape and tile operations instead of repeat:
110+
```python
111+
# More memory-efficient broadcasting
112+
keys_reshaped = keys[:, :, None, :, :].repeat(self.gqa_ratio, axis=2)
113+
keys_expanded = keys_reshaped.reshape(B, -1, L, 128)
114+
```
115+
116+
**Strategy 3: Fused Operations**
117+
Combine multiple operations to reduce memory transfers:
118+
```python
119+
# Fused scaled dot-product attention using mx.fast primitives
120+
# This might leverage optimized Metal kernels
121+
```
122+
123+
**Strategy 4: Memory Layout Optimization**
124+
Optimize tensor layouts for Apple Silicon:
125+
```python
126+
# Ensure contiguous memory layouts
127+
# Optimize transpose operations
128+
# Reduce intermediate allocations
129+
```
130+
131+
# SUCCESS METRICS (from benchmark suite):
132+
- Average decode speed: 70.3 → 80+ tokens/sec (14%+ improvement)
133+
- Memory efficiency: maintain <2GB usage
134+
- Scaling: reduce performance drop with longer contexts
135+
- Correctness: identical outputs to baseline implementation
136+
137+
Focus on CONCRETE kernel optimizations using MLX primitives.
138+
Test different GQA computation strategies systematically.
139+
Prioritize memory bandwidth efficiency and computation fusion.
140+
141+
num_top_programs: 4
142+
num_diverse_programs: 2
143+
144+
# Database configuration
145+
database:
146+
db_path: "./openevolve_output/qwen3_custom_gqa"
147+
population_size: 25
148+
archive_size: 12
149+
num_islands: 2
150+
elite_selection_ratio: 0.25
151+
exploitation_ratio: 0.7
152+
exploration_ratio: 0.3
153+
154+
# Evaluator configuration
155+
evaluator:
156+
timeout: 300 # 5 minutes per evaluation
157+
parallel_evaluations: 1
158+
159+
# Evolution settings
160+
diff_based_evolution: true
161+
allow_full_rewrites: false
162+
max_code_length: 50000

0 commit comments

Comments
 (0)