@@ -134,6 +134,7 @@ def compare_implementations(self, evolved_kernels: Dict, num_trials: int = 3) ->
134134 print (f" Model: { self .model_name } " )
135135 print (f" Trials per implementation: { num_trials } " )
136136 print (f" Evaluation strategy: Sequential (baseline first, then evolved)" )
137+ print (f" Evolved kernels available: { list (evolved_kernels .keys ()) if evolved_kernels else 'None' } " )
137138
138139 baseline_results = []
139140 evolved_results = []
@@ -184,6 +185,15 @@ def compare_implementations(self, evolved_kernels: Dict, num_trials: int = 3) ->
184185 # PHASE 2: Run ALL evolved trials
185186 # ========================================
186187 print (f"\n 🚀 PHASE 2: Running { num_trials } EVOLVED trials (MLX-LM + evolved kernels)" )
188+
189+ # Verify evolved kernels are valid before running trials
190+ if evolved_kernels :
191+ print (f" ✅ Testing evolved kernels: { list (evolved_kernels .keys ())} " )
192+ for kernel_name , kernel_func in evolved_kernels .items ():
193+ if kernel_func is None :
194+ print (f" ⚠️ Warning: { kernel_name } is None" )
195+ else :
196+ print (f" ✅ { kernel_name } : { type (kernel_func )} " )
187197
188198 for trial in range (num_trials ):
189199 print (f"\n --- Evolved Trial { trial + 1 } /{ num_trials } ---" )
@@ -729,6 +739,10 @@ def _run_single_trial(
729739 """Run a single LoRA fine-tuning trial."""
730740
731741 print (f" 🧪 Running { trial_name } ..." )
742+ if evolved_kernels :
743+ print (f" 📦 Using evolved kernels: { list (evolved_kernels .keys ())} " )
744+ else :
745+ print (f" 📋 Using standard MLX-LM (no kernels)" )
732746
733747 try :
734748 # Memory before
@@ -762,6 +776,13 @@ def _run_single_trial(
762776
763777 # Extract additional metrics
764778 training_time = metrics .get ("training_time" , total_time )
779+
780+ # Check if kernels were actually used
781+ kernels_used = metrics .get ("used_evolved_kernels" , False )
782+ if evolved_kernels and not kernels_used :
783+ print (f" ⚠️ Warning: Evolved kernels provided but not used" )
784+ elif evolved_kernels and kernels_used :
785+ print (f" ✅ Evolved kernels successfully applied" )
765786
766787 # Calculate approximate tokens/second
767788 estimated_tokens = config ["iters" ] * config ["batch_size" ] * config ["max_seq_length" ]
@@ -770,7 +791,8 @@ def _run_single_trial(
770791 print (f" Final loss: { final_loss :.4f} " )
771792 print (f" Training time: { training_time :.2f} s" )
772793 print (f" Memory delta: { memory_delta :.1f} MB" )
773- print (f" Evolved kernels: { evolved_kernels is not None } " )
794+ print (f" Tokens/sec: { tokens_per_second :.1f} " )
795+ print (f" Kernels used: { kernels_used } " )
774796
775797 return {
776798 "final_loss" : float (final_loss ),
@@ -780,10 +802,13 @@ def _run_single_trial(
780802 "tokens_per_second" : float (tokens_per_second ),
781803 "lora_rank" : config ["lora_parameters" ]["rank" ],
782804 "num_layers" : config ["num_layers" ],
805+ "kernels_used" : bool (kernels_used ),
783806 }
784807
785808 except Exception as e :
786809 print (f" ❌ Failed: { e } " )
810+ import traceback
811+ traceback .print_exc ()
787812 return {"error" : str (e )}
788813
789814 def _analyze_results (self , results : Dict [str , List [Dict ]]) -> Dict [str , Any ]:
@@ -897,18 +922,32 @@ def evaluate(program_path: str) -> Dict[str, Union[bool, float, str, int]]:
897922 return {"overall_score" : 0.0 , "error" : "Missing baseline_lora_kernels function" }
898923
899924 # Get evolved kernels
900- evolved_kernels = evolved_program .evolved_lora_kernels ()
901- baseline_kernels = evolved_program .baseline_lora_kernels () # Returns None
902-
903- print (f"✅ Evolved kernels loaded: { list (evolved_kernels .keys ())} " )
904- print (f"✅ Baseline: Standard MLX-LM (no custom kernels)" )
925+ print ("📦 Loading evolved kernels..." )
926+ try :
927+ evolved_kernels = evolved_program .evolved_lora_kernels ()
928+ baseline_kernels = evolved_program .baseline_lora_kernels () # Returns None
929+
930+ print (f"✅ Evolved kernels loaded: { list (evolved_kernels .keys ()) if evolved_kernels else 'None' } " )
931+ print (f"✅ Baseline: Standard MLX-LM (no custom kernels)" )
932+
933+ # Validate evolved kernels
934+ if evolved_kernels :
935+ for kernel_name , kernel_func in evolved_kernels .items ():
936+ if kernel_func is None :
937+ print (f" ⚠️ Warning: { kernel_name } is None" )
938+ else :
939+ print (f" ✅ { kernel_name } : { type (kernel_func )} " )
940+
941+ except Exception as e :
942+ print (f"❌ Failed to load evolved kernels: { e } " )
943+ return {"overall_score" : 0.0 , "error" : f"Failed to load evolved kernels: { e } " }
905944
906945 # Setup benchmark
907946 benchmark = MLXLoRABenchmark ()
908947
909948 # Run sequential comparison (baseline first, then evolved)
910949 comparison_results = benchmark .compare_implementations (
911- evolved_kernels = evolved_kernels , num_trials = 5
950+ evolved_kernels = evolved_kernels , num_trials = 3 # Reduced for faster testing
912951 )
913952
914953 if "error" in comparison_results :
@@ -947,6 +986,16 @@ def evaluate(program_path: str) -> Dict[str, Union[bool, float, str, int]]:
947986 f" Evolved - Loss: { evolved_avg ['final_loss' ]:.4f} , Time: { evolved_avg ['training_time' ]:.1f} s, Memory: { evolved_avg ['memory_delta' ]:.1f} MB"
948987 )
949988
989+ # Check if kernels were actually used in evolved trials
990+ evolved_success = [r for r in comparison_results .get ("evolved" , []) if "error" not in r ]
991+ if evolved_success :
992+ kernels_actually_used = any (r .get ("kernels_used" , False ) for r in evolved_success )
993+ if evolved_kernels and not kernels_actually_used :
994+ print (f" ⚠️ WARNING: Evolved kernels were provided but not used in trials" )
995+ print (f" 🔍 This suggests the kernel injection mechanism may not be working" )
996+ elif evolved_kernels and kernels_actually_used :
997+ print (f" ✅ Evolved kernels were successfully used in trials" )
998+
950999 # Success interpretation
9511000 if overall_score >= 0.8 :
9521001 print (" 🥇 EXCELLENT: Strong improvements while maintaining convergence!" )
@@ -994,6 +1043,7 @@ def evaluate(program_path: str) -> Dict[str, Union[bool, float, str, int]]:
9941043 "target_achieved" : bool (
9951044 loss_convergence_ok and (speed_improvement > 1.1 or memory_improvement > 1.1 )
9961045 ),
1046+ "kernels_actually_used" : bool (evolved_success and any (r .get ("kernels_used" , False ) for r in evolved_success )) if evolved_success else False ,
9971047 }
9981048
9991049 return results
0 commit comments