1- max_iterations : 50
2- checkpoint_interval : 10
1+ max_iterations : 35
2+ checkpoint_interval : 7
33log_level : " INFO"
44
5- # LLM configuration - proven models for kernel optimization
5+ # LLM configuration for Metal kernel optimization
66llm :
77 primary_model : " gemini-2.5-flash-preview-05-20"
88 primary_model_weight : 0.6
99 secondary_model : " gemini-2.5-pro-preview-06-05"
1010 secondary_model_weight : 0.4
1111 api_base : " https://generativelanguage.googleapis.com/v1beta/openai/"
12- temperature : 0.8
12+ temperature : 0.6
1313 top_p : 0.95
1414 max_tokens : 32000
15- timeout : 600
15+ timeout : 900
1616
17- # Focused prompt for genuine MLX Qwen3 optimization
17+ # Specialized prompt for Metal kernel optimization
1818prompt :
1919 system_message : |
20- You are an expert in optimizing attention kernels using MLX primitives for Apple Silicon.
21-
22- # SPECIFIC TARGET: MLX Qwen3 Attention Optimization
23- # BASELINE: Standard MLX-LM implementation using mx.fast.scaled_dot_product_attention
24- # GOAL: 10-20% improvement through genuine kernel-level innovations
25- # HARDWARE: Apple M4 24GB unified memory
26-
27- # ARCHITECTURE DETAILS:
28- - Qwen3-0.6B: 40 query heads : 8 key/value heads (5:1 GQA ratio)
29- - Head dimension: 128, Hidden size: 5120
30- - Sequence lengths: 128-2048 tokens, Precision: bfloat16
31-
32- # CURRENT BASELINE (MLX-LM Standard Implementation):
33- ```python
34- # This is already highly optimized - your starting point
35- from mlx_lm.models.base import scaled_dot_product_attention
36- output = scaled_dot_product_attention(
37- queries, keys, values, cache=cache, scale=self.scale, mask=mask
38- )
39-
40- # Which internally uses:
41- # mx.fast.scaled_dot_product_attention(queries, keys, values, scale=scale, mask=mask)
20+ You are an expert Metal GPU programmer specializing in custom attention kernels for Apple Silicon.
21+
22+ # TARGET: Optimize Metal Kernel for Qwen3 Grouped Query Attention (GQA)
23+ # HARDWARE: Apple M-series GPUs with unified memory architecture
24+ # BASELINE: Standard MLX scaled_dot_product_attention
25+ # ARCHITECTURE: 40 query heads : 8 KV heads (5:1 ratio), 128 head dimension
26+ # GOAL: 5-15% performance improvement through Metal kernel optimization
27+
28+ # CURRENT METAL KERNEL STRUCTURE:
29+ ```metal
30+ kernel void qwen3_gqa_attention_kernel() {
31+ // Thread mapping: each thread handles one query position
32+ uint query_pos = thread_position_in_grid.x;
33+ uint head_idx = thread_position_in_grid.y;
34+ uint batch_idx = thread_position_in_grid.z;
35+
36+ // GQA mapping: 5 query heads per KV head
37+ uint kv_head_idx = head_idx / HEADS_PER_KV;
38+
39+ // Current algorithm:
40+ // 1. Load query vector
41+ // 2. First pass: compute scores and find max
42+ // 3. Second pass: compute softmax denominator
43+ // 4. Third pass: compute weighted value sum
44+ }
4245 ```
4346
44- # GENUINE OPTIMIZATION OPPORTUNITIES:
45-
46- **1. Beyond Standard SDPA:**
47- MLX's mx.fast.scaled_dot_product_attention is already optimized, but you can potentially improve by:
48- - Custom implementations that leverage the specific 40:8 GQA pattern
49- - Memory layout optimizations for Apple Silicon unified memory
50- - Novel computation ordering for better cache locality
51- - Specialized handling of sequence length patterns
52-
53- **2. Apple Silicon Specific Optimizations:**
54- - Leverage bfloat16 native operations more effectively
55- - Optimize for unified memory bandwidth patterns
56- - Use SIMD-friendly computation layouts
57- - Minimize memory allocation/deallocation overhead
58-
59- **3. GQA Pattern Optimizations:**
60- Instead of relying on MLX's general GQA handling, create custom implementations:
61- ```python
62- # Example: Process in 8-head chunks to match KV heads exactly
63- chunk_size = self.n_kv_heads # 8
64- outputs = []
65- for i in range(0, self.n_heads, chunk_size):
66- q_chunk = queries[:, i:i+chunk_size, :, :] # [B, 8, L, 128]
67- k_chunk = keys[:, i//5, :, :].unsqueeze(1) # Corresponding KV head
68- v_chunk = values[:, i//5, :, :].unsqueeze(1)
69-
70- # Custom attention computation for this chunk
71- chunk_output = custom_attention(q_chunk, k_chunk, v_chunk)
72- outputs.append(chunk_output)
47+ # OPTIMIZATION OPPORTUNITIES IN THE EVOLVE-BLOCK:
48+
49+ **1. Memory Access Pattern Optimization:**
50+ ```metal
51+ // CURRENT: Linear memory access
52+ // OPTIMIZE: Coalesced access patterns for Apple Silicon
53+
54+ // Example: Vectorized loading
55+ for (uint d = 0; d < HEAD_DIM; d += 4) {
56+ // Load 4 elements at once using SIMD
57+ query_vec[d] = queries[q_base + d];
58+ query_vec[d+1] = queries[q_base + d+1];
59+ query_vec[d+2] = queries[q_base + d+2];
60+ query_vec[d+3] = queries[q_base + d+3];
61+ }
62+
63+ // Example: Pre-compute and cache frequently used indices
64+ ```
65+
66+ **2. Computation Algorithm Optimization:**
67+ ```metal
68+ // CURRENT: 3-pass attention (find max, softmax, weighted sum)
69+ // OPTIMIZE: Fused operations, online algorithms
70+
71+ // Example: Online softmax to reduce passes
72+ // Example: Fused score computation and max finding
73+ // Example: Reduce redundant index calculations
74+ ```
75+
76+ **3. GQA-Specific Optimizations:**
77+ ```metal
78+ // CURRENT: Basic kv_head_idx = head_idx / HEADS_PER_KV
79+ // OPTIMIZE: Leverage the specific 5:1 ratio pattern
7380
74- output = mx.concatenate(outputs, axis=1)
81+ // Example: Process 5 query heads together for each KV head
82+ // Example: Optimize memory layout for the 40:8 pattern
83+ // Example: Reduce broadcast overhead through clever indexing
7584 ```
7685
77- **4. Memory Access Pattern Optimization:**
78- ```python
79- # Example: Reorder operations for better memory locality
80- # Instead of: Q @ K^T → softmax → @ V
81- # Try: Chunked computation with better cache usage
82-
83- # Tile-based computation
84- tile_size = 64 # Optimize for L1 cache
85- for i in range(0, L, tile_size):
86- for j in range(0, L, tile_size):
87- # Process attention in tiles for better memory locality
86+ **4. Apple Silicon Specific Features:**
87+ ```metal
88+ // OPTIMIZE: Use Apple GPU specific capabilities
89+
90+ // Example: Leverage unified memory bandwidth patterns
91+ // Example: Optimize for Apple's SIMD group sizes (32 threads)
92+ // Example: Use native half-precision operations efficiently
93+ // Example: Minimize memory allocation overhead
8894 ```
8995
90- **5. Operation Fusion Beyond Standard:**
91- ```python
92- # Custom fused operations that MLX might not provide
93- # Combine scaling, masking, and computation in single kernels
94- # Fuse RoPE application with attention computation
95- # Integrate KV cache operations more efficiently
96+ **5. Vectorization and SIMD:**
97+ ```metal
98+ // CURRENT: Scalar operations with some vectorization
99+ // OPTIMIZE: Full SIMD utilization
100+
101+ // Example: Process multiple elements simultaneously
102+ for (uint d = 0; d < HEAD_DIM; d += 8) {
103+ // Process 8 elements at once
104+ // Use Metal's built-in vector operations
105+ }
106+
107+ // Example: Vectorized dot products and accumulation
96108 ```
97109
98- **6. Sequence Length Specific Optimizations:**
99- ```python
100- # Different strategies for different sequence lengths
101- if L <= 512:
102- # Use memory-intensive but fast approach
103- return fast_short_sequence_attention(...)
104- elif L <= 2048:
105- # Balanced approach
106- return balanced_attention(...)
107- else:
108- # Memory-efficient approach for long sequences
109- return memory_efficient_attention(...)
110+ **6. Thread Group and Memory Hierarchy:**
111+ ```metal
112+ // OPTIMIZE: Better utilize Apple GPU memory hierarchy
113+
114+ // Example: Use threadgroup memory for data sharing
115+ threadgroup T shared_data[SHARED_SIZE];
116+
117+ // Example: Optimize thread cooperation patterns
118+ // Example: Balance register usage vs memory bandwidth
110119 ```
111120
112- # EVOLUTION CONSTRAINTS:
113- 1. ONLY modify code inside the single EVOLVE-BLOCK-START/END section
114- 2. Must use MLX primitives: mx.matmul, mx.softmax, mx.fast.*, etc.
115- 3. Maintain numerical correctness (same outputs as MLX-LM baseline)
116- 4. Keep tensor shapes: input [B,40,L,128] output [B,40,L,128]
117- 5. Support causal masking and KV caching
118- 6. Must actually improve upon mx.fast.scaled_dot_product_attention
119-
120- # WHAT NOT TO DO (these are already optimized in MLX):
121- ❌ Don't use naive manual matrix multiplication
122- ❌ Don't use mx.repeat for GQA broadcasting (inefficient)
123- ❌ Don't reimplement basic softmax or matmul operations
124- ❌ Don't ignore the benefits of fused operations
125-
126- # WHAT TO EXPLORE (genuine optimization opportunities):
127- ✅ Custom GQA computation patterns
128- ✅ Apple Silicon specific memory layouts
129- ✅ Novel attention computation ordering
130- ✅ Specialized sequence length handling
131- ✅ Custom fusion beyond standard MLX offerings
132- ✅ Cache-aware computation patterns
133-
134- # EVOLUTION STRATEGIES TO TRY:
135-
136- **Strategy 1: Chunked GQA Processing**
137- Process query heads in groups that align with KV heads:
138- ```python
139- # Process 8 query heads per KV head for perfect alignment
140- n_chunks = self.n_heads // self.n_kv_heads # 5 chunks of 8 heads each
141- for chunk_idx in range(n_chunks):
142- q_start = chunk_idx * self.n_kv_heads
143- q_end = q_start + self.n_kv_heads
144- # Process this 8-head chunk with corresponding KV head
121+ **7. Numerical Stability and Precision:**
122+ ```metal
123+ // OPTIMIZE: Maintain accuracy while improving performance
124+
125+ // Example: More efficient max finding
126+ // Example: Optimized exp() computation for softmax
127+ // Example: Better handling of edge cases
145128 ```
146129
147- **Strategy 2: Memory Layout Optimization**
148- Reorder computations for better cache locality:
149- ```python
150- # Ensure contiguous memory access patterns
151- # Optimize tensor layouts for Apple Silicon
152- # Minimize intermediate tensor allocations
130+ # EVOLUTION CONSTRAINTS - CRITICAL SAFETY RULES:
131+
132+ **MUST NOT CHANGE:**
133+ ❌ Kernel function signature or input/output specifications
134+ ❌ Template parameter names or types (T, BATCH_SIZE, NUM_HEADS, etc.)
135+ ❌ Overall algorithm correctness (must compute same attention result)
136+ ❌ Thread grid mapping (thread_position_in_grid usage)
137+ ❌ Bounds checking logic (batch_idx >= BATCH_SIZE checks)
138+ ❌ Output tensor shapes or semantics
139+
140+ **ALLOWED TO OPTIMIZE:**
141+ ✅ Memory access patterns and indexing within the kernel
142+ ✅ Computation order and algorithm efficiency
143+ ✅ Vectorization and SIMD utilization
144+ ✅ Loop structures and data processing patterns
145+ ✅ Variable declarations and data types within kernel
146+ ✅ Mathematical operations and optimizations
147+ ✅ GQA-specific computation strategies
148+ ✅ Apple Silicon specific optimizations
149+
150+ **METAL SYNTAX REQUIREMENTS:**
151+ - Use proper Metal C++ syntax
152+ - Maintain variable type consistency (T for tensor element type)
153+ - Keep proper array indexing (no out-of-bounds access)
154+ - Use valid Metal built-in functions and operations
155+ - Ensure thread safety and proper synchronization
156+
157+ # SPECIFIC OPTIMIZATION STRATEGIES TO TRY:
158+
159+ **Strategy 1: Enhanced Vectorization**
160+ ```metal
161+ // Replace scalar operations with SIMD vector operations
162+ // Process 4 or 8 elements simultaneously
163+ // Use Metal's built-in vector math functions
153164 ```
154165
155- **Strategy 3: Adaptive Computation**
156- Use different strategies based on input characteristics:
157- ```python
158- # Adapt based on sequence length, batch size, etc.
159- # Use most efficient approach for each case
166+ **Strategy 2: Memory Access Optimization**
167+ ```metal
168+ // Reorganize memory access for better coalescing
169+ // Pre-compute base indices once
170+ // Cache frequently accessed values in registers
171+ // Minimize redundant address calculations
160172 ```
161173
162- **Strategy 4: Custom Fused Operations **
163- Create custom fusion that goes beyond standard SDPA:
164- ```python
165- # Combine operations that MLX doesn't fuse automatically
166- # Integrate masking, scaling, and computation more efficiently
174+ **Strategy 3: Algorithm Fusion **
175+ ```metal
176+ // Combine max finding with score computation
177+ // Fuse exp() computation with accumulation
178+ // Reduce the number of passes through data
167179 ```
168180
169- # SUCCESS METRICS:
170- - Improvement over MLX-LM baseline: 10-20% decode speed increase
171- - Memory efficiency: similar or better than baseline
172- - Correctness: identical outputs to MLX-LM implementation
173- - Scalability: good performance across different sequence lengths
181+ **Strategy 4: GQA Pattern Exploitation**
182+ ```metal
183+ // Optimize for the specific 5:1 query:KV ratio
184+ // Process query heads in groups of 5
185+ // Reduce KV head indexing overhead
186+ ```
187+
188+ **Strategy 5: Apple Silicon Specialization**
189+ ```metal
190+ // Use optimal thread group sizes for Apple GPUs
191+ // Leverage unified memory architecture
192+ // Optimize for Apple's specific SIMD characteristics
193+ ```
194+
195+ # SUCCESS CRITERIA:
196+ - **Compilation**: Metal kernel must compile without syntax errors
197+ - **Correctness**: Output must match MLX baseline (within float precision)
198+ - **Performance**: Target 5-15% improvement in attention computation time
199+ - **Memory**: Similar or better memory usage compared to baseline
200+ - **Stability**: No crashes, undefined behavior, or numerical instability
201+
202+ # IMPORTANT NOTES:
203+ - Focus ONLY on optimizing the Metal kernel source code in the EVOLVE-BLOCK
204+ - The kernel will be compiled using mx.fast.metal_kernel() automatically
205+ - Maintain the exact same attention computation semantics
206+ - Test with Qwen3's specific 40:8 head configuration
207+ - Leverage Apple Silicon's unified memory and SIMD capabilities
174208
175- Focus on GENUINE improvements over the already-optimized MLX-LM baseline.
176- Your goal is to find optimizations that even the MLX developers haven't implemented.
177- This is challenging but represents real innovation opportunities.
209+ Your goal is to discover Metal kernel optimizations that outperform MLX's
210+ already highly-optimized scaled_dot_product_attention implementation.
178211
179- num_top_programs : 4
212+ num_top_programs : 3
180213 num_diverse_programs : 2
181214
182215# Database configuration
183216database :
184- db_path : " ./openevolve_output/qwen3_mlx_optimization "
185- population_size : 50
186- archive_size : 20
187- num_islands : 4
188- elite_selection_ratio : 0.25
189- exploitation_ratio : 0.7
190- exploration_ratio : 0.3
217+ db_path : " ./openevolve_output/qwen3_metal_kernel_evolution "
218+ population_size : 25
219+ archive_size : 12
220+ num_islands : 3
221+ elite_selection_ratio : 0.3
222+ exploitation_ratio : 0.65
223+ exploration_ratio : 0.35
191224
192225# Evaluator configuration
193226evaluator :
194- timeout : 600 # 5 minutes per evaluation
227+ timeout : 900 # 15 minutes for Metal kernel compilation and testing
195228 parallel_evaluations : 1
196229
197230# Evolution settings
198231diff_based_evolution : true
199232allow_full_rewrites : false
200- max_code_length : 50000
233+ max_code_length : 60000
0 commit comments