Skip to content

Commit cf5e451

Browse files
committed
f
1 parent ad13d3e commit cf5e451

File tree

2 files changed

+201
-132
lines changed

2 files changed

+201
-132
lines changed

examples/mlx_fine_tuning_kernels/evaluator.py

Lines changed: 118 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
MLX LoRA Fine-tuning Optimization Evaluator
33
44
This evaluator performs real LoRA fine-tuning benchmarks using the mlx-lm library,
5-
comparing evolved implementations against standard MLX-LM LoRA implementations.
5+
comparing standard MLX-LM against MLX-LM with evolved kernels injected.
66
The goal is to achieve the same training loss with improved memory efficiency and/or speed.
77
"""
88

@@ -62,8 +62,8 @@ def clear_mlx_cache_and_gc():
6262

6363
class MLXLoRABenchmark:
6464
"""
65-
Benchmark for comparing MLX-LM LoRA fine-tuning implementations.
66-
Measures training loss convergence, speed, and memory usage using real mlx-lm.
65+
Benchmark for comparing standard MLX-LM vs MLX-LM with evolved kernels.
66+
Uses proper sequential evaluation to avoid monkey patching interference.
6767
"""
6868

6969
def __init__(self, model_name: str = "mlx-community/Qwen2.5-0.5B-Instruct-4bit"):
@@ -78,13 +78,6 @@ def cleanup(self):
7878
except:
7979
pass
8080
self.temp_dirs.clear()
81-
82-
# Also run general cleanup
83-
try:
84-
from cleanup import cleanup_temp_files
85-
cleanup_temp_files()
86-
except ImportError:
87-
pass
8881

8982
def create_test_config(self, data_dir: str, adapter_dir: str) -> Dict[str, Any]:
9083
"""Create test configuration for LoRA fine-tuning with all MLX-LM expected attributes."""
@@ -96,21 +89,21 @@ def create_test_config(self, data_dir: str, adapter_dir: str) -> Dict[str, Any]:
9689
"optimizer_config": {"adam": {}},
9790
"data": data_dir,
9891
"seed": 42,
99-
"num_layers": 2, # Small for fast testing
100-
"batch_size": 1, # Small for memory efficiency
101-
"iters": 10, # More iterations for larger dataset
102-
"val_batches": 5,
92+
"num_layers": 4, # More layers for comprehensive evaluation
93+
"batch_size": 2, # Reasonable batch size for larger dataset
94+
"iters": 25, # More iterations for larger dataset
95+
"val_batches": 10,
10396
"learning_rate": 1e-4,
104-
"steps_per_report": 5,
105-
"steps_per_eval": 20,
97+
"steps_per_report": 10,
98+
"steps_per_eval": 50,
10699
"adapter_path": adapter_dir,
107100
"save_every": 100,
108-
"max_seq_length": 256, # Shorter sequences
109-
"lora_parameters": {"rank": 8, "dropout": 0.0, "scale": 16.0}, # Smaller rank
101+
"max_seq_length": 512, # Full sequence length
102+
"lora_parameters": {"rank": 16, "dropout": 0.0, "scale": 16.0}, # Standard rank
110103
"mask_prompt": False,
111104
# Additional MLX-LM expected attributes
112105
"test": True,
113-
"test_batches": 5,
106+
"test_batches": 10,
114107
"resume_adapter_file": None,
115108
"config": None,
116109
"grad_checkpoint": False,
@@ -121,90 +114,130 @@ def create_test_config(self, data_dir: str, adapter_dir: str) -> Dict[str, Any]:
121114
def compare_implementations(
122115
self,
123116
evolved_kernels: Dict,
124-
num_trials: int = 1
117+
num_trials: int = 3
125118
) -> Dict[str, Any]:
126-
"""Compare standard MLX-LM vs MLX-LM with evolved kernels injected."""
119+
"""
120+
Compare standard MLX-LM vs MLX-LM with evolved kernels.
121+
122+
PROPER EVALUATION STRUCTURE:
123+
1. Run ALL baseline trials first (no patching)
124+
2. Calculate baseline metrics
125+
3. Apply evolved kernels patching ONCE
126+
4. Run ALL evolved trials
127+
5. Calculate evolved metrics
128+
6. Compare results
129+
130+
This avoids monkey patching interference between trials.
131+
"""
127132

128133
if not MLX_LM_AVAILABLE:
129134
return {"error": "MLX-LM not available for real benchmarking"}
130135

131-
print(f"\n📊 MLX-LM LORA FINE-TUNING COMPARISON")
136+
print(f"\n📊 MLX-LM LORA KERNEL COMPARISON")
132137
print(f" Model: {self.model_name}")
133-
print(f" Trials: {num_trials}")
138+
print(f" Trials per implementation: {num_trials}")
139+
print(f" Evaluation strategy: Sequential (baseline first, then evolved)")
134140

135-
results = {
136-
'baseline': [],
137-
'evolved': []
138-
}
141+
baseline_results = []
142+
evolved_results = []
143+
144+
# ========================================
145+
# PHASE 1: Run ALL baseline trials first
146+
# ========================================
147+
print(f"\n🔬 PHASE 1: Running {num_trials} BASELINE trials (standard MLX-LM)")
139148

140149
for trial in range(num_trials):
141-
print(f"\n--- Trial {trial + 1}/{num_trials} ---")
150+
print(f"\n--- Baseline Trial {trial + 1}/{num_trials} ---")
142151

143152
# Create temporary directories for this trial
144153
baseline_data_dir = tempfile.mkdtemp(prefix="baseline_data_")
145154
baseline_adapter_dir = tempfile.mkdtemp(prefix="baseline_adapters_")
146-
evolved_data_dir = tempfile.mkdtemp(prefix="evolved_data_")
147-
evolved_adapter_dir = tempfile.mkdtemp(prefix="evolved_adapters_")
148-
149-
self.temp_dirs.extend([
150-
baseline_data_dir, baseline_adapter_dir,
151-
evolved_data_dir, evolved_adapter_dir
152-
])
155+
self.temp_dirs.extend([baseline_data_dir, baseline_adapter_dir])
153156

154-
# Test baseline implementation (standard MLX-LM)
155157
try:
156-
print("🔬 Testing BASELINE implementation (standard MLX-LM)...")
157-
158158
# Create test dataset
159159
self._create_test_dataset(baseline_data_dir)
160160
baseline_config = self.create_test_config(baseline_data_dir, baseline_adapter_dir)
161161

162162
clear_mlx_cache_and_gc()
163-
baseline_result = self._run_lora_benchmark_with_kernels(
163+
164+
# Run baseline (standard MLX-LM)
165+
baseline_result = self._run_single_trial(
164166
baseline_config,
165-
"BASELINE",
166-
evolved_kernels=None # No evolved kernels = standard MLX-LM
167+
f"BASELINE-{trial+1}",
168+
evolved_kernels=None # No kernels = standard MLX-LM
167169
)
168-
results['baseline'].append(baseline_result)
170+
baseline_results.append(baseline_result)
171+
172+
# Early exit if first baseline trial fails
173+
if trial == 0 and 'error' in baseline_result:
174+
print(" 🚨 First baseline trial failed - stopping evaluation")
175+
return {"error": f"First baseline trial failed: {baseline_result['error']}"}
169176

170177
except Exception as e:
171-
print(f" ❌ Baseline trial failed: {e}")
172-
results['baseline'].append({"error": str(e)})
173-
# FAIL FAST: If first trial fails, don't continue
178+
print(f" ❌ Baseline trial {trial+1} failed: {e}")
179+
baseline_results.append({"error": str(e)})
180+
181+
# Early exit if first trial fails
174182
if trial == 0:
175-
print(" 🚨 First trial failed - stopping evaluation early")
176-
return {"error": f"First trial failed: {e}"}
183+
print(" 🚨 First baseline trial failed - stopping evaluation")
184+
return {"error": f"First baseline trial failed: {e}"}
185+
186+
# ========================================
187+
# PHASE 2: Run ALL evolved trials
188+
# ========================================
189+
print(f"\n🚀 PHASE 2: Running {num_trials} EVOLVED trials (MLX-LM + evolved kernels)")
190+
191+
for trial in range(num_trials):
192+
print(f"\n--- Evolved Trial {trial + 1}/{num_trials} ---")
193+
194+
# Create temporary directories for this trial
195+
evolved_data_dir = tempfile.mkdtemp(prefix="evolved_data_")
196+
evolved_adapter_dir = tempfile.mkdtemp(prefix="evolved_adapters_")
197+
self.temp_dirs.extend([evolved_data_dir, evolved_adapter_dir])
177198

178-
# Test evolved implementation
179199
try:
180-
print("🚀 Testing EVOLVED implementation...")
181-
182200
# Create test dataset (same as baseline)
183201
self._create_test_dataset(evolved_data_dir)
184202
evolved_config = self.create_test_config(evolved_data_dir, evolved_adapter_dir)
185203

186204
clear_mlx_cache_and_gc()
187-
evolved_result = self._run_lora_benchmark_with_kernels(
205+
206+
# Run evolved (MLX-LM + evolved kernels)
207+
evolved_result = self._run_single_trial(
188208
evolved_config,
189-
"EVOLVED",
209+
f"EVOLVED-{trial+1}",
190210
evolved_kernels=evolved_kernels # Inject evolved kernels
191211
)
192-
results['evolved'].append(evolved_result)
212+
evolved_results.append(evolved_result)
213+
214+
# Early exit if first evolved trial fails
215+
if trial == 0 and 'error' in evolved_result:
216+
print(" 🚨 First evolved trial failed - stopping evaluation")
217+
return {"error": f"First evolved trial failed: {evolved_result['error']}"}
193218

194219
except Exception as e:
195-
print(f" ❌ Evolved trial failed: {e}")
196-
results['evolved'].append({"error": str(e)})
197-
# FAIL FAST: If first trial fails, don't continue
220+
print(f" ❌ Evolved trial {trial+1} failed: {e}")
221+
evolved_results.append({"error": str(e)})
222+
223+
# Early exit if first trial fails
198224
if trial == 0:
199-
print(" 🚨 First trial failed - stopping evaluation early")
200-
return {"error": f"First trial failed: {e}"}
225+
print(" 🚨 First evolved trial failed - stopping evaluation")
226+
return {"error": f"First evolved trial failed: {e}"}
201227

202-
# Cleanup after all trials
228+
# ========================================
229+
# PHASE 3: Analyze and compare results
230+
# ========================================
203231
self.cleanup()
204232

233+
results = {
234+
'baseline': baseline_results,
235+
'evolved': evolved_results
236+
}
237+
205238
return self._analyze_results(results)
206239

207-
def _create_test_dataset(self, output_dir: str, num_samples: int = 300):
240+
def _create_test_dataset(self, output_dir: str, num_samples: int = 500):
208241
"""Create a comprehensive test dataset for LoRA fine-tuning with diverse examples."""
209242
examples = [
210243
# AI and Machine Learning
@@ -378,47 +411,42 @@ def _create_test_dataset(self, output_dir: str, num_samples: int = 300):
378411
{"text": "What is dietetics?\nDietetics applies nutrition science to promote health and treat disease through proper food and eating habits."},
379412
]
380413

381-
# Ensure we have enough diverse examples
414+
# Use smaller dataset for faster evaluation
382415
if num_samples > len(examples):
383-
# Cycle through examples to reach desired number
384416
dataset = []
385417
for i in range(num_samples):
386418
dataset.append(examples[i % len(examples)])
387419
else:
388-
# Use subset if we have more examples than needed
389420
dataset = examples[:num_samples]
390421

391-
# Create balanced splits with sufficient validation data
422+
# Create balanced splits with minimum sizes
392423
train_size = max(10, int(0.7 * num_samples))
393424
val_size = max(5, int(0.2 * num_samples))
394-
test_size = num_samples - train_size - val_size
395-
if test_size < 3:
396-
test_size = 3
397-
val_size = num_samples - train_size - test_size
425+
test_size = max(3, num_samples - train_size - val_size)
398426

399427
train_data = dataset[:train_size]
400428
val_data = dataset[train_size:train_size + val_size]
401429
test_data = dataset[train_size + val_size:train_size + val_size + test_size]
402430

403-
print(f"📊 Creating comprehensive dataset: {len(train_data)} train, {len(val_data)} valid, {len(test_data)} test examples")
431+
print(f"📊 Dataset: {len(train_data)} train, {len(val_data)} valid, {len(test_data)} test examples")
404432

405-
# Write datasets - CRITICAL: Use "valid" not "val" for MLX-LM
433+
# Write datasets - Use "valid" not "val" for MLX-LM
406434
os.makedirs(output_dir, exist_ok=True)
407435
for split, data in [("train", train_data), ("valid", val_data), ("test", test_data)]:
408436
file_path = os.path.join(output_dir, f"{split}.jsonl")
409437
with open(file_path, "w") as f:
410438
for example in data:
411439
f.write(json.dumps(example) + "\n")
412440

413-
def _run_lora_benchmark_with_kernels(
441+
def _run_single_trial(
414442
self,
415443
config: Dict[str, Any],
416-
implementation_name: str,
444+
trial_name: str,
417445
evolved_kernels: Optional[Dict] = None
418446
) -> Dict[str, Union[float, str]]:
419-
"""Run LoRA fine-tuning benchmark with optional evolved kernel injection."""
447+
"""Run a single LoRA fine-tuning trial."""
420448

421-
print(f" 🧪 Running {implementation_name} LoRA fine-tuning...")
449+
print(f" 🧪 Running {trial_name}...")
422450

423451
try:
424452
# Memory before
@@ -452,14 +480,14 @@ def _run_lora_benchmark_with_kernels(
452480
# Extract additional metrics
453481
training_time = metrics.get('training_time', total_time)
454482

455-
# Calculate approximate tokens/second (rough estimate)
483+
# Calculate approximate tokens/second
456484
estimated_tokens = config['iters'] * config['batch_size'] * config['max_seq_length']
457485
tokens_per_second = estimated_tokens / training_time if training_time > 0 else 0
458486

459487
print(f" Final loss: {final_loss:.4f}")
460488
print(f" Training time: {training_time:.2f}s")
461489
print(f" Memory delta: {memory_delta:.1f} MB")
462-
print(f" Used evolved kernels: {evolved_kernels is not None}")
490+
print(f" Evolved kernels: {evolved_kernels is not None}")
463491

464492
return {
465493
'final_loss': float(final_loss),
@@ -538,13 +566,16 @@ def _analyze_results(self, results: Dict[str, List[Dict]]) -> Dict[str, Any]:
538566

539567
def evaluate(program_path: str) -> Dict[str, Union[bool, float, str, int]]:
540568
"""
541-
Evaluate MLX-LM LoRA fine-tuning optimization program.
569+
Evaluate MLX-LM LoRA kernel optimization program.
570+
571+
Uses sequential evaluation approach:
572+
1. Run ALL baseline trials (standard MLX-LM)
573+
2. Run ALL evolved trials (MLX-LM + evolved kernels)
574+
3. Compare results
542575
543-
Performs real LoRA fine-tuning comparison using mlx-lm library between
544-
baseline and evolved implementations. Success metric: achieve same training
545-
loss with efficiency improvements.
576+
This avoids monkey patching interference between trials.
546577
"""
547-
print(f"🚀 Evaluating MLX-LM LoRA Fine-tuning Optimization: {program_path}")
578+
print(f"🚀 Evaluating MLX LoRA Kernel Optimization: {program_path}")
548579

549580
if not MLX_LM_AVAILABLE:
550581
return {
@@ -575,15 +606,15 @@ def evaluate(program_path: str) -> Dict[str, Union[bool, float, str, int]]:
575606
baseline_kernels = evolved_program.baseline_lora_kernels() # Returns None
576607

577608
print(f"✅ Evolved kernels loaded: {list(evolved_kernels.keys())}")
578-
print(f"✅ Baseline kernels: {baseline_kernels} (standard MLX-LM)")
609+
print(f"✅ Baseline: Standard MLX-LM (no custom kernels)")
579610

580611
# Setup benchmark
581612
benchmark = MLXLoRABenchmark()
582613

583-
# Run comparison
614+
# Run sequential comparison (baseline first, then evolved)
584615
comparison_results = benchmark.compare_implementations(
585616
evolved_kernels=evolved_kernels,
586-
num_trials=5
617+
num_trials=5
587618
)
588619

589620
if 'error' in comparison_results:
@@ -606,7 +637,7 @@ def evaluate(program_path: str) -> Dict[str, Union[bool, float, str, int]]:
606637
baseline_avg = comparison_results['baseline_avg']
607638
evolved_avg = comparison_results['evolved_avg']
608639

609-
print(f"\n📊 MLX-LM LORA FINE-TUNING OPTIMIZATION RESULTS:")
640+
print(f"\n📊 MLX LORA KERNEL OPTIMIZATION RESULTS:")
610641
print(f" Loss Convergence: {'✅' if loss_convergence_ok else '❌'} (diff: {loss_difference:.4f})")
611642
print(f" Speed Improvement: {speed_improvement:.2f}x")
612643
print(f" Memory Improvement: {memory_improvement:.2f}x")
@@ -664,7 +695,7 @@ def evaluate(program_path: str) -> Dict[str, Union[bool, float, str, int]]:
664695
"successful_evolved_trials": comparison_results['successful_trials']['evolved'],
665696

666697
# Metadata
667-
"evaluation_type": "mlx_lm_lora_finetuning",
698+
"evaluation_type": "mlx_lora_kernel_optimization",
668699
"achieves_convergence": bool(loss_convergence_ok),
669700
"has_efficiency_improvements": bool(speed_improvement > 1.05 or memory_improvement > 1.05),
670701
"target_achieved": bool(loss_convergence_ok and (speed_improvement > 1.1 or memory_improvement > 1.1)),
@@ -683,7 +714,7 @@ def evaluate(program_path: str) -> Dict[str, Union[bool, float, str, int]]:
683714

684715

685716
if __name__ == "__main__":
686-
print("Testing MLX-LM LoRA Fine-tuning Optimization Evaluator...")
717+
print("Testing MLX LoRA Kernel Optimization Evaluator...")
687718

688719
initial_program_path = os.path.join(os.path.dirname(__file__), "initial_program.py")
689720

0 commit comments

Comments
 (0)