Skip to content

Commit ef0fde9

Browse files
committed
f
1 parent f8bc941 commit ef0fde9

File tree

2 files changed

+202
-201
lines changed

2 files changed

+202
-201
lines changed

examples/mlx_metal_kernel_opt/config.yaml

Lines changed: 131 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -14,131 +14,174 @@ llm:
1414
max_tokens: 32000
1515
timeout: 600
1616

17-
# Focused prompt for custom GQA kernel evolution
17+
# Focused prompt for genuine MLX Qwen3 optimization
1818
prompt:
1919
system_message: |
2020
You are an expert in optimizing attention kernels using MLX primitives for Apple Silicon.
2121
22-
# SPECIFIC TARGET: Custom GQA Attention Kernel Evolution
23-
# CURRENT PERFORMANCE: 70.3 tokens/sec average decode speed
24-
# GOAL: 80+ tokens/sec (14%+ improvement) through kernel-level optimizations
22+
# SPECIFIC TARGET: MLX Qwen3 Attention Optimization
23+
# BASELINE: Standard MLX-LM implementation using mx.fast.scaled_dot_product_attention
24+
# GOAL: 10-20% improvement through genuine kernel-level innovations
2525
# HARDWARE: Apple M4 24GB unified memory
2626
2727
# ARCHITECTURE DETAILS:
2828
- Qwen3-0.6B: 40 query heads : 8 key/value heads (5:1 GQA ratio)
2929
- Head dimension: 128, Hidden size: 5120
3030
- Sequence lengths: 128-2048 tokens, Precision: bfloat16
3131
32-
# CURRENT CUSTOM IMPLEMENTATION (Baseline to Evolve):
32+
# CURRENT BASELINE (MLX-LM Standard Implementation):
3333
```python
34-
# Manual GQA broadcasting approach (can be optimized)
35-
keys_expanded = mx.repeat(keys, self.gqa_ratio, axis=1) # [B, 40, L, 128]
36-
values_expanded = mx.repeat(values, self.gqa_ratio, axis=1) # [B, 40, L, 128]
37-
38-
# Standard attention computation (room for optimization)
39-
scores = mx.matmul(queries, keys_expanded.transpose(0, 1, 3, 2)) * self.scale
40-
attn_weights = mx.softmax(scores, axis=-1, precise=True)
41-
output = mx.matmul(attn_weights, values_expanded)
34+
# This is already highly optimized - your starting point
35+
from mlx_lm.models.base import scaled_dot_product_attention
36+
output = scaled_dot_product_attention(
37+
queries, keys, values, cache=cache, scale=self.scale, mask=mask
38+
)
39+
40+
# Which internally uses:
41+
# mx.fast.scaled_dot_product_attention(queries, keys, values, scale=scale, mask=mask)
4242
```
4343
44-
# KEY OPTIMIZATION OPPORTUNITIES:
45-
46-
**1. GQA Broadcasting Strategies:**
47-
Current: `mx.repeat` creates explicit copies of KV tensors
48-
Alternatives:
49-
- Chunked computation: Process 5 query heads per KV head separately
50-
- On-demand broadcasting: Avoid materialized copies
51-
- Strided access patterns: Direct indexing instead of repeat
52-
- Memory-efficient reshaping: Better tensor layouts
53-
54-
**2. Computation Fusion:**
55-
Current: Separate matmul → softmax → matmul operations
56-
Opportunities:
57-
- Fused attention kernels using mx.fast primitives
58-
- Combined operations to reduce memory transfers
59-
- Optimized scaling and masking integration
60-
61-
**3. Memory Access Optimization:**
62-
Apple Silicon unified memory allows specific optimizations:
63-
- Coalesced memory access for 40-head query tensor
64-
- Cache-friendly KV head access patterns
65-
- Reduced intermediate tensor allocations
66-
- Better transpose operation ordering
67-
68-
**4. Apple Silicon Specific Optimizations:**
69-
- bfloat16 native operations
70-
- Metal Performance Shaders integration
71-
- Unified memory bandwidth optimization
72-
- SIMD-friendly computation patterns
73-
74-
**5. Sequence Length Scaling:**
75-
Current performance degrades with longer contexts
76-
Opportunities:
77-
- Better attention computation chunking
78-
- Optimized causal mask application
79-
- Memory-efficient large sequence handling
44+
# GENUINE OPTIMIZATION OPPORTUNITIES:
8045
81-
# EVOLUTION CONSTRAINTS:
82-
1. ONLY modify code inside the single EVOLVE-BLOCK-START/END section
83-
2. Use MLX primitives: mx.matmul, mx.softmax, mx.repeat, mx.where, etc.
84-
3. Maintain numerical correctness (same output as baseline)
85-
4. Keep tensor shapes compatible: input [B,40,L,128] output [B,40,L,128]
86-
5. Support causal masking for autoregressive generation
46+
**1. Beyond Standard SDPA:**
47+
MLX's mx.fast.scaled_dot_product_attention is already optimized, but you can potentially improve by:
48+
- Custom implementations that leverage the specific 40:8 GQA pattern
49+
- Memory layout optimizations for Apple Silicon unified memory
50+
- Novel computation ordering for better cache locality
51+
- Specialized handling of sequence length patterns
8752
88-
# SPECIFIC EVOLUTION STRATEGIES TO EXPLORE:
53+
**2. Apple Silicon Specific Optimizations:**
54+
- Leverage bfloat16 native operations more effectively
55+
- Optimize for unified memory bandwidth patterns
56+
- Use SIMD-friendly computation layouts
57+
- Minimize memory allocation/deallocation overhead
8958
90-
**Strategy 1: Chunked GQA Computation**
91-
Instead of broadcasting, process query heads in groups:
59+
**3. GQA Pattern Optimizations:**
60+
Instead of relying on MLX's general GQA handling, create custom implementations:
9261
```python
62+
# Example: Process in 8-head chunks to match KV heads exactly
63+
chunk_size = self.n_kv_heads # 8
9364
outputs = []
94-
for i in range(self.gqa_ratio): # 5 iterations
95-
q_chunk = queries[:, i*8:(i+1)*8, :, :] # [B, 8, L, 128]
96-
scores = mx.matmul(q_chunk, keys.transpose(0, 1, 3, 2)) * self.scale
97-
attn_weights = mx.softmax(scores, axis=-1)
98-
output_chunk = mx.matmul(attn_weights, values)
99-
outputs.append(output_chunk)
65+
for i in range(0, self.n_heads, chunk_size):
66+
q_chunk = queries[:, i:i+chunk_size, :, :] # [B, 8, L, 128]
67+
k_chunk = keys[:, i//5, :, :].unsqueeze(1) # Corresponding KV head
68+
v_chunk = values[:, i//5, :, :].unsqueeze(1)
69+
70+
# Custom attention computation for this chunk
71+
chunk_output = custom_attention(q_chunk, k_chunk, v_chunk)
72+
outputs.append(chunk_output)
73+
10074
output = mx.concatenate(outputs, axis=1)
10175
```
10276
103-
**Strategy 2: Optimized Broadcasting**
104-
Use reshape and tile operations instead of repeat:
77+
**4. Memory Access Pattern Optimization:**
78+
```python
79+
# Example: Reorder operations for better memory locality
80+
# Instead of: Q @ K^T → softmax → @ V
81+
# Try: Chunked computation with better cache usage
82+
83+
# Tile-based computation
84+
tile_size = 64 # Optimize for L1 cache
85+
for i in range(0, L, tile_size):
86+
for j in range(0, L, tile_size):
87+
# Process attention in tiles for better memory locality
88+
```
89+
90+
**5. Operation Fusion Beyond Standard:**
91+
```python
92+
# Custom fused operations that MLX might not provide
93+
# Combine scaling, masking, and computation in single kernels
94+
# Fuse RoPE application with attention computation
95+
# Integrate KV cache operations more efficiently
96+
```
97+
98+
**6. Sequence Length Specific Optimizations:**
99+
```python
100+
# Different strategies for different sequence lengths
101+
if L <= 512:
102+
# Use memory-intensive but fast approach
103+
return fast_short_sequence_attention(...)
104+
elif L <= 2048:
105+
# Balanced approach
106+
return balanced_attention(...)
107+
else:
108+
# Memory-efficient approach for long sequences
109+
return memory_efficient_attention(...)
110+
```
111+
112+
# EVOLUTION CONSTRAINTS:
113+
1. ONLY modify code inside the single EVOLVE-BLOCK-START/END section
114+
2. Must use MLX primitives: mx.matmul, mx.softmax, mx.fast.*, etc.
115+
3. Maintain numerical correctness (same outputs as MLX-LM baseline)
116+
4. Keep tensor shapes: input [B,40,L,128] output [B,40,L,128]
117+
5. Support causal masking and KV caching
118+
6. Must actually improve upon mx.fast.scaled_dot_product_attention
119+
120+
# WHAT NOT TO DO (these are already optimized in MLX):
121+
❌ Don't use naive manual matrix multiplication
122+
❌ Don't use mx.repeat for GQA broadcasting (inefficient)
123+
❌ Don't reimplement basic softmax or matmul operations
124+
❌ Don't ignore the benefits of fused operations
125+
126+
# WHAT TO EXPLORE (genuine optimization opportunities):
127+
✅ Custom GQA computation patterns
128+
✅ Apple Silicon specific memory layouts
129+
✅ Novel attention computation ordering
130+
✅ Specialized sequence length handling
131+
✅ Custom fusion beyond standard MLX offerings
132+
✅ Cache-aware computation patterns
133+
134+
# EVOLUTION STRATEGIES TO TRY:
135+
136+
**Strategy 1: Chunked GQA Processing**
137+
Process query heads in groups that align with KV heads:
138+
```python
139+
# Process 8 query heads per KV head for perfect alignment
140+
n_chunks = self.n_heads // self.n_kv_heads # 5 chunks of 8 heads each
141+
for chunk_idx in range(n_chunks):
142+
q_start = chunk_idx * self.n_kv_heads
143+
q_end = q_start + self.n_kv_heads
144+
# Process this 8-head chunk with corresponding KV head
145+
```
146+
147+
**Strategy 2: Memory Layout Optimization**
148+
Reorder computations for better cache locality:
105149
```python
106-
# More memory-efficient broadcasting
107-
keys_reshaped = keys[:, :, None, :, :].repeat(self.gqa_ratio, axis=2)
108-
keys_expanded = keys_reshaped.reshape(B, -1, L, 128)
150+
# Ensure contiguous memory access patterns
151+
# Optimize tensor layouts for Apple Silicon
152+
# Minimize intermediate tensor allocations
109153
```
110154
111-
**Strategy 3: Fused Operations**
112-
Combine multiple operations to reduce memory transfers:
155+
**Strategy 3: Adaptive Computation**
156+
Use different strategies based on input characteristics:
113157
```python
114-
# Fused scaled dot-product attention using mx.fast primitives
115-
# This might leverage optimized Metal kernels
158+
# Adapt based on sequence length, batch size, etc.
159+
# Use most efficient approach for each case
116160
```
117161
118-
**Strategy 4: Memory Layout Optimization**
119-
Optimize tensor layouts for Apple Silicon:
162+
**Strategy 4: Custom Fused Operations**
163+
Create custom fusion that goes beyond standard SDPA:
120164
```python
121-
# Ensure contiguous memory layouts
122-
# Optimize transpose operations
123-
# Reduce intermediate allocations
165+
# Combine operations that MLX doesn't fuse automatically
166+
# Integrate masking, scaling, and computation more efficiently
124167
```
125168
126-
# SUCCESS METRICS (from benchmark suite):
127-
- Average decode speed: 70.3 → 80+ tokens/sec (14%+ improvement)
128-
- Memory efficiency: maintain <2GB usage
129-
- Scaling: reduce performance drop with longer contexts
130-
- Correctness: identical outputs to baseline implementation
169+
# SUCCESS METRICS:
170+
- Improvement over MLX-LM baseline: 10-20% decode speed increase
171+
- Memory efficiency: similar or better than baseline
172+
- Correctness: identical outputs to MLX-LM implementation
173+
- Scalability: good performance across different sequence lengths
131174
132-
Focus on CONCRETE kernel optimizations using MLX primitives.
133-
Test different GQA computation strategies systematically.
134-
Prioritize memory bandwidth efficiency and computation fusion.
175+
Focus on GENUINE improvements over the already-optimized MLX-LM baseline.
176+
Your goal is to find optimizations that even the MLX developers haven't implemented.
177+
This is challenging but represents real innovation opportunities.
135178
136179
num_top_programs: 4
137180
num_diverse_programs: 2
138181

139182
# Database configuration
140183
database:
141-
db_path: "./openevolve_output/qwen3_custom_gqa"
184+
db_path: "./openevolve_output/qwen3_mlx_optimization"
142185
population_size: 50
143186
archive_size: 20
144187
num_islands: 4
@@ -154,4 +197,4 @@ evaluator:
154197
# Evolution settings
155198
diff_based_evolution: true
156199
allow_full_rewrites: false
157-
max_code_length: 50000
200+
max_code_length: 50000

0 commit comments

Comments
 (0)