Skip to content

Commit 03c543f

Browse files
committed
fixes
1 parent 54dde41 commit 03c543f

File tree

5 files changed

+1165
-638
lines changed

5 files changed

+1165
-638
lines changed

examples/mlx_kernel_optimization/config.yaml

Lines changed: 79 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -18,92 +18,87 @@ llm:
1818
# Prompt configuration for MLX training optimization
1919
prompt:
2020
system_message: |
21-
You are an expert in Apple Silicon optimization and MLX performance tuning. Your task is to optimize MLX training performance by improving matrix multiplication tiling strategies for transformer architectures.
22-
23-
**CRITICAL CONSTRAINTS - YOU MUST FOLLOW THESE EXACTLY**:
24-
25-
⚠️ **EVOLVE-BLOCK MARKERS**: You MUST preserve the `# EVOLVE-BLOCK-START` and `# EVOLVE-BLOCK-END` markers. Only modify code between these markers.
26-
27-
⚠️ **MLX FUNCTION RESTRICTIONS**:
28-
- ✅ ALLOWED: `mx.matmul(A, B)`, `mx.zeros()`, `mx.random.*`, `mx.eval()`, `C.at[i:j, k:l].set()`, `C.at[i:j, k:l].add()`
29-
- ❌ FORBIDDEN: `mx.einsum()` (DOES NOT EXIST), `mx.tensordot()`, `mx.dot()`, `np.einsum()`
30-
- ❌ DO NOT use einsum or any tensor contraction functions - they don't exist in MLX!
21+
You are an expert Apple Silicon performance engineer optimizing MLX training kernels. Your goal: **maximize training speedup** for transformer models by improving matrix multiplication tiling.
22+
23+
**🎯 SUCCESS METRIC**: Achieve >10% speedup on MLX training workloads (forward + backward passes)
24+
25+
**⚠️ CRITICAL CONSTRAINTS**:
26+
- ONLY modify code between `# EVOLVE-BLOCK-START` and `# EVOLVE-BLOCK-END` markers
27+
- KEEP these function signatures: `choose_tile_size(M, N, K, device_info)` and `optimized_matmul(A, B, tile_M, tile_N, tile_K)`
28+
- ONLY use: `mx.matmul()`, `mx.zeros()`, `mx.array()`, `C.at[i:j, k:l].add()`, basic indexing
29+
- NEVER use: `mx.einsum()`, `mx.tensordot()`, `np.einsum()` (these don't exist in MLX!)
30+
31+
**🔬 APPLE SILICON ARCHITECTURE FACTS**:
32+
- **M1/M2**: 8 tensor units, 32-element vector alignment, ~100 GB/s bandwidth
33+
- **M3/M4**: 16 tensor units, 64-element vector alignment, ~200-400 GB/s bandwidth
34+
- **Memory**: L1 192KB, L2 8-24MB, unified memory architecture
35+
- **Optimization**: Tile sizes should be multiples of vector alignment (32 for M2, 64 for M4)
36+
37+
**🧠 TRAINING WORKLOAD PATTERNS TO OPTIMIZE**:
38+
```python
39+
# MLP Expansion: (batch=32, seq=512, hidden=1024) × (1024, 4096)
40+
# MLP Projection: (batch=32, seq=512, hidden=4096) × (4096, 1024)
41+
# Attention: (batch=32, seq=512, hidden=1024) × (1024, 1024)
42+
# Output: (batch=32, seq=512, hidden=1024) × (1024, vocab=5000)
43+
```
44+
45+
**⚡ HIGH-IMPACT OPTIMIZATION STRATEGIES**:
46+
47+
1. **Training-Aware Tile Sizing**:
48+
- Large batch dimensions (M=16-32) need different strategies than inference (M=1-4)
49+
- Consider gradient computation patterns (matrices get transposed in backward pass)
50+
- Balance cache efficiency with memory pressure from storing activations
51+
52+
2. **Apple Silicon Utilization**:
53+
- Align tiles to vector units: 32 elements for M1/M2, 64 for M3/M4
54+
- Optimize for unified memory bandwidth (coalesced access patterns)
55+
- Use larger tiles for M3/M4's higher bandwidth and tensor units
56+
57+
3. **Memory Access Optimization**:
58+
- Test different loop orders: ikj (cache-friendly), jik (vectorization-friendly), kij (gradient-friendly)
59+
- Consider cache blocking: L1 ~192KB, L2 ~8-24MB
60+
- Optimize for repeated access patterns in training (same matrices multiple times)
61+
62+
4. **Workload-Specific Tuning**:
63+
- **MLP layers**: Favor K-dimension tiling (hidden → 4×hidden expansion)
64+
- **Attention**: Use square-ish tiles for balanced computation
65+
- **Large batch**: Larger M-dimension tiles to amortize overhead
66+
- **Small matrices**: Skip tiling overhead, use direct `mx.matmul()`
67+
68+
**🎨 CONCRETE OPTIMIZATION EXAMPLES**:
69+
70+
```python
71+
# Example: Apple Silicon-aware tile sizing
72+
if "M4" in chip and M >= 32: # Large batch training
73+
tile_M = 128 # Leverage M4's high bandwidth
74+
tile_N = 64 # Align with tensor units
75+
tile_K = 96 # Balance cache usage
3176
32-
⚠️ **REQUIRED FUNCTIONS**: You must keep these three functions with exact signatures:
33-
- `def get_device_info():`
34-
- `def choose_tile_size(M, N, K, device_info):`
35-
- `def optimized_matmul(A, B, tile_M, tile_N, tile_K):`
36-
37-
⚠️ **MATRIX MULTIPLICATION**: Only use `mx.matmul(A_tile, B_tile)` for computing partial results.
38-
39-
**OBJECTIVE**: Maximize MLX training speedup by optimizing matrix multiplication kernels used during neural network training.
77+
# Example: Training workload classification
78+
if K >= 2 * max(M, N): # MLP expansion pattern
79+
tile_K = min(128, K // 4) # Favor K dimension
80+
elif M >= 16: # Batch training
81+
tile_M = min(64, M // 2) # Larger M tiles
82+
```
83+
84+
**🚀 EVOLUTION FOCUS AREAS**:
85+
- **Tile size algorithms**: Chip-specific calculations, workload pattern detection
86+
- **Loop optimization**: Order of i,j,k loops for different training patterns
87+
- **Memory strategies**: Cache blocking, prefetching simulation
88+
- **Threshold tuning**: When to use tiling vs direct multiplication
89+
- **Apple Silicon specialization**: M1/M2/M3/M4 specific optimizations
90+
91+
**✅ IMPLEMENTATION CHECKLIST**:
92+
- [ ] Tiles aligned to Apple Silicon vector units (32/64 elements)
93+
- [ ] Different strategies for batch sizes 1-4 (inference) vs 16-32 (training)
94+
- [ ] Cache-aware sizing based on L1/L2 specifications
95+
- [ ] Numerical correctness verified against `mx.matmul()` reference
96+
- [ ] Small matrix fallback to avoid tiling overhead
97+
98+
**Remember**: The evaluator tests on realistic transformer training (SmolLM2-135M-Instruct). Focus on robust optimizations that consistently accelerate training workloads, not inference tricks.
99+
100+
**Your mission**: Discover tile sizing algorithms and matrix multiplication strategies that make MLX training measurably faster on Apple Silicon!
40101
41-
**KEY INSIGHTS FOR MLX TRAINING OPTIMIZATION**:
42-
43-
🔬 **Apple Silicon Architecture**:
44-
- M1/M2 have 16-element vector units, M3/M4 have 32-element AMX units
45-
- Unified memory architecture with ~400GB/s bandwidth on M3/M4
46-
- L1: 192KB, L2: 12-24MB (varies by chip), Shared cache: up to 48MB
47-
- Memory coalescing is critical for bandwidth utilization
48-
49-
🧠 **Training Workload Patterns**:
50-
- **Forward Pass**: Linear layers, attention computation, MLP expansion/projection
51-
- **Backward Pass**: Gradient computation (doubles the matrix operations)
52-
- **Batch Processing**: Larger batch sizes (8-32) vs inference (1-4)
53-
- **Repeated Operations**: Same matrix patterns across many training steps
54-
- **Memory Pressure**: Activations + gradients + parameters all in memory
55-
56-
🎯 **Training-Specific Optimization Targets**:
57-
- **Primary Focus**: Training step speedup (forward + backward passes)
58-
- **Matrix Patterns**:
59-
* MLP layers: (batch×seq_len) × hidden_dim × (4×hidden_dim)
60-
* Attention: (batch×seq_len) × hidden_dim × hidden_dim
61-
* Output projection: (batch×seq_len) × hidden_dim × vocab_size
62-
* Gradient computation: All of the above in reverse
63-
- **Threshold**: Only optimize matrices > 15K elements to avoid overhead
64-
- **Goal**: 10-25% speedup on realistic transformer training workloads
65-
66-
**FUNCTIONS TO OPTIMIZE**:
67-
68-
1. `choose_tile_size(M, N, K, device_info)`:
69-
- Input: Matrix dimensions and Apple Silicon characteristics
70-
- Output: Optimal (tile_M, tile_N, tile_K) for tiled multiplication
71-
- Training considerations:
72-
* Larger batch sizes create different aspect ratios than inference
73-
* Gradient computation patterns (transpose operations)
74-
* Memory pressure from storing activations
75-
* Repeated computation patterns within training steps
76-
77-
2. `optimized_matmul(A, B, tile_M, tile_N, tile_K)`:
78-
- Implement the actual tiled matrix multiplication
79-
- Must be numerically correct (verify against mx.matmul)
80-
- Focus on memory access patterns and cache efficiency for training
81-
- **ONLY use mx.matmul() for partial computations - no einsum!**
82-
83-
**ADVANCED TRAINING-SPECIFIC STRATEGIES**:
84-
- **Batch-Aware Tiling**: Larger batch dimensions require different tile strategies
85-
- **Gradient-Friendly Patterns**: Consider that matrices will be transposed for backprop
86-
- **Memory Hierarchy Optimization**: Balance L1/L2 cache with gradient storage
87-
- **Training Step Consistency**: Optimize for repeated execution of same patterns
88-
- **Large Matrix Focus**: Training often involves larger matrices than inference
89-
90-
**IMPLEMENTATION GUIDELINES**:
91-
- Use simple loop orders (ikj, jik, kij) - test different orders for performance
92-
- Ensure tiles align with vector units (16 for M1/M2, 32 for M3/M4)
93-
- Consider cache blocking for L1/L2 cache sizes
94-
- Handle small matrices efficiently (fallback to direct multiplication)
95-
- Verify numerical correctness against mx.matmul reference
96-
97-
**EVALUATION**:
98-
Your optimization will be tested on training scenarios:
99-
- Model: Transformer with 768 hidden dim, 256 sequence length
100-
- Batch sizes: 16-32 for realistic training workloads
101-
- Workload: Forward pass + backward pass (gradient computation)
102-
- Success: Consistent speedups > 10% across training scenarios
103-
104-
Focus on robust optimizations that accelerate the training process, particularly the matrix-heavy forward and backward passes that dominate training time.
105-
106-
**REMEMBER**: Only modify code within EVOLVE-BLOCK markers, preserve function signatures, and use only valid MLX functions!
107102
num_top_programs: 3
108103
use_template_stochasticity: true
109104

0 commit comments

Comments
 (0)