Skip to content

Commit dc078bc

Browse files
committed
fixes
1 parent fdc659e commit dc078bc

File tree

3 files changed

+135
-123
lines changed

3 files changed

+135
-123
lines changed

examples/mlx_spda_optimization/config.yaml

Lines changed: 124 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -17,170 +17,196 @@ llm:
1717
max_tokens: 24000
1818
timeout: 600
1919

20-
# Focused prompt for Metal kernel evolution
20+
# Focused prompt for CPU-based block-diagonal attention optimization
2121
prompt:
2222
system_message: |
23-
🎯 **MISSION: Evolve High-Performance Metal Kernel for Block-Diagonal Attention**
23+
🎯 **MISSION: Evolve High-Performance Block-Diagonal Attention for Packed Sequences**
2424
25-
You are evolving a custom Metal GPU kernel for block-diagonal attention with packed sequences.
26-
This is a focused, well-defined optimization problem with clear success metrics.
25+
You are optimizing attention computation for packed sequences (multiple sequences concatenated
26+
to avoid padding waste) where attention should only occur within sequence boundaries.
2727
2828
## **THE PROBLEM**
2929
30-
**Current Issue**: Training BERTs/GPTs with packed sequences (multiple sequences concatenated to avoid padding waste) requires block-diagonal attention where:
30+
**Current Issue**: Training models with packed sequences requires block-diagonal attention:
3131
- Keys/queries from the same sequence can attend to each other
32-
- Keys/queries from different sequences should NOT attend to each other
33-
- Naive masking wastes computation on large -inf regions
32+
- Keys/queries from different sequences should NOT attend to each other
33+
- Naive masking wastes computation on large masked regions
3434
35-
**Goal**: Evolve a Metal kernel that efficiently computes block-diagonal attention by:
36-
- Skipping computation for cross-sequence attention entirely
37-
- Optimizing memory access patterns for Apple Silicon
38-
- Achieving 1.5-2x+ speedup over naive masked attention
35+
**Goal**: Evolve efficient attention that beats naive masking by:
36+
- Smart block detection and processing
37+
- Optimized CPU operations with MLX
38+
- Memory-efficient computation patterns
39+
- Achieving 1.2-2x+ speedup over naive masked attention
3940
4041
## **EVOLUTION TARGET**
4142
4243
**Single Evolution Block**: The entire `evolved_scaled_dot_product_attention` function
4344
4445
**Focus Areas** (in order of priority):
4546
46-
### 1. **Metal Kernel Source Code** (HIGHEST PRIORITY)
47-
```cpp
48-
// Current kernel in create_block_diagonal_kernel_source()
49-
// EVOLUTION OPPORTUNITIES:
50-
// - Optimize thread allocation per block
51-
// - Use threadgroup/shared memory efficiently
52-
// - Implement vectorized operations (float4, half4)
53-
// - Add tiled computation for large blocks
54-
// - Optimize memory access patterns
55-
// - Skip unnecessary computations entirely
56-
```
57-
58-
### 2. **Block Detection Logic**
47+
### 1. **Block Detection & Processing** (HIGHEST PRIORITY)
5948
```python
6049
# In detect_packed_sequences() and analyze_mask_structure()
6150
# EVOLUTION OPPORTUNITIES:
62-
// - Better detection of block-diagonal patterns
63-
// - Handle variable-length sequences efficiently
64-
// - Optimize for common packing strategies
65-
// - Auto-detect sequence boundaries from attention patterns
51+
# - Better detection of block-diagonal patterns from masks
52+
# - Handle variable-length sequences efficiently
53+
# - Optimize for common packing strategies (uniform/variable)
54+
# - Cache block structure analysis for repeated use
55+
```
56+
57+
### 2. **Optimized Block-Diagonal CPU Computation**
58+
```python
59+
# In optimized_block_diagonal_cpu()
60+
# EVOLUTION OPPORTUNITIES:
61+
# - More efficient block iteration and memory access
62+
# - Vectorized MLX operations within blocks
63+
# - Minimize memory allocations and copies
64+
# - Fused attention computation within blocks
65+
# - Parallel processing of independent blocks
6666
```
6767
68-
### 3. **Kernel Launch Parameters**
68+
### 3. **Smart Fallback Logic**
6969
```python
70-
# In try_custom_metal_kernel()
70+
# In main function logic
7171
# EVOLUTION OPPORTUNITIES:
72-
// - Optimize thread group sizes
73-
// - Better template parameter handling
74-
// - Efficient memory allocation strategies
75-
// - Multiple kernel variants for different scenarios
72+
# - Better heuristics for when to use block-diagonal vs regular attention
73+
# - Adaptive algorithm selection based on sequence patterns
74+
# - Efficient mask analysis and caching
7675
```
7776
78-
### 4. **CPU Fallback Optimization**
77+
### 4. **MLX Operation Optimization**
7978
```python
80-
# In optimized_block_diagonal_cpu()
79+
# Throughout the function
8180
# EVOLUTION OPPORTUNITIES:
82-
// - More efficient block processing
83-
// - Vectorized CPU operations
84-
// - Memory-efficient block iteration
81+
# - Use more efficient MLX operations (avoid numpy conversions)
82+
# - Better memory layout and access patterns
83+
# - Minimize intermediate tensor allocations
84+
# - Leverage MLX's optimized attention primitives where possible
85+
```
86+
87+
## **CRITICAL SYNTAX AND CODING RULES**
88+
89+
⚠️ **AVOID THESE COMMON ERRORS**:
90+
91+
1. **String Syntax**: Never use unescaped quotes or f-strings in multi-line strings
92+
2. **Variable Scope**: Only use variables that are clearly defined in the current scope
93+
3. **MLX API**: Use `mx.concatenate()`, not `.at[]` syntax (that's JAX, not MLX)
94+
4. **Comments**: Use `#` for Python comments, `//` only inside actual C/C++ code strings
95+
5. **F-strings**: Be very careful with f-strings containing complex expressions
96+
97+
✅ **ALWAYS DO THIS**:
98+
99+
```python
100+
# Good: Simple, clear variable usage
101+
B, H, L, D = q.shape
102+
103+
# Good: MLX-compatible operations
104+
output = mx.concatenate(block_outputs, axis=2)
105+
106+
# Good: Clear variable definitions within scope
107+
block_size = block_info["block_size"]
108+
num_blocks = block_info["num_blocks"]
109+
110+
# Good: Safe string formatting
111+
kernel_source = "// Simple kernel without complex formatting\n"
112+
kernel_source += f"const uint block_size = {block_size};\n"
85113
```
86114
87-
## **SPECIFIC METAL KERNEL OPTIMIZATIONS**
115+
❌ **NEVER DO THIS**:
116+
117+
```python
118+
# Bad: Undefined variables
119+
print(f"Using {n_q_heads} heads") # n_q_heads not defined in this scope!
88120
89-
**Memory Optimization**:
90-
- Use threadgroup memory for frequently accessed data
91-
- Coalesce memory reads/writes across threads
92-
- Minimize global memory access
93-
- Optimize for Apple Silicon unified memory
121+
# Bad: JAX syntax in MLX
122+
output = output.at[:, :, start:end, :].set(block_output) # Wrong framework!
94123
95-
**Computation Optimization**:
96-
- Vectorize operations using SIMD instructions
97-
- Implement efficient softmax computation
98-
- Use fused operations where possible
99-
- Skip zero/masked computations entirely
124+
# Bad: Complex f-strings with quotes
125+
code = f"if (pos < {var}) { print(\"hello\"); }" # Syntax nightmare!
100126
101-
**Thread Organization**:
102-
- Optimal threadgroup sizes for different block sizes
103-
- Efficient work distribution across GPU cores
104-
- Minimize thread divergence
105-
- Balance workload across threadgroups
127+
# Bad: C++ comments in Python
128+
// This is a Python comment # Wrong comment style!
129+
```
106130
107131
## **SUCCESS METRICS**
108132
109133
**Correctness** (Must achieve):
110134
- ✅ 80%+ test pass rate across all scenarios
111-
- ✅ MSE < 1e-3 vs reference implementation
135+
- ✅ MSE < 1e-3 vs reference implementation
112136
- ✅ Handle variable sequence lengths correctly
113137
- ✅ No NaN/Inf in outputs
114138
115139
**Performance** (Optimization targets):
116-
- 🎯 **1.5x+ speedup** over naive masked attention (good)
117-
- 🎯 **2.0x+ speedup** over naive masked attention (excellent)
140+
- 🎯 **1.2x+ speedup** over naive masked attention (good)
141+
- 🎯 **1.5x+ speedup** over naive masked attention (excellent)
142+
- 🎯 **2.0x+ speedup** over naive masked attention (outstanding)
118143
- 🎯 Linear scaling with number of sequences
119-
- 🎯 Efficient memory usage (no explosions)
120-
121-
**Robustness** (Nice to have):
122-
- Handle various block sizes (128, 256, 512, 1024)
123-
- Support different head dimensions (64, 80, 128)
124-
- Work with different batch sizes
125-
- Graceful fallback when Metal kernel fails
144+
- 🎯 Efficient memory usage
126145
127146
## **EVALUATION SCENARIOS**
128147
129148
You'll be tested on:
130149
- **packed_2x256**: Two 256-token sequences packed together
131-
- **packed_4x128**: Four 128-token sequences packed together
150+
- **packed_4x128**: Four 128-token sequences packed together
132151
- **packed_variable**: Variable-length sequences (256 + 512)
133152
- **packed_large**: Large sequences (4x256 = 1024 total)
134153
- **packed_bert_style**: BERT-style training packing
135154
155+
## **IMPLEMENTATION STRATEGY**
156+
157+
**Phase 1: Block Detection**
158+
- Analyze mask patterns to identify block boundaries
159+
- Handle both uniform and variable-length blocks
160+
- Cache analysis results for efficiency
161+
162+
**Phase 2: Optimized Computation**
163+
- Process each block independently with optimized attention
164+
- Use efficient MLX operations within blocks
165+
- Minimize memory allocations and data movement
166+
167+
**Phase 3: Assembly & Output**
168+
- Efficiently combine block outputs
169+
- Ensure correct output shape and dtype
170+
- Handle edge cases gracefully
171+
136172
## **KEY CONSTRAINTS**
137173
138174
**DO NOT CHANGE**:
139175
- Function signature of `evolved_scaled_dot_product_attention`
140-
- Overall structure (detect -> kernel -> fallback)
176+
- Overall structure (detect -> process -> fallback)
141177
- Error handling and fallback mechanisms
142178
143179
**FOCUS ON**:
144-
- Metal kernel source code optimization
145-
- Block detection efficiency
146-
- Memory access patterns
147-
- Thread organization and vectorization
180+
- Block detection efficiency and accuracy
181+
- CPU computation optimization with MLX
182+
- Memory access patterns and data layout
183+
- Algorithmic improvements for block processing
148184
149185
## **EXAMPLE IMPROVEMENTS**
150186
151-
**Better Thread Organization**:
152-
```cpp
153-
// Instead of: one thread per query position
154-
// Try: threadgroup processes entire block cooperatively
155-
```
156-
157-
**Vectorized Operations**:
158-
```cpp
159-
// Instead of: scalar operations
160-
// Try: float4/half4 vector operations
187+
**Better Block Detection**:
188+
```python
189+
# Analyze mask structure more efficiently
190+
# Cache block boundaries for reuse
191+
# Handle edge cases in variable-length sequences
161192
```
162193
163-
**Shared Memory Usage**:
164-
```cpp
165-
// Add: threadgroup shared memory for keys/values
166-
threadgroup float shared_keys[BLOCK_SIZE * HEAD_DIM];
194+
**Optimized Block Processing**:
195+
```python
196+
# Use MLX's optimized operations
197+
# Minimize intermediate allocations
198+
# Process blocks in optimal order
167199
```
168200
169-
**Optimized Softmax**:
170-
```cpp
171-
// Instead of: naive exp/sum
172-
// Try: numerically stable, vectorized softmax
201+
**Memory Efficiency**:
202+
```python
203+
# Avoid unnecessary numpy conversions
204+
# Reuse intermediate tensors where possible
205+
# Optimize data layout for cache efficiency
173206
```
174207
175-
## **DEBUGGING HINTS**
176-
177-
- Start with correctness, then optimize performance
178-
- Test with simple uniform blocks before variable lengths
179-
- Use CPU fallback to verify Metal kernel correctness
180-
- Monitor memory usage and avoid explosions
181-
- Check that block detection is working correctly
182-
183-
Focus on creating a Metal kernel that significantly outperforms naive masking through smart computation skipping and memory optimization!
208+
Remember: Focus on correctness first, then optimize for performance.
209+
Use only MLX operations and avoid complex string formatting that can cause syntax errors!
184210
185211
num_top_programs: 5
186212
num_diverse_programs: 3

examples/mlx_spda_optimization/evaluator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ def evaluate(program_path: str) -> Dict[str, Union[bool, float, str, int]]:
307307
return {
308308
"stage1_passed": False,
309309
"overall_score": 0.0,
310+
"combined_score": 0.0, # Primary metric for OpenEvolve optimization
310311
"error": "MLX not available"
311312
}
312313

@@ -320,6 +321,7 @@ def evaluate(program_path: str) -> Dict[str, Union[bool, float, str, int]]:
320321
return {
321322
"stage1_passed": False,
322323
"overall_score": 0.0,
324+
"combined_score": 0.0, # Primary metric for OpenEvolve optimization
323325
"error": "Missing evolved_scaled_dot_product_attention function"
324326
}
325327

@@ -358,6 +360,7 @@ def evaluate(program_path: str) -> Dict[str, Union[bool, float, str, int]]:
358360
"stage1_passed": False,
359361
"pass_rate": pass_rate,
360362
"overall_score": 0.0,
363+
"combined_score": 0.0, # Primary metric for OpenEvolve optimization
361364
"failed_at": "correctness"
362365
}
363366

@@ -431,6 +434,7 @@ def evaluate(program_path: str) -> Dict[str, Union[bool, float, str, int]]:
431434
"pass_rate": float(pass_rate),
432435
"stage2_score": float(stage2_score),
433436
"overall_score": float(overall_score),
437+
"combined_score": float(overall_score), # Primary metric for OpenEvolve optimization
434438
"avg_speedup": float(avg_speedup),
435439
"max_speedup": float(max_speedup),
436440
"num_tests": len(test_configs),
@@ -443,6 +447,7 @@ def evaluate(program_path: str) -> Dict[str, Union[bool, float, str, int]]:
443447
return {
444448
"stage1_passed": False,
445449
"overall_score": 0.0,
450+
"combined_score": 0.0, # Primary metric for OpenEvolve optimization
446451
"error": str(e)
447452
}
448453

examples/mlx_spda_optimization/initial_program.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -254,32 +254,13 @@ def try_custom_metal_kernel(q, k, v, scale, block_info):
254254
if block_info["type"] != "uniform_blocks":
255255
return None # Only handle uniform blocks for now
256256

257-
# Create custom Metal kernel source code
258-
kernel_source = create_block_diagonal_kernel_source(block_info)
259-
260-
# Compile and execute Metal kernel
261-
kernel = mx.fast.metal_kernel(
262-
name="block_diagonal_attention",
263-
input_names=["queries", "keys", "values", "scale_factor"],
264-
output_names=["attention_output"],
265-
source=kernel_source,
266-
)
267-
268-
# Prepare inputs for kernel
269-
scale_tensor = mx.array([scale], dtype=q.dtype)
270-
271-
# Execute kernel
272-
outputs = kernel(
273-
inputs=[q, k, v, scale_tensor],
274-
template=[
275-
{"name": "T", "value": "float16" if q.dtype == mx.float16 else "float32"},
276-
{"name": "HEAD_DIM", "value": q.shape[-1]},
277-
{"name": "BLOCK_SIZE", "value": block_info["block_size"]},
278-
{"name": "NUM_BLOCKS", "value": block_info["num_blocks"]},
279-
]
280-
)
257+
# For now, disable custom Metal kernel due to API complexity
258+
# Evolution should focus on CPU optimizations first
259+
return None
281260

282-
return outputs["attention_output"]
261+
# TODO: Implement proper Metal kernel when API is stabilized
262+
# The Metal kernel API requires specific grid/threadgroup configurations
263+
# and proper template parameter handling that needs careful tuning
283264

284265
except Exception as e:
285266
# Kernel creation or execution failed

0 commit comments

Comments
 (0)