@@ -120,11 +120,10 @@ def create_test_config(self, data_dir: str, adapter_dir: str) -> Dict[str, Any]:
120120
121121 def compare_implementations (
122122 self ,
123- baseline_kernels : Dict ,
124123 evolved_kernels : Dict ,
125- num_trials : int = 1
124+ num_trials : int = 1
126125 ) -> Dict [str , Any ]:
127- """Compare baseline vs evolved LoRA implementations using real mlx-lm ."""
126+ """Compare standard MLX-LM vs MLX-LM with evolved kernels injected ."""
128127
129128 if not MLX_LM_AVAILABLE :
130129 return {"error" : "MLX-LM not available for real benchmarking" }
@@ -152,19 +151,19 @@ def compare_implementations(
152151 evolved_data_dir , evolved_adapter_dir
153152 ])
154153
155- # Test baseline implementation
154+ # Test baseline implementation (standard MLX-LM)
156155 try :
157- print ("🔬 Testing BASELINE implementation..." )
156+ print ("🔬 Testing BASELINE implementation (standard MLX-LM) ..." )
158157
159158 # Create test dataset
160159 self ._create_test_dataset (baseline_data_dir )
161160 baseline_config = self .create_test_config (baseline_data_dir , baseline_adapter_dir )
162161
163162 clear_mlx_cache_and_gc ()
164- baseline_result = self ._run_lora_benchmark (
165- baseline_kernels ['optimized_lora_fine_tuning' ],
163+ baseline_result = self ._run_lora_benchmark_with_kernels (
166164 baseline_config ,
167- "BASELINE"
165+ "BASELINE" ,
166+ evolved_kernels = None # No evolved kernels = standard MLX-LM
168167 )
169168 results ['baseline' ].append (baseline_result )
170169
@@ -185,10 +184,10 @@ def compare_implementations(
185184 evolved_config = self .create_test_config (evolved_data_dir , evolved_adapter_dir )
186185
187186 clear_mlx_cache_and_gc ()
188- evolved_result = self ._run_lora_benchmark (
189- evolved_kernels ['optimized_lora_fine_tuning' ],
187+ evolved_result = self ._run_lora_benchmark_with_kernels (
190188 evolved_config ,
191- "EVOLVED"
189+ "EVOLVED" ,
190+ evolved_kernels = evolved_kernels # Inject evolved kernels
192191 )
193192 results ['evolved' ].append (evolved_result )
194193
@@ -411,13 +410,13 @@ def _create_test_dataset(self, output_dir: str, num_samples: int = 300):
411410 for example in data :
412411 f .write (json .dumps (example ) + "\n " )
413412
414- def _run_lora_benchmark (
413+ def _run_lora_benchmark_with_kernels (
415414 self ,
416- lora_fine_tuning_fn ,
417415 config : Dict [str , Any ],
418- implementation_name : str
416+ implementation_name : str ,
417+ evolved_kernels : Optional [Dict ] = None
419418 ) -> Dict [str , Union [float , str ]]:
420- """Run LoRA fine-tuning benchmark."""
419+ """Run LoRA fine-tuning benchmark with optional evolved kernel injection ."""
421420
422421 print (f" 🧪 Running { implementation_name } LoRA fine-tuning..." )
423422
@@ -426,12 +425,21 @@ def _run_lora_benchmark(
426425 memory_before = get_memory_usage ()
427426 start_time = time .perf_counter ()
428427
429- # Run LoRA fine-tuning
430- final_loss , metrics = lora_fine_tuning_fn (
428+ # Import and run the training function
429+ import sys
430+ import os
431+ current_dir = os .path .dirname (os .path .abspath (__file__ ))
432+ sys .path .insert (0 , current_dir )
433+
434+ from initial_program import standard_lora_fine_tuning_with_kernels
435+
436+ # Run training with or without evolved kernels
437+ final_loss , metrics = standard_lora_fine_tuning_with_kernels (
431438 model_name = config ['model' ],
432439 train_data_path = config ['data' ],
433440 config = config ,
434- adapter_save_path = config ['adapter_path' ]
441+ adapter_save_path = config ['adapter_path' ],
442+ evolved_kernels = evolved_kernels
435443 )
436444
437445 # Timing and memory
@@ -451,6 +459,7 @@ def _run_lora_benchmark(
451459 print (f" Final loss: { final_loss :.4f} " )
452460 print (f" Training time: { training_time :.2f} s" )
453461 print (f" Memory delta: { memory_delta :.1f} MB" )
462+ print (f" Used evolved kernels: { evolved_kernels is not None } " )
454463
455464 return {
456465 'final_loss' : float (final_loss ),
@@ -561,28 +570,20 @@ def evaluate(program_path: str) -> Dict[str, Union[bool, float, str, int]]:
561570 "error" : "Missing baseline_lora_kernels function"
562571 }
563572
564- # Get LoRA implementations
573+ # Get evolved kernels
565574 evolved_kernels = evolved_program .evolved_lora_kernels ()
566- baseline_kernels = evolved_program .baseline_lora_kernels ()
567-
568- # Check required kernels
569- required_key = 'optimized_lora_fine_tuning'
570- if required_key not in evolved_kernels or required_key not in baseline_kernels :
571- return {
572- "overall_score" : 0.0 ,
573- "error" : f"Missing kernel: { required_key } "
574- }
575+ baseline_kernels = evolved_program .baseline_lora_kernels () # Returns None
575576
576- print (f"✅ LoRA implementations loaded successfully" )
577+ print (f"✅ Evolved kernels loaded: { list (evolved_kernels .keys ())} " )
578+ print (f"✅ Baseline kernels: { baseline_kernels } (standard MLX-LM)" )
577579
578580 # Setup benchmark
579581 benchmark = MLXLoRABenchmark ()
580582
581583 # Run comparison
582584 comparison_results = benchmark .compare_implementations (
583- baseline_kernels = baseline_kernels ,
584585 evolved_kernels = evolved_kernels ,
585- num_trials = 5
586+ num_trials = 5
586587 )
587588
588589 if 'error' in comparison_results :
0 commit comments