@@ -17,170 +17,196 @@ llm:
1717 max_tokens : 24000
1818 timeout : 600
1919
20- # Focused prompt for Metal kernel evolution
20+ # Focused prompt for CPU-based block-diagonal attention optimization
2121prompt :
2222 system_message : |
23- 🎯 **MISSION: Evolve High-Performance Metal Kernel for Block-Diagonal Attention**
23+ 🎯 **MISSION: Evolve High-Performance Block-Diagonal Attention for Packed Sequences **
2424
25- You are evolving a custom Metal GPU kernel for block-diagonal attention with packed sequences.
26- This is a focused, well-defined optimization problem with clear success metrics .
25+ You are optimizing attention computation for packed sequences (multiple sequences concatenated
26+ to avoid padding waste) where attention should only occur within sequence boundaries .
2727
2828 ## **THE PROBLEM**
2929
30- **Current Issue**: Training BERTs/GPTs with packed sequences (multiple sequences concatenated to avoid padding waste) requires block-diagonal attention where :
30+ **Current Issue**: Training models with packed sequences requires block-diagonal attention:
3131 - Keys/queries from the same sequence can attend to each other
32- - Keys/queries from different sequences should NOT attend to each other
33- - Naive masking wastes computation on large -inf regions
32+ - Keys/queries from different sequences should NOT attend to each other
33+ - Naive masking wastes computation on large masked regions
3434
35- **Goal**: Evolve a Metal kernel that efficiently computes block-diagonal attention by:
36- - Skipping computation for cross-sequence attention entirely
37- - Optimizing memory access patterns for Apple Silicon
38- - Achieving 1.5-2x+ speedup over naive masked attention
35+ **Goal**: Evolve efficient attention that beats naive masking by:
36+ - Smart block detection and processing
37+ - Optimized CPU operations with MLX
38+ - Memory-efficient computation patterns
39+ - Achieving 1.2-2x+ speedup over naive masked attention
3940
4041 ## **EVOLUTION TARGET**
4142
4243 **Single Evolution Block**: The entire `evolved_scaled_dot_product_attention` function
4344
4445 **Focus Areas** (in order of priority):
4546
46- ### 1. **Metal Kernel Source Code** (HIGHEST PRIORITY)
47- ```cpp
48- // Current kernel in create_block_diagonal_kernel_source()
49- // EVOLUTION OPPORTUNITIES:
50- // - Optimize thread allocation per block
51- // - Use threadgroup/shared memory efficiently
52- // - Implement vectorized operations (float4, half4)
53- // - Add tiled computation for large blocks
54- // - Optimize memory access patterns
55- // - Skip unnecessary computations entirely
56- ```
57-
58- ### 2. **Block Detection Logic**
47+ ### 1. **Block Detection & Processing** (HIGHEST PRIORITY)
5948 ```python
6049 # In detect_packed_sequences() and analyze_mask_structure()
6150 # EVOLUTION OPPORTUNITIES:
62- // - Better detection of block-diagonal patterns
63- // - Handle variable-length sequences efficiently
64- // - Optimize for common packing strategies
65- // - Auto-detect sequence boundaries from attention patterns
51+ # - Better detection of block-diagonal patterns from masks
52+ # - Handle variable-length sequences efficiently
53+ # - Optimize for common packing strategies (uniform/variable)
54+ # - Cache block structure analysis for repeated use
55+ ```
56+
57+ ### 2. **Optimized Block-Diagonal CPU Computation**
58+ ```python
59+ # In optimized_block_diagonal_cpu()
60+ # EVOLUTION OPPORTUNITIES:
61+ # - More efficient block iteration and memory access
62+ # - Vectorized MLX operations within blocks
63+ # - Minimize memory allocations and copies
64+ # - Fused attention computation within blocks
65+ # - Parallel processing of independent blocks
6666 ```
6767
68- ### 3. **Kernel Launch Parameters **
68+ ### 3. **Smart Fallback Logic **
6969 ```python
70- # In try_custom_metal_kernel()
70+ # In main function logic
7171 # EVOLUTION OPPORTUNITIES:
72- // - Optimize thread group sizes
73- // - Better template parameter handling
74- // - Efficient memory allocation strategies
75- // - Multiple kernel variants for different scenarios
72+ # - Better heuristics for when to use block-diagonal vs regular attention
73+ # - Adaptive algorithm selection based on sequence patterns
74+ # - Efficient mask analysis and caching
7675 ```
7776
78- ### 4. **CPU Fallback Optimization**
77+ ### 4. **MLX Operation Optimization**
7978 ```python
80- # In optimized_block_diagonal_cpu()
79+ # Throughout the function
8180 # EVOLUTION OPPORTUNITIES:
82- // - More efficient block processing
83- // - Vectorized CPU operations
84- // - Memory-efficient block iteration
81+ # - Use more efficient MLX operations (avoid numpy conversions)
82+ # - Better memory layout and access patterns
83+ # - Minimize intermediate tensor allocations
84+ # - Leverage MLX's optimized attention primitives where possible
85+ ```
86+
87+ ## **CRITICAL SYNTAX AND CODING RULES**
88+
89+ ⚠️ **AVOID THESE COMMON ERRORS**:
90+
91+ 1. **String Syntax**: Never use unescaped quotes or f-strings in multi-line strings
92+ 2. **Variable Scope**: Only use variables that are clearly defined in the current scope
93+ 3. **MLX API**: Use `mx.concatenate()`, not `.at[]` syntax (that's JAX, not MLX)
94+ 4. **Comments**: Use `#` for Python comments, `//` only inside actual C/C++ code strings
95+ 5. **F-strings**: Be very careful with f-strings containing complex expressions
96+
97+ ✅ **ALWAYS DO THIS**:
98+
99+ ```python
100+ # Good: Simple, clear variable usage
101+ B, H, L, D = q.shape
102+
103+ # Good: MLX-compatible operations
104+ output = mx.concatenate(block_outputs, axis=2)
105+
106+ # Good: Clear variable definitions within scope
107+ block_size = block_info["block_size"]
108+ num_blocks = block_info["num_blocks"]
109+
110+ # Good: Safe string formatting
111+ kernel_source = "// Simple kernel without complex formatting\n"
112+ kernel_source += f"const uint block_size = {block_size};\n"
85113 ```
86114
87- ## **SPECIFIC METAL KERNEL OPTIMIZATIONS**
115+ ❌ **NEVER DO THIS**:
116+
117+ ```python
118+ # Bad: Undefined variables
119+ print(f"Using {n_q_heads} heads") # n_q_heads not defined in this scope!
88120
89- **Memory Optimization**:
90- - Use threadgroup memory for frequently accessed data
91- - Coalesce memory reads/writes across threads
92- - Minimize global memory access
93- - Optimize for Apple Silicon unified memory
121+ # Bad: JAX syntax in MLX
122+ output = output.at[:, :, start:end, :].set(block_output) # Wrong framework!
94123
95- **Computation Optimization**:
96- - Vectorize operations using SIMD instructions
97- - Implement efficient softmax computation
98- - Use fused operations where possible
99- - Skip zero/masked computations entirely
124+ # Bad: Complex f-strings with quotes
125+ code = f"if (pos < {var}) { print(\"hello\"); }" # Syntax nightmare!
100126
101- **Thread Organization**:
102- - Optimal threadgroup sizes for different block sizes
103- - Efficient work distribution across GPU cores
104- - Minimize thread divergence
105- - Balance workload across threadgroups
127+ # Bad: C++ comments in Python
128+ // This is a Python comment # Wrong comment style!
129+ ```
106130
107131 ## **SUCCESS METRICS**
108132
109133 **Correctness** (Must achieve):
110134 - ✅ 80%+ test pass rate across all scenarios
111- - ✅ MSE < 1e-3 vs reference implementation
135+ - ✅ MSE < 1e-3 vs reference implementation
112136 - ✅ Handle variable sequence lengths correctly
113137 - ✅ No NaN/Inf in outputs
114138
115139 **Performance** (Optimization targets):
116- - 🎯 **1.5x+ speedup** over naive masked attention (good)
117- - 🎯 **2.0x+ speedup** over naive masked attention (excellent)
140+ - 🎯 **1.2x+ speedup** over naive masked attention (good)
141+ - 🎯 **1.5x+ speedup** over naive masked attention (excellent)
142+ - 🎯 **2.0x+ speedup** over naive masked attention (outstanding)
118143 - 🎯 Linear scaling with number of sequences
119- - 🎯 Efficient memory usage (no explosions)
120-
121- **Robustness** (Nice to have):
122- - Handle various block sizes (128, 256, 512, 1024)
123- - Support different head dimensions (64, 80, 128)
124- - Work with different batch sizes
125- - Graceful fallback when Metal kernel fails
144+ - 🎯 Efficient memory usage
126145
127146 ## **EVALUATION SCENARIOS**
128147
129148 You'll be tested on:
130149 - **packed_2x256**: Two 256-token sequences packed together
131- - **packed_4x128**: Four 128-token sequences packed together
150+ - **packed_4x128**: Four 128-token sequences packed together
132151 - **packed_variable**: Variable-length sequences (256 + 512)
133152 - **packed_large**: Large sequences (4x256 = 1024 total)
134153 - **packed_bert_style**: BERT-style training packing
135154
155+ ## **IMPLEMENTATION STRATEGY**
156+
157+ **Phase 1: Block Detection**
158+ - Analyze mask patterns to identify block boundaries
159+ - Handle both uniform and variable-length blocks
160+ - Cache analysis results for efficiency
161+
162+ **Phase 2: Optimized Computation**
163+ - Process each block independently with optimized attention
164+ - Use efficient MLX operations within blocks
165+ - Minimize memory allocations and data movement
166+
167+ **Phase 3: Assembly & Output**
168+ - Efficiently combine block outputs
169+ - Ensure correct output shape and dtype
170+ - Handle edge cases gracefully
171+
136172 ## **KEY CONSTRAINTS**
137173
138174 **DO NOT CHANGE**:
139175 - Function signature of `evolved_scaled_dot_product_attention`
140- - Overall structure (detect -> kernel -> fallback)
176+ - Overall structure (detect -> process -> fallback)
141177 - Error handling and fallback mechanisms
142178
143179 **FOCUS ON**:
144- - Metal kernel source code optimization
145- - Block detection efficiency
146- - Memory access patterns
147- - Thread organization and vectorization
180+ - Block detection efficiency and accuracy
181+ - CPU computation optimization with MLX
182+ - Memory access patterns and data layout
183+ - Algorithmic improvements for block processing
148184
149185 ## **EXAMPLE IMPROVEMENTS**
150186
151- **Better Thread Organization**:
152- ```cpp
153- // Instead of: one thread per query position
154- // Try: threadgroup processes entire block cooperatively
155- ```
156-
157- **Vectorized Operations**:
158- ```cpp
159- // Instead of: scalar operations
160- // Try: float4/half4 vector operations
187+ **Better Block Detection**:
188+ ```python
189+ # Analyze mask structure more efficiently
190+ # Cache block boundaries for reuse
191+ # Handle edge cases in variable-length sequences
161192 ```
162193
163- **Shared Memory Usage**:
164- ```cpp
165- // Add: threadgroup shared memory for keys/values
166- threadgroup float shared_keys[BLOCK_SIZE * HEAD_DIM];
194+ **Optimized Block Processing**:
195+ ```python
196+ # Use MLX's optimized operations
197+ # Minimize intermediate allocations
198+ # Process blocks in optimal order
167199 ```
168200
169- **Optimized Softmax**:
170- ```cpp
171- // Instead of: naive exp/sum
172- // Try: numerically stable, vectorized softmax
201+ **Memory Efficiency**:
202+ ```python
203+ # Avoid unnecessary numpy conversions
204+ # Reuse intermediate tensors where possible
205+ # Optimize data layout for cache efficiency
173206 ```
174207
175- ## **DEBUGGING HINTS**
176-
177- - Start with correctness, then optimize performance
178- - Test with simple uniform blocks before variable lengths
179- - Use CPU fallback to verify Metal kernel correctness
180- - Monitor memory usage and avoid explosions
181- - Check that block detection is working correctly
182-
183- Focus on creating a Metal kernel that significantly outperforms naive masking through smart computation skipping and memory optimization!
208+ Remember: Focus on correctness first, then optimize for performance.
209+ Use only MLX operations and avoid complex string formatting that can cause syntax errors!
184210
185211 num_top_programs : 5
186212 num_diverse_programs : 3
0 commit comments