22MLX LoRA Fine-tuning Optimization Evaluator
33
44This evaluator performs real LoRA fine-tuning benchmarks using the mlx-lm library,
5- comparing evolved implementations against standard MLX-LM LoRA implementations .
5+ comparing standard MLX-LM against MLX-LM with evolved kernels injected .
66The goal is to achieve the same training loss with improved memory efficiency and/or speed.
77"""
88
@@ -62,8 +62,8 @@ def clear_mlx_cache_and_gc():
6262
6363class MLXLoRABenchmark :
6464 """
65- Benchmark for comparing MLX-LM LoRA fine-tuning implementations .
66- Measures training loss convergence, speed, and memory usage using real mlx-lm .
65+ Benchmark for comparing standard MLX-LM vs MLX-LM with evolved kernels .
66+ Uses proper sequential evaluation to avoid monkey patching interference .
6767 """
6868
6969 def __init__ (self , model_name : str = "mlx-community/Qwen2.5-0.5B-Instruct-4bit" ):
@@ -78,13 +78,6 @@ def cleanup(self):
7878 except :
7979 pass
8080 self .temp_dirs .clear ()
81-
82- # Also run general cleanup
83- try :
84- from cleanup import cleanup_temp_files
85- cleanup_temp_files ()
86- except ImportError :
87- pass
8881
8982 def create_test_config (self , data_dir : str , adapter_dir : str ) -> Dict [str , Any ]:
9083 """Create test configuration for LoRA fine-tuning with all MLX-LM expected attributes."""
@@ -96,21 +89,21 @@ def create_test_config(self, data_dir: str, adapter_dir: str) -> Dict[str, Any]:
9689 "optimizer_config" : {"adam" : {}},
9790 "data" : data_dir ,
9891 "seed" : 42 ,
99- "num_layers" : 2 , # Small for fast testing
100- "batch_size" : 1 , # Small for memory efficiency
101- "iters" : 10 , # More iterations for larger dataset
102- "val_batches" : 5 ,
92+ "num_layers" : 4 , # More layers for comprehensive evaluation
93+ "batch_size" : 2 , # Reasonable batch size for larger dataset
94+ "iters" : 25 , # More iterations for larger dataset
95+ "val_batches" : 10 ,
10396 "learning_rate" : 1e-4 ,
104- "steps_per_report" : 5 ,
105- "steps_per_eval" : 20 ,
97+ "steps_per_report" : 10 ,
98+ "steps_per_eval" : 50 ,
10699 "adapter_path" : adapter_dir ,
107100 "save_every" : 100 ,
108- "max_seq_length" : 256 , # Shorter sequences
109- "lora_parameters" : {"rank" : 8 , "dropout" : 0.0 , "scale" : 16.0 }, # Smaller rank
101+ "max_seq_length" : 512 , # Full sequence length
102+ "lora_parameters" : {"rank" : 16 , "dropout" : 0.0 , "scale" : 16.0 }, # Standard rank
110103 "mask_prompt" : False ,
111104 # Additional MLX-LM expected attributes
112105 "test" : True ,
113- "test_batches" : 5 ,
106+ "test_batches" : 10 ,
114107 "resume_adapter_file" : None ,
115108 "config" : None ,
116109 "grad_checkpoint" : False ,
@@ -121,90 +114,130 @@ def create_test_config(self, data_dir: str, adapter_dir: str) -> Dict[str, Any]:
121114 def compare_implementations (
122115 self ,
123116 evolved_kernels : Dict ,
124- num_trials : int = 1
117+ num_trials : int = 3
125118 ) -> Dict [str , Any ]:
126- """Compare standard MLX-LM vs MLX-LM with evolved kernels injected."""
119+ """
120+ Compare standard MLX-LM vs MLX-LM with evolved kernels.
121+
122+ PROPER EVALUATION STRUCTURE:
123+ 1. Run ALL baseline trials first (no patching)
124+ 2. Calculate baseline metrics
125+ 3. Apply evolved kernels patching ONCE
126+ 4. Run ALL evolved trials
127+ 5. Calculate evolved metrics
128+ 6. Compare results
129+
130+ This avoids monkey patching interference between trials.
131+ """
127132
128133 if not MLX_LM_AVAILABLE :
129134 return {"error" : "MLX-LM not available for real benchmarking" }
130135
131- print (f"\n 📊 MLX-LM LORA FINE-TUNING COMPARISON" )
136+ print (f"\n 📊 MLX-LM LORA KERNEL COMPARISON" )
132137 print (f" Model: { self .model_name } " )
133- print (f" Trials: { num_trials } " )
138+ print (f" Trials per implementation: { num_trials } " )
139+ print (f" Evaluation strategy: Sequential (baseline first, then evolved)" )
134140
135- results = {
136- 'baseline' : [],
137- 'evolved' : []
138- }
141+ baseline_results = []
142+ evolved_results = []
143+
144+ # ========================================
145+ # PHASE 1: Run ALL baseline trials first
146+ # ========================================
147+ print (f"\n 🔬 PHASE 1: Running { num_trials } BASELINE trials (standard MLX-LM)" )
139148
140149 for trial in range (num_trials ):
141- print (f"\n --- Trial { trial + 1 } /{ num_trials } ---" )
150+ print (f"\n --- Baseline Trial { trial + 1 } /{ num_trials } ---" )
142151
143152 # Create temporary directories for this trial
144153 baseline_data_dir = tempfile .mkdtemp (prefix = "baseline_data_" )
145154 baseline_adapter_dir = tempfile .mkdtemp (prefix = "baseline_adapters_" )
146- evolved_data_dir = tempfile .mkdtemp (prefix = "evolved_data_" )
147- evolved_adapter_dir = tempfile .mkdtemp (prefix = "evolved_adapters_" )
148-
149- self .temp_dirs .extend ([
150- baseline_data_dir , baseline_adapter_dir ,
151- evolved_data_dir , evolved_adapter_dir
152- ])
155+ self .temp_dirs .extend ([baseline_data_dir , baseline_adapter_dir ])
153156
154- # Test baseline implementation (standard MLX-LM)
155157 try :
156- print ("🔬 Testing BASELINE implementation (standard MLX-LM)..." )
157-
158158 # Create test dataset
159159 self ._create_test_dataset (baseline_data_dir )
160160 baseline_config = self .create_test_config (baseline_data_dir , baseline_adapter_dir )
161161
162162 clear_mlx_cache_and_gc ()
163- baseline_result = self ._run_lora_benchmark_with_kernels (
163+
164+ # Run baseline (standard MLX-LM)
165+ baseline_result = self ._run_single_trial (
164166 baseline_config ,
165- "BASELINE" ,
166- evolved_kernels = None # No evolved kernels = standard MLX-LM
167+ f "BASELINE- { trial + 1 } " ,
168+ evolved_kernels = None # No kernels = standard MLX-LM
167169 )
168- results ['baseline' ].append (baseline_result )
170+ baseline_results .append (baseline_result )
171+
172+ # Early exit if first baseline trial fails
173+ if trial == 0 and 'error' in baseline_result :
174+ print (" 🚨 First baseline trial failed - stopping evaluation" )
175+ return {"error" : f"First baseline trial failed: { baseline_result ['error' ]} " }
169176
170177 except Exception as e :
171- print (f" ❌ Baseline trial failed: { e } " )
172- results ['baseline' ].append ({"error" : str (e )})
173- # FAIL FAST: If first trial fails, don't continue
178+ print (f" ❌ Baseline trial { trial + 1 } failed: { e } " )
179+ baseline_results .append ({"error" : str (e )})
180+
181+ # Early exit if first trial fails
174182 if trial == 0 :
175- print (" 🚨 First trial failed - stopping evaluation early" )
176- return {"error" : f"First trial failed: { e } " }
183+ print (" 🚨 First baseline trial failed - stopping evaluation" )
184+ return {"error" : f"First baseline trial failed: { e } " }
185+
186+ # ========================================
187+ # PHASE 2: Run ALL evolved trials
188+ # ========================================
189+ print (f"\n 🚀 PHASE 2: Running { num_trials } EVOLVED trials (MLX-LM + evolved kernels)" )
190+
191+ for trial in range (num_trials ):
192+ print (f"\n --- Evolved Trial { trial + 1 } /{ num_trials } ---" )
193+
194+ # Create temporary directories for this trial
195+ evolved_data_dir = tempfile .mkdtemp (prefix = "evolved_data_" )
196+ evolved_adapter_dir = tempfile .mkdtemp (prefix = "evolved_adapters_" )
197+ self .temp_dirs .extend ([evolved_data_dir , evolved_adapter_dir ])
177198
178- # Test evolved implementation
179199 try :
180- print ("🚀 Testing EVOLVED implementation..." )
181-
182200 # Create test dataset (same as baseline)
183201 self ._create_test_dataset (evolved_data_dir )
184202 evolved_config = self .create_test_config (evolved_data_dir , evolved_adapter_dir )
185203
186204 clear_mlx_cache_and_gc ()
187- evolved_result = self ._run_lora_benchmark_with_kernels (
205+
206+ # Run evolved (MLX-LM + evolved kernels)
207+ evolved_result = self ._run_single_trial (
188208 evolved_config ,
189- "EVOLVED" ,
209+ f "EVOLVED- { trial + 1 } " ,
190210 evolved_kernels = evolved_kernels # Inject evolved kernels
191211 )
192- results ['evolved' ].append (evolved_result )
212+ evolved_results .append (evolved_result )
213+
214+ # Early exit if first evolved trial fails
215+ if trial == 0 and 'error' in evolved_result :
216+ print (" 🚨 First evolved trial failed - stopping evaluation" )
217+ return {"error" : f"First evolved trial failed: { evolved_result ['error' ]} " }
193218
194219 except Exception as e :
195- print (f" ❌ Evolved trial failed: { e } " )
196- results ['evolved' ].append ({"error" : str (e )})
197- # FAIL FAST: If first trial fails, don't continue
220+ print (f" ❌ Evolved trial { trial + 1 } failed: { e } " )
221+ evolved_results .append ({"error" : str (e )})
222+
223+ # Early exit if first trial fails
198224 if trial == 0 :
199- print (" 🚨 First trial failed - stopping evaluation early " )
200- return {"error" : f"First trial failed: { e } " }
225+ print (" 🚨 First evolved trial failed - stopping evaluation" )
226+ return {"error" : f"First evolved trial failed: { e } " }
201227
202- # Cleanup after all trials
228+ # ========================================
229+ # PHASE 3: Analyze and compare results
230+ # ========================================
203231 self .cleanup ()
204232
233+ results = {
234+ 'baseline' : baseline_results ,
235+ 'evolved' : evolved_results
236+ }
237+
205238 return self ._analyze_results (results )
206239
207- def _create_test_dataset (self , output_dir : str , num_samples : int = 300 ):
240+ def _create_test_dataset (self , output_dir : str , num_samples : int = 500 ):
208241 """Create a comprehensive test dataset for LoRA fine-tuning with diverse examples."""
209242 examples = [
210243 # AI and Machine Learning
@@ -378,47 +411,42 @@ def _create_test_dataset(self, output_dir: str, num_samples: int = 300):
378411 {"text" : "What is dietetics?\n Dietetics applies nutrition science to promote health and treat disease through proper food and eating habits." },
379412 ]
380413
381- # Ensure we have enough diverse examples
414+ # Use smaller dataset for faster evaluation
382415 if num_samples > len (examples ):
383- # Cycle through examples to reach desired number
384416 dataset = []
385417 for i in range (num_samples ):
386418 dataset .append (examples [i % len (examples )])
387419 else :
388- # Use subset if we have more examples than needed
389420 dataset = examples [:num_samples ]
390421
391- # Create balanced splits with sufficient validation data
422+ # Create balanced splits with minimum sizes
392423 train_size = max (10 , int (0.7 * num_samples ))
393424 val_size = max (5 , int (0.2 * num_samples ))
394- test_size = num_samples - train_size - val_size
395- if test_size < 3 :
396- test_size = 3
397- val_size = num_samples - train_size - test_size
425+ test_size = max (3 , num_samples - train_size - val_size )
398426
399427 train_data = dataset [:train_size ]
400428 val_data = dataset [train_size :train_size + val_size ]
401429 test_data = dataset [train_size + val_size :train_size + val_size + test_size ]
402430
403- print (f"📊 Creating comprehensive dataset : { len (train_data )} train, { len (val_data )} valid, { len (test_data )} test examples" )
431+ print (f"📊 Dataset : { len (train_data )} train, { len (val_data )} valid, { len (test_data )} test examples" )
404432
405- # Write datasets - CRITICAL: Use "valid" not "val" for MLX-LM
433+ # Write datasets - Use "valid" not "val" for MLX-LM
406434 os .makedirs (output_dir , exist_ok = True )
407435 for split , data in [("train" , train_data ), ("valid" , val_data ), ("test" , test_data )]:
408436 file_path = os .path .join (output_dir , f"{ split } .jsonl" )
409437 with open (file_path , "w" ) as f :
410438 for example in data :
411439 f .write (json .dumps (example ) + "\n " )
412440
413- def _run_lora_benchmark_with_kernels (
441+ def _run_single_trial (
414442 self ,
415443 config : Dict [str , Any ],
416- implementation_name : str ,
444+ trial_name : str ,
417445 evolved_kernels : Optional [Dict ] = None
418446 ) -> Dict [str , Union [float , str ]]:
419- """Run LoRA fine-tuning benchmark with optional evolved kernel injection ."""
447+ """Run a single LoRA fine-tuning trial ."""
420448
421- print (f" 🧪 Running { implementation_name } LoRA fine-tuning ..." )
449+ print (f" 🧪 Running { trial_name } ..." )
422450
423451 try :
424452 # Memory before
@@ -452,14 +480,14 @@ def _run_lora_benchmark_with_kernels(
452480 # Extract additional metrics
453481 training_time = metrics .get ('training_time' , total_time )
454482
455- # Calculate approximate tokens/second (rough estimate)
483+ # Calculate approximate tokens/second
456484 estimated_tokens = config ['iters' ] * config ['batch_size' ] * config ['max_seq_length' ]
457485 tokens_per_second = estimated_tokens / training_time if training_time > 0 else 0
458486
459487 print (f" Final loss: { final_loss :.4f} " )
460488 print (f" Training time: { training_time :.2f} s" )
461489 print (f" Memory delta: { memory_delta :.1f} MB" )
462- print (f" Used evolved kernels: { evolved_kernels is not None } " )
490+ print (f" Evolved kernels: { evolved_kernels is not None } " )
463491
464492 return {
465493 'final_loss' : float (final_loss ),
@@ -538,13 +566,16 @@ def _analyze_results(self, results: Dict[str, List[Dict]]) -> Dict[str, Any]:
538566
539567def evaluate (program_path : str ) -> Dict [str , Union [bool , float , str , int ]]:
540568 """
541- Evaluate MLX-LM LoRA fine-tuning optimization program.
569+ Evaluate MLX-LM LoRA kernel optimization program.
570+
571+ Uses sequential evaluation approach:
572+ 1. Run ALL baseline trials (standard MLX-LM)
573+ 2. Run ALL evolved trials (MLX-LM + evolved kernels)
574+ 3. Compare results
542575
543- Performs real LoRA fine-tuning comparison using mlx-lm library between
544- baseline and evolved implementations. Success metric: achieve same training
545- loss with efficiency improvements.
576+ This avoids monkey patching interference between trials.
546577 """
547- print (f"🚀 Evaluating MLX-LM LoRA Fine-tuning Optimization: { program_path } " )
578+ print (f"🚀 Evaluating MLX LoRA Kernel Optimization: { program_path } " )
548579
549580 if not MLX_LM_AVAILABLE :
550581 return {
@@ -575,15 +606,15 @@ def evaluate(program_path: str) -> Dict[str, Union[bool, float, str, int]]:
575606 baseline_kernels = evolved_program .baseline_lora_kernels () # Returns None
576607
577608 print (f"✅ Evolved kernels loaded: { list (evolved_kernels .keys ())} " )
578- print (f"✅ Baseline kernels: { baseline_kernels } (standard MLX-LM)" )
609+ print (f"✅ Baseline: Standard MLX-LM (no custom kernels )" )
579610
580611 # Setup benchmark
581612 benchmark = MLXLoRABenchmark ()
582613
583- # Run comparison
614+ # Run sequential comparison (baseline first, then evolved)
584615 comparison_results = benchmark .compare_implementations (
585616 evolved_kernels = evolved_kernels ,
586- num_trials = 5
617+ num_trials = 5
587618 )
588619
589620 if 'error' in comparison_results :
@@ -606,7 +637,7 @@ def evaluate(program_path: str) -> Dict[str, Union[bool, float, str, int]]:
606637 baseline_avg = comparison_results ['baseline_avg' ]
607638 evolved_avg = comparison_results ['evolved_avg' ]
608639
609- print (f"\n 📊 MLX-LM LORA FINE-TUNING OPTIMIZATION RESULTS:" )
640+ print (f"\n 📊 MLX LORA KERNEL OPTIMIZATION RESULTS:" )
610641 print (f" Loss Convergence: { '✅' if loss_convergence_ok else '❌' } (diff: { loss_difference :.4f} )" )
611642 print (f" Speed Improvement: { speed_improvement :.2f} x" )
612643 print (f" Memory Improvement: { memory_improvement :.2f} x" )
@@ -664,7 +695,7 @@ def evaluate(program_path: str) -> Dict[str, Union[bool, float, str, int]]:
664695 "successful_evolved_trials" : comparison_results ['successful_trials' ]['evolved' ],
665696
666697 # Metadata
667- "evaluation_type" : "mlx_lm_lora_finetuning " ,
698+ "evaluation_type" : "mlx_lora_kernel_optimization " ,
668699 "achieves_convergence" : bool (loss_convergence_ok ),
669700 "has_efficiency_improvements" : bool (speed_improvement > 1.05 or memory_improvement > 1.05 ),
670701 "target_achieved" : bool (loss_convergence_ok and (speed_improvement > 1.1 or memory_improvement > 1.1 )),
@@ -683,7 +714,7 @@ def evaluate(program_path: str) -> Dict[str, Union[bool, float, str, int]]:
683714
684715
685716if __name__ == "__main__" :
686- print ("Testing MLX-LM LoRA Fine-tuning Optimization Evaluator..." )
717+ print ("Testing MLX LoRA Kernel Optimization Evaluator..." )
687718
688719 initial_program_path = os .path .join (os .path .dirname (__file__ ), "initial_program.py" )
689720
0 commit comments