11# MLX LoRA Fine-tuning Optimization Configuration
22# Target: Real LoRA fine-tuning efficiency improvements while maintaining convergence
33
4- max_iterations : 50 # More iterations for breakthrough discoveries
4+ max_iterations : 50
55checkpoint_interval : 5
66log_level : " INFO"
77
@@ -12,187 +12,229 @@ llm:
1212 secondary_model : " gemini-2.5-pro-preview-06-05"
1313 secondary_model_weight : 0.3
1414 api_base : " https://generativelanguage.googleapis.com/v1beta/openai/"
15- temperature : 0.9 # Higher creativity for breakthrough optimizations
15+ temperature : 0.9
1616 top_p : 0.95
1717 max_tokens : 32000
1818 timeout : 600
1919
2020# Detailed prompt for LoRA optimization
2121prompt :
2222 system_message : |
23- You are optimizing MLX LoRA fine-tuning implementations to achieve the same training loss
24- as standard LoRA but with improved memory efficiency and/or training speed.
23+ You are optimizing MLX LoRA fine-tuning kernels to achieve the same training loss
24+ as standard MLX-LM but with improved memory efficiency and/or training speed.
2525
2626 # 🎯 GOAL: Efficient LoRA Fine-tuning with Maintained Convergence
27- Your target is to achieve the SAME training loss as baseline LoRA implementations
27+ Your target is to achieve the SAME training loss as baseline MLX-LM implementations
2828 while providing 10%+ improvements in memory usage and/or training speed.
2929
30- # 🔧 KEY OPTIMIZATION OPPORTUNITIES
30+ # 📋 CURRENT IMPLEMENTATION STRUCTURE
3131
32- **1. LoRA Weight Pre-computation** ⭐ HIGH SUCCESS PROBABILITY
32+ The code has an `evolved_lora_kernels()` function that returns a dictionary with these kernels:
3333 ```python
34- # Standard: 3 separate matrix multiplications per forward pass
35- base_out = x @ base_weight.T
36- lora_a_out = x @ lora_a.T
37- lora_b_out = lora_a_out @ lora_b.T
38- result = base_out + scale * lora_b_out
39-
40- # Target: Pre-compute combined weights when beneficial
41- if not self.training: # During inference
42- fused_weight = base_weight + scale * (lora_b @ lora_a)
43- result = x @ fused_weight.T
34+ return {
35+ "optimized_lora_linear_class": OptimizedLoRALinear,
36+ "optimized_lora_matmul": optimized_lora_matmul,
37+ "optimized_lora_forward_pass": optimized_lora_forward_pass,
38+ "optimized_gradient_computation": optimized_gradient_computation,
39+ "optimized_parameter_update": optimized_parameter_update,
40+ "memory_efficient_loss_computation": memory_efficient_loss_computation,
41+ }
4442 ```
4543
46- **2. Memory-Efficient Gradient Computation**
47- ```python
48- # Standard: Separate gradient computations
49- grad_base = grad_output @ x.T
50- grad_lora_b = grad_output @ lora_a_out.T
51- grad_lora_a = lora_b.T @ grad_output @ x.T
44+ These kernels get injected via `patch_model_with_kernels()` and used during training.
5245
53- # Target: Fused gradient computation to reduce memory allocations
54- # Reuse intermediate tensors, optimize memory access patterns
55- ```
46+ # 🔧 KEY OPTIMIZATION TARGETS IN EVOLVE-BLOCK
5647
57- **3. Training Loop Optimization**
48+ **1. OptimizedLoRALinear Class** ⭐ HIGH IMPACT
5849 ```python
59- # Standard: Separate forward, loss, backward, update steps
60- logits = model(inputs)
61- loss = loss_fn(logits, targets)
62- grads = compute_gradients(loss)
63- optimizer.update(model, grads)
64-
65- # Target: Reduce kernel launches and memory overhead
66- # Optimize for LoRA-specific gradient patterns
50+ class OptimizedLoRALinear(nn.Module):
51+ def __call__(self, x):
52+ base_out = self.base_layer(x)
53+ # CURRENT: Standard LoRA computation
54+ lora_out = mx.matmul(mx.matmul(x, self.lora_a.T), self.lora_b.T)
55+ return base_out + self.scale * lora_out
56+
57+ # EVOLUTION TARGETS:
58+ # - Fuse base + LoRA computation
59+ # - Pre-compute weights during inference
60+ # - Optimize memory access patterns
61+ # - Use mx.compile for hot paths
6762 ```
6863
69- **4. Multi-Layer LoRA Batch Processing**
64+ **2. optimized_lora_matmul Function** ⚡ SPEED TARGET
7065 ```python
71- # Standard: Apply LoRA to layers one by one
72- for layer in layers:
73- layer.q_proj = LoRALinear.from_linear(layer.q_proj)
74- layer.v_proj = LoRALinear.from_linear(layer.v_proj)
75-
76- # Target: Batch LoRA operations across layers
77- # Share computation, optimize memory utilization
66+ @mx.compile
67+ def optimized_lora_matmul(x, lora_a, lora_b, scale):
68+ # CURRENT: Basic compiled matrix multiplication
69+ temp = mx.matmul(x, lora_a.T)
70+ result = mx.matmul(temp, lora_b.T)
71+ return scale * result
72+
73+ # EVOLUTION TARGETS:
74+ # - Fuse matrix operations
75+ # - Optimize for specific tensor shapes
76+ # - Reduce intermediate allocations
77+ # - Vectorize computations
7878 ```
7979
80- **5. Memory-Efficient Loss Computation**
80+ **3. optimized_lora_forward_pass Function** 🚀 INTEGRATION TARGET
8181 ```python
82- # Standard: Full vocabulary materialization
83- loss = cross_entropy(logits, targets) # Memory: O(batch * seq * vocab)
84-
85- # Target: Chunked or online loss computation for large vocabularies
86- # Reduce memory footprint during loss calculation
82+ def optimized_lora_forward_pass(model, x, use_kernels=True):
83+ # CURRENT: Iterates through model layers
84+ for name, layer in model.named_modules():
85+ if hasattr(layer, 'lora_a') and hasattr(layer, 'lora_b'):
86+ # Apply optimized LoRA computation
87+
88+ # EVOLUTION TARGETS:
89+ # - Batch multiple LoRA layers
90+ # - Fuse activations with LoRA
91+ # - Optimize layer traversal
92+ # - Reduce function call overhead
8793 ```
8894
89- **6. UNSLOTH-STYLE MLX KERNEL FUSION ** 🎯 PRIMARY SPEED TARGET
95+ **4. memory_efficient_loss_computation Function ** 💾 MEMORY TARGET
9096 ```python
91- # Standard: Separate operations
92- x = mx.add(input, lora_out)
93- x = activation_fn(x)
94- x = mx.matmul(x, next_weight)
95-
96- # Target: Fused kernels using MLX primitives
97- # Combine LoRA, activation, and next operation
98- # Leverage mx.compile and mx.eval strategically
97+ def memory_efficient_loss_computation(logits, targets, chunk_size=1024):
98+ # CURRENT: Chunked loss for large vocabularies
99+ if logits.shape[-1] <= chunk_size:
100+ return nn.losses.cross_entropy(logits, targets, reduction="mean")
101+ # Process in chunks...
102+
103+ # EVOLUTION TARGETS:
104+ # - Optimize chunk size dynamically
105+ # - Reduce memory allocations
106+ # - Parallelize chunk processing
107+ # - Smart caching strategies
99108 ```
100109
101- **7. Smart Gradient Accumulation **
110+ **5. optimized_gradient_computation Function ** 🧠 GRADIENT TARGET
102111 ```python
103- # Standard: Individual gradient updates
104- for batch in batches:
105- loss = forward(batch)
106- grads = backward(loss)
107- optimizer.update(grads)
108-
109- # Target: Accumulated updates with reduced sync points
110- # Batch multiple LoRA layer updates together
112+ def optimized_gradient_computation(loss, model, use_kernels=True):
113+ # CURRENT: Basic compiled gradient computation
114+ compiled_grad_fn = mx.compile(mx.grad(grad_fn))
115+ return compiled_grad_fn(model)
116+
117+ # EVOLUTION TARGETS:
118+ # - LoRA-specific gradient patterns
119+ # - Accumulate gradients efficiently
120+ # - Reduce gradient computation overhead
121+ # - Smart gradient sharing
111122 ```
112123
113- # 🚀 UNSLOTH-INSPIRED OPTIMIZATION TECHNIQUES (Target 2x+ Speed Improvements)
124+ **6. optimized_parameter_update Function** 🔄 UPDATE TARGET
125+ ```python
126+ @mx.compile
127+ def optimized_parameter_update(params, grads, lr):
128+ # CURRENT: Basic parameter update loop
129+ for key in params:
130+ if key in grads:
131+ updated_params[key] = params[key] - lr * grads[key]
132+
133+ # EVOLUTION TARGETS:
134+ # - Batch parameter updates
135+ # - Vectorize updates
136+ # - Optimize for LoRA structure
137+ # - Reduce synchronization points
138+ ```
114139
115- **🔥 Flash Attention Equivalents for MLX**: Fused attention computation patterns
116- **⚡ Kernel Fusion**: Combine LoRA operations with activation functions
117- **🧠 Smart Gradient Accumulation**: Batch gradient updates efficiently
118- **⭐ Optimized MLX Operations**: Leverage mx.fast for critical paths
119- **🚀 Parameter-Efficient Updates**: Minimize optimizer state overhead
120- **💾 Memory Mapping**: Efficient tensor reuse and allocation patterns
121- **🎯 Selective Computation**: Skip unnecessary ops based on LoRA rank/scale
122- **🔧 Mixed Precision**: Smart FP16/FP32 usage for speed without loss
140+ # 🚀 PROVEN MLX OPTIMIZATION TECHNIQUES
123141
124- Current baseline shows 1.57x memory improvement but only 1.01x speed.
125- FOCUS: Discover speed optimizations like unsloth's 2-5x improvements!
142+ **🔥 mx.compile Usage**: Leverage @mx.compile for hot computation paths
143+ **⚡ Tensor Fusion**: Combine multiple operations into single kernels
144+ **🧠 Memory Reuse**: Optimize tensor allocation and reuse patterns
145+ **⭐ Vectorization**: Use MLX's SIMD capabilities effectively
146+ **🚀 Batch Operations**: Process multiple items simultaneously
147+ **💾 Smart Caching**: Cache computed values when beneficial
148+ **🎯 Shape Optimization**: Optimize for common tensor shapes
149+ **🔧 Pipeline Efficiency**: Reduce data movement and sync points
126150
127151 # 📊 SUCCESS METRICS
128152
129153 **Primary Metric**: Training Loss Convergence (MUST MATCH BASELINE ±1%)
130- - Target: Same final loss as standard LoRA implementation
154+ - Target: Same final loss as standard MLX-LM LoRA implementation
131155 - Critical: Maintain numerical stability and gradient flow
132156
133157 **Secondary Metrics**: Efficiency Improvements
134158 - Memory efficiency: 10%+ reduction in peak memory usage
135159 - Training speed: 10%+ improvement in tokens/second
160+ - Time efficiency: 10%+ reduction in training time
136161 - Ideal: Both memory AND speed improvements
137162
138- # 🎖️ REAL-WORLD LORA OPTIMIZATION PATTERNS
163+ # 🎖️ REALISTIC OPTIMIZATION EXPECTATIONS
139164
140165 Successful LoRA optimizations typically achieve:
141- - **Memory reduction**: 15 -30% through weight fusion and gradient optimization
142- - **Speed improvement**: 10-25 % through reduced kernel launches and better memory access
143- - **Maintained convergence**: Critical for practical adoption
166+ - **Memory reduction**: 10 -30% through smart tensor management
167+ - **Speed improvement**: 15-50 % through kernel fusion and compilation
168+ - **Maintained convergence**: Essential for practical adoption
144169
145- Your optimizations should target similar patterns adapted for MLX.
170+ Your optimizations should target these realistic improvements for MLX.
146171
147172 # 🚫 CONSTRAINTS
148- - Keep exact function signatures from initial_program.py
149- - Maintain numerical correctness (loss must match baseline within 0.01 )
173+ - Keep exact function signatures and return values
174+ - Maintain numerical correctness (loss must match baseline within 1% )
150175 - Support all LoRA configs (ranks 8-64, any scale/dropout)
151176 - MLX-only dependencies (mx.core, mx.nn, mx.optimizers)
152- - 🚨 CRITICAL: Concise evolution changes (under 35,000 chars total)
153- - NO verbose comments - focus on algorithmic improvements
154- - Prioritize SPEED over memory (we already have 1.57x memory gain)
155- - Test mx.compile, mx.eval, kernel fusion, gradient accumulation patterns
156-
157- # 🔍 WHAT TO EVOLVE - TARGET UNSLOTH-STYLE 2x+ SPEED GAINS
158-
159- Focus on `evolved_lora_kernels` function. Prioritize SPEED optimizations:
160-
161- 1. **optimized_lora_fine_tuning**: Main training pipeline with kernel fusion
162- 2. **optimized_training_loop**: Batch gradient accumulation like unsloth
163- 3. **optimized_train_step**: Fused forward/backward with mx.compile
164- 4. **optimized_linear_to_lora_layers**: Batched multi-layer LoRA application
165- 5. **optimized_evaluate**: Fast inference with weight pre-computation
166-
167- 🎯 PRIMARY TARGETS FOR SPEED BREAKTHROUGH:
168- - Leverage `mx.compile()` for hot paths (like unsloth's kernel compilation)
169- - Use `mx.eval()` strategically to minimize sync points
170- - Batch operations across multiple LoRA layers simultaneously
171- - Pre-compute weights when beneficial (inference mode optimization)
172- - Implement gradient accumulation patterns that reduce memory allocations
173-
174- Current Results: 1.57x memory ✅, 1.01x speed ❌
175- Target: Discover 2-5x speed improvements while maintaining perfect convergence!
177+ - 🚨 CRITICAL: Concise evolution changes (under 30,000 chars total)
178+ - Focus on algorithmic improvements, not verbose comments
179+ - Ensure kernels can be properly patched into models
180+ - Test optimizations work with real MLX-LM training
181+
182+ # 🔍 WHAT TO EVOLVE - FOCUS ON EVOLVE-BLOCK
183+
184+ **Primary Evolution Target: `evolved_lora_kernels()` function**
185+
186+ The EVOLVE-BLOCK contains 6 kernels that get injected into MLX-LM training:
187+
188+ 1. **OptimizedLoRALinear**: The core LoRA layer implementation
189+ 2. **optimized_lora_matmul**: Compiled matrix multiplication kernel
190+ 3. **optimized_lora_forward_pass**: Model forward pass optimization
191+ 4. **optimized_gradient_computation**: Gradient computation optimization
192+ 5. **optimized_parameter_update**: Parameter update optimization
193+ 6. **memory_efficient_loss_computation**: Loss computation optimization
194+
195+ 🎯 **PRIMARY OPTIMIZATION STRATEGIES:**
196+ - Add more @mx.compile decorators for hot paths
197+ - Fuse multiple operations into single kernels
198+ - Optimize memory access patterns and reuse
199+ - Batch operations across multiple LoRA layers
200+ - Pre-compute values when beneficial (inference optimization)
201+ - Implement LoRA-specific optimizations based on mathematical properties
202+ - Reduce intermediate tensor allocations
203+ - Optimize for common LoRA configurations (rank 8-64)
204+
205+ 🔬 **CURRENT STATUS:** Starting from basic working implementations
206+ **TARGET:** Achieve 15-25% efficiency improvements while maintaining convergence
207+
208+ # ⚠️ CRITICAL EVOLUTION GUIDELINES
209+
210+ 1. **ALWAYS preserve function signatures** - the patching system depends on them
211+ 2. **Test numerical correctness** - loss must converge to same value as baseline
212+ 3. **Use MLX primitives effectively** - leverage mx.compile, mx.eval, etc.
213+ 4. **Focus on realistic optimizations** - don't over-engineer
214+ 5. **Maintain code clarity** - optimizations should be understandable
215+ 6. **Ensure kernel injection works** - test that patches apply correctly
216+
217+ **Evolution Success = Same Loss + Better Performance + Working Integration**
176218
177219 num_top_programs : 6
178220 num_diverse_programs : 4
179221
180222# Database configuration for LoRA optimization
181223database :
182224 db_path : " ./openevolve_output/program_db"
183- population_size : 80 # Larger population for more diverse explorations
225+ population_size : 80
184226 archive_size : 40
185227 num_islands : 4
186- elite_selection_ratio : 0.20 # Less elite pressure, more exploration
187- exploitation_ratio : 0.6 # Balanced exploration for breakthroughs
228+ elite_selection_ratio : 0.20
229+ exploitation_ratio : 0.6
188230 exploration_ratio : 0.4
189231
190232# Evaluator configuration
191233evaluator :
192- timeout : 1200 # Longer timeout for real LoRA training
234+ timeout : 1200
193235 parallel_evaluations : 1
194236
195237# Evolution settings
196238diff_based_evolution : true
197239allow_full_rewrites : false
198- max_code_length : 45000 # Encourage concise, focused optimizations
240+ max_code_length : 45000
0 commit comments