Skip to content

Commit ad13d3e

Browse files
committed
h
1 parent eac05b8 commit ad13d3e

File tree

2 files changed

+257
-470
lines changed

2 files changed

+257
-470
lines changed

examples/mlx_fine_tuning_kernels/evaluator.py

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,10 @@ def create_test_config(self, data_dir: str, adapter_dir: str) -> Dict[str, Any]:
120120

121121
def compare_implementations(
122122
self,
123-
baseline_kernels: Dict,
124123
evolved_kernels: Dict,
125-
num_trials: int = 1
124+
num_trials: int = 1
126125
) -> Dict[str, Any]:
127-
"""Compare baseline vs evolved LoRA implementations using real mlx-lm."""
126+
"""Compare standard MLX-LM vs MLX-LM with evolved kernels injected."""
128127

129128
if not MLX_LM_AVAILABLE:
130129
return {"error": "MLX-LM not available for real benchmarking"}
@@ -152,19 +151,19 @@ def compare_implementations(
152151
evolved_data_dir, evolved_adapter_dir
153152
])
154153

155-
# Test baseline implementation
154+
# Test baseline implementation (standard MLX-LM)
156155
try:
157-
print("🔬 Testing BASELINE implementation...")
156+
print("🔬 Testing BASELINE implementation (standard MLX-LM)...")
158157

159158
# Create test dataset
160159
self._create_test_dataset(baseline_data_dir)
161160
baseline_config = self.create_test_config(baseline_data_dir, baseline_adapter_dir)
162161

163162
clear_mlx_cache_and_gc()
164-
baseline_result = self._run_lora_benchmark(
165-
baseline_kernels['optimized_lora_fine_tuning'],
163+
baseline_result = self._run_lora_benchmark_with_kernels(
166164
baseline_config,
167-
"BASELINE"
165+
"BASELINE",
166+
evolved_kernels=None # No evolved kernels = standard MLX-LM
168167
)
169168
results['baseline'].append(baseline_result)
170169

@@ -185,10 +184,10 @@ def compare_implementations(
185184
evolved_config = self.create_test_config(evolved_data_dir, evolved_adapter_dir)
186185

187186
clear_mlx_cache_and_gc()
188-
evolved_result = self._run_lora_benchmark(
189-
evolved_kernels['optimized_lora_fine_tuning'],
187+
evolved_result = self._run_lora_benchmark_with_kernels(
190188
evolved_config,
191-
"EVOLVED"
189+
"EVOLVED",
190+
evolved_kernels=evolved_kernels # Inject evolved kernels
192191
)
193192
results['evolved'].append(evolved_result)
194193

@@ -411,13 +410,13 @@ def _create_test_dataset(self, output_dir: str, num_samples: int = 300):
411410
for example in data:
412411
f.write(json.dumps(example) + "\n")
413412

414-
def _run_lora_benchmark(
413+
def _run_lora_benchmark_with_kernels(
415414
self,
416-
lora_fine_tuning_fn,
417415
config: Dict[str, Any],
418-
implementation_name: str
416+
implementation_name: str,
417+
evolved_kernels: Optional[Dict] = None
419418
) -> Dict[str, Union[float, str]]:
420-
"""Run LoRA fine-tuning benchmark."""
419+
"""Run LoRA fine-tuning benchmark with optional evolved kernel injection."""
421420

422421
print(f" 🧪 Running {implementation_name} LoRA fine-tuning...")
423422

@@ -426,12 +425,21 @@ def _run_lora_benchmark(
426425
memory_before = get_memory_usage()
427426
start_time = time.perf_counter()
428427

429-
# Run LoRA fine-tuning
430-
final_loss, metrics = lora_fine_tuning_fn(
428+
# Import and run the training function
429+
import sys
430+
import os
431+
current_dir = os.path.dirname(os.path.abspath(__file__))
432+
sys.path.insert(0, current_dir)
433+
434+
from initial_program import standard_lora_fine_tuning_with_kernels
435+
436+
# Run training with or without evolved kernels
437+
final_loss, metrics = standard_lora_fine_tuning_with_kernels(
431438
model_name=config['model'],
432439
train_data_path=config['data'],
433440
config=config,
434-
adapter_save_path=config['adapter_path']
441+
adapter_save_path=config['adapter_path'],
442+
evolved_kernels=evolved_kernels
435443
)
436444

437445
# Timing and memory
@@ -451,6 +459,7 @@ def _run_lora_benchmark(
451459
print(f" Final loss: {final_loss:.4f}")
452460
print(f" Training time: {training_time:.2f}s")
453461
print(f" Memory delta: {memory_delta:.1f} MB")
462+
print(f" Used evolved kernels: {evolved_kernels is not None}")
454463

455464
return {
456465
'final_loss': float(final_loss),
@@ -561,28 +570,20 @@ def evaluate(program_path: str) -> Dict[str, Union[bool, float, str, int]]:
561570
"error": "Missing baseline_lora_kernels function"
562571
}
563572

564-
# Get LoRA implementations
573+
# Get evolved kernels
565574
evolved_kernels = evolved_program.evolved_lora_kernels()
566-
baseline_kernels = evolved_program.baseline_lora_kernels()
567-
568-
# Check required kernels
569-
required_key = 'optimized_lora_fine_tuning'
570-
if required_key not in evolved_kernels or required_key not in baseline_kernels:
571-
return {
572-
"overall_score": 0.0,
573-
"error": f"Missing kernel: {required_key}"
574-
}
575+
baseline_kernels = evolved_program.baseline_lora_kernels() # Returns None
575576

576-
print(f"✅ LoRA implementations loaded successfully")
577+
print(f"✅ Evolved kernels loaded: {list(evolved_kernels.keys())}")
578+
print(f"✅ Baseline kernels: {baseline_kernels} (standard MLX-LM)")
577579

578580
# Setup benchmark
579581
benchmark = MLXLoRABenchmark()
580582

581583
# Run comparison
582584
comparison_results = benchmark.compare_implementations(
583-
baseline_kernels=baseline_kernels,
584585
evolved_kernels=evolved_kernels,
585-
num_trials=5
586+
num_trials=5
586587
)
587588

588589
if 'error' in comparison_results:

0 commit comments

Comments
 (0)