Skip to content

Commit 70e2421

Browse files
committed
g
1 parent 8fadaf5 commit 70e2421

File tree

2 files changed

+342
-115
lines changed

2 files changed

+342
-115
lines changed

examples/mlx_fine_tuning_kernels/evaluator.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def compare_implementations(self, evolved_kernels: Dict, num_trials: int = 3) ->
134134
print(f" Model: {self.model_name}")
135135
print(f" Trials per implementation: {num_trials}")
136136
print(f" Evaluation strategy: Sequential (baseline first, then evolved)")
137+
print(f" Evolved kernels available: {list(evolved_kernels.keys()) if evolved_kernels else 'None'}")
137138

138139
baseline_results = []
139140
evolved_results = []
@@ -184,6 +185,15 @@ def compare_implementations(self, evolved_kernels: Dict, num_trials: int = 3) ->
184185
# PHASE 2: Run ALL evolved trials
185186
# ========================================
186187
print(f"\n🚀 PHASE 2: Running {num_trials} EVOLVED trials (MLX-LM + evolved kernels)")
188+
189+
# Verify evolved kernels are valid before running trials
190+
if evolved_kernels:
191+
print(f" ✅ Testing evolved kernels: {list(evolved_kernels.keys())}")
192+
for kernel_name, kernel_func in evolved_kernels.items():
193+
if kernel_func is None:
194+
print(f" ⚠️ Warning: {kernel_name} is None")
195+
else:
196+
print(f" ✅ {kernel_name}: {type(kernel_func)}")
187197

188198
for trial in range(num_trials):
189199
print(f"\n--- Evolved Trial {trial + 1}/{num_trials} ---")
@@ -729,6 +739,10 @@ def _run_single_trial(
729739
"""Run a single LoRA fine-tuning trial."""
730740

731741
print(f" 🧪 Running {trial_name}...")
742+
if evolved_kernels:
743+
print(f" 📦 Using evolved kernels: {list(evolved_kernels.keys())}")
744+
else:
745+
print(f" 📋 Using standard MLX-LM (no kernels)")
732746

733747
try:
734748
# Memory before
@@ -762,6 +776,13 @@ def _run_single_trial(
762776

763777
# Extract additional metrics
764778
training_time = metrics.get("training_time", total_time)
779+
780+
# Check if kernels were actually used
781+
kernels_used = metrics.get("used_evolved_kernels", False)
782+
if evolved_kernels and not kernels_used:
783+
print(f" ⚠️ Warning: Evolved kernels provided but not used")
784+
elif evolved_kernels and kernels_used:
785+
print(f" ✅ Evolved kernels successfully applied")
765786

766787
# Calculate approximate tokens/second
767788
estimated_tokens = config["iters"] * config["batch_size"] * config["max_seq_length"]
@@ -770,7 +791,8 @@ def _run_single_trial(
770791
print(f" Final loss: {final_loss:.4f}")
771792
print(f" Training time: {training_time:.2f}s")
772793
print(f" Memory delta: {memory_delta:.1f} MB")
773-
print(f" Evolved kernels: {evolved_kernels is not None}")
794+
print(f" Tokens/sec: {tokens_per_second:.1f}")
795+
print(f" Kernels used: {kernels_used}")
774796

775797
return {
776798
"final_loss": float(final_loss),
@@ -780,10 +802,13 @@ def _run_single_trial(
780802
"tokens_per_second": float(tokens_per_second),
781803
"lora_rank": config["lora_parameters"]["rank"],
782804
"num_layers": config["num_layers"],
805+
"kernels_used": bool(kernels_used),
783806
}
784807

785808
except Exception as e:
786809
print(f" ❌ Failed: {e}")
810+
import traceback
811+
traceback.print_exc()
787812
return {"error": str(e)}
788813

789814
def _analyze_results(self, results: Dict[str, List[Dict]]) -> Dict[str, Any]:
@@ -897,18 +922,32 @@ def evaluate(program_path: str) -> Dict[str, Union[bool, float, str, int]]:
897922
return {"overall_score": 0.0, "error": "Missing baseline_lora_kernels function"}
898923

899924
# Get evolved kernels
900-
evolved_kernels = evolved_program.evolved_lora_kernels()
901-
baseline_kernels = evolved_program.baseline_lora_kernels() # Returns None
902-
903-
print(f"✅ Evolved kernels loaded: {list(evolved_kernels.keys())}")
904-
print(f"✅ Baseline: Standard MLX-LM (no custom kernels)")
925+
print("📦 Loading evolved kernels...")
926+
try:
927+
evolved_kernels = evolved_program.evolved_lora_kernels()
928+
baseline_kernels = evolved_program.baseline_lora_kernels() # Returns None
929+
930+
print(f"✅ Evolved kernels loaded: {list(evolved_kernels.keys()) if evolved_kernels else 'None'}")
931+
print(f"✅ Baseline: Standard MLX-LM (no custom kernels)")
932+
933+
# Validate evolved kernels
934+
if evolved_kernels:
935+
for kernel_name, kernel_func in evolved_kernels.items():
936+
if kernel_func is None:
937+
print(f" ⚠️ Warning: {kernel_name} is None")
938+
else:
939+
print(f" ✅ {kernel_name}: {type(kernel_func)}")
940+
941+
except Exception as e:
942+
print(f"❌ Failed to load evolved kernels: {e}")
943+
return {"overall_score": 0.0, "error": f"Failed to load evolved kernels: {e}"}
905944

906945
# Setup benchmark
907946
benchmark = MLXLoRABenchmark()
908947

909948
# Run sequential comparison (baseline first, then evolved)
910949
comparison_results = benchmark.compare_implementations(
911-
evolved_kernels=evolved_kernels, num_trials=5
950+
evolved_kernels=evolved_kernels, num_trials=3 # Reduced for faster testing
912951
)
913952

914953
if "error" in comparison_results:
@@ -947,6 +986,16 @@ def evaluate(program_path: str) -> Dict[str, Union[bool, float, str, int]]:
947986
f" Evolved - Loss: {evolved_avg['final_loss']:.4f}, Time: {evolved_avg['training_time']:.1f}s, Memory: {evolved_avg['memory_delta']:.1f} MB"
948987
)
949988

989+
# Check if kernels were actually used in evolved trials
990+
evolved_success = [r for r in comparison_results.get("evolved", []) if "error" not in r]
991+
if evolved_success:
992+
kernels_actually_used = any(r.get("kernels_used", False) for r in evolved_success)
993+
if evolved_kernels and not kernels_actually_used:
994+
print(f" ⚠️ WARNING: Evolved kernels were provided but not used in trials")
995+
print(f" 🔍 This suggests the kernel injection mechanism may not be working")
996+
elif evolved_kernels and kernels_actually_used:
997+
print(f" ✅ Evolved kernels were successfully used in trials")
998+
950999
# Success interpretation
9511000
if overall_score >= 0.8:
9521001
print(" 🥇 EXCELLENT: Strong improvements while maintaining convergence!")
@@ -994,6 +1043,7 @@ def evaluate(program_path: str) -> Dict[str, Union[bool, float, str, int]]:
9941043
"target_achieved": bool(
9951044
loss_convergence_ok and (speed_improvement > 1.1 or memory_improvement > 1.1)
9961045
),
1046+
"kernels_actually_used": bool(evolved_success and any(r.get("kernels_used", False) for r in evolved_success)) if evolved_success else False,
9971047
}
9981048

9991049
return results

0 commit comments

Comments
 (0)