Skip to content

Commit a42bfd3

Browse files
committed
Update evaluator.py
1 parent dba614c commit a42bfd3

File tree

1 file changed

+59
-8
lines changed

1 file changed

+59
-8
lines changed

examples/mlx_finetuning_optimization/evaluator.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ def evaluate_optimization_patterns(program, baseline_results: Dict[str, Any]) ->
188188
"training_speed": 0.0,
189189
"memory_improvement": 0.0,
190190
"speed_improvement": 0.0,
191+
"final_loss": 999.0, # Very bad loss
192+
"loss_ratio": 999.0,
191193
"overall_fitness": 0.0,
192194
"error": f"Invalid configuration: {validation_message}"
193195
}
@@ -203,6 +205,8 @@ def evaluate_optimization_patterns(program, baseline_results: Dict[str, Any]) ->
203205
"training_speed": 0.0,
204206
"memory_improvement": 0.0,
205207
"speed_improvement": 0.0,
208+
"final_loss": 999.0,
209+
"loss_ratio": 999.0,
206210
"overall_fitness": 0.0,
207211
"error": optimization_results["error"]
208212
}
@@ -212,58 +216,105 @@ def evaluate_optimization_patterns(program, baseline_results: Dict[str, Any]) ->
212216
baseline_memory_efficiency = baseline_results.get("memory_efficiency", 0.001)
213217
baseline_peak_memory = baseline_results.get("peak_memory_mb", 1000.0)
214218
baseline_total_time = baseline_results.get("total_time", 100.0)
219+
baseline_final_loss = baseline_results.get("final_loss", 2.0) # CRITICAL: Add final loss
215220

216221
opt_tokens_per_sec = optimization_results.get("tokens_per_second", 0.0)
217222
opt_memory_efficiency = optimization_results.get("memory_efficiency", 0.0)
218223
opt_peak_memory = optimization_results.get("peak_memory_mb", float('inf'))
219224
opt_total_time = optimization_results.get("total_time", float('inf'))
225+
opt_final_loss = optimization_results.get("final_loss", 999.0) # CRITICAL: Add final loss
226+
227+
# Calculate loss ratio (optimized loss / baseline loss)
228+
loss_ratio = opt_final_loss / baseline_final_loss if baseline_final_loss > 0 else 999.0
229+
230+
# CRITICAL CONSTRAINT: Reject if final loss is significantly worse
231+
MAX_LOSS_DEGRADATION = 1.20 # Allow max 20% worse loss
232+
if loss_ratio > MAX_LOSS_DEGRADATION:
233+
print(f"❌ REJECTING optimization: Final loss too high!")
234+
print(f" Baseline loss: {baseline_final_loss:.4f}")
235+
print(f" Optimized loss: {opt_final_loss:.4f}")
236+
print(f" Loss ratio: {loss_ratio:.2f} (max allowed: {MAX_LOSS_DEGRADATION})")
237+
238+
return {
239+
"memory_efficiency": 0.0,
240+
"training_speed": 0.0,
241+
"memory_improvement": -1.0,
242+
"speed_improvement": -1.0,
243+
"final_loss": float(opt_final_loss),
244+
"loss_ratio": float(loss_ratio),
245+
"overall_fitness": -10.0, # Heavy penalty
246+
"error": f"Final loss degraded too much: {loss_ratio:.2f}x vs baseline"
247+
}
220248

221249
# Calculate percentage improvements
222250
speed_improvement = (opt_tokens_per_sec - baseline_tokens_per_sec) / baseline_tokens_per_sec if baseline_tokens_per_sec > 0 else 0.0
223251
memory_efficiency_improvement = (opt_memory_efficiency - baseline_memory_efficiency) / baseline_memory_efficiency if baseline_memory_efficiency > 0 else 0.0
224252
memory_usage_improvement = (baseline_peak_memory - opt_peak_memory) / baseline_peak_memory if baseline_peak_memory > 0 else 0.0
225253
time_improvement = (baseline_total_time - opt_total_time) / baseline_total_time if baseline_total_time > 0 else 0.0
226254

255+
# Loss improvement (lower is better, so we want negative loss_ratio improvement)
256+
loss_improvement = (baseline_final_loss - opt_final_loss) / baseline_final_loss if baseline_final_loss > 0 else 0.0
257+
227258
# Ensure improvements are reasonable (cap at 10x improvement to avoid outliers)
228259
speed_improvement = max(-0.9, min(speed_improvement, 10.0))
229260
memory_efficiency_improvement = max(-0.9, min(memory_efficiency_improvement, 10.0))
230261
memory_usage_improvement = max(-0.9, min(memory_usage_improvement, 0.9)) # Max 90% memory reduction
231262
time_improvement = max(-0.9, min(time_improvement, 0.9)) # Max 90% time reduction
263+
loss_improvement = max(-2.0, min(loss_improvement, 2.0)) # Loss can be 3x better or 2x worse
232264

233-
# Calculate overall fitness with emphasis on memory efficiency (key constraint for Mac users)
234-
# Positive improvements should increase fitness, negative should decrease it
265+
# Calculate overall fitness with LOSS AS PRIMARY FACTOR
235266
fitness_components = {
236-
"memory_efficiency_score": memory_efficiency_improvement * 0.4, # 40% weight
237-
"speed_score": speed_improvement * 0.25, # 25% weight
238-
"memory_usage_score": memory_usage_improvement * 0.25, # 25% weight
239-
"time_score": time_improvement * 0.1 # 10% weight
267+
"loss_quality_score": loss_improvement * 0.5, # 50% weight - MOST IMPORTANT
268+
"memory_efficiency_score": memory_efficiency_improvement * 0.2, # 20% weight
269+
"speed_score": speed_improvement * 0.2, # 20% weight
270+
"memory_usage_score": memory_usage_improvement * 0.1, # 10% weight
240271
}
241272

242273
overall_fitness = sum(fitness_components.values())
243274

244275
# Add stability bonus/penalty
245-
if opt_peak_memory < float('inf') and opt_tokens_per_sec > 0:
276+
if opt_peak_memory < float('inf') and opt_tokens_per_sec > 0 and opt_final_loss < 50.0:
246277
stability_bonus = 0.1
247278
else:
248279
stability_bonus = -0.5 # Heavy penalty for failed runs
249280

250281
overall_fitness += stability_bonus
251282

283+
# Add loss quality bonus for maintaining good learning
284+
if loss_ratio <= 1.05: # Within 5% of baseline loss
285+
loss_quality_bonus = 0.2 # Bonus for maintaining learning quality
286+
elif loss_ratio <= 1.10: # Within 10%
287+
loss_quality_bonus = 0.1
288+
else:
289+
loss_quality_bonus = 0.0
290+
291+
overall_fitness += loss_quality_bonus
292+
252293
# Normalize fitness to reasonable range
253-
overall_fitness = max(-1.0, min(overall_fitness, 5.0))
294+
overall_fitness = max(-10.0, min(overall_fitness, 5.0))
295+
296+
print(f"✅ Optimization ACCEPTED:")
297+
print(f" Final loss: {opt_final_loss:.4f} vs baseline {baseline_final_loss:.4f} (ratio: {loss_ratio:.2f})")
298+
print(f" Speed: {speed_improvement:.1%} improvement")
299+
print(f" Memory efficiency: {memory_efficiency_improvement:.1%} improvement")
300+
print(f" Overall fitness: {overall_fitness:.4f}")
254301

255302
return {
256303
"memory_efficiency": float(opt_memory_efficiency),
257304
"training_speed": float(opt_tokens_per_sec),
258305
"peak_memory_mb": float(opt_peak_memory),
259306
"total_time": float(opt_total_time),
307+
"final_loss": float(opt_final_loss),
308+
"loss_ratio": float(loss_ratio),
260309
"speed_improvement": float(speed_improvement),
261310
"memory_efficiency_improvement": float(memory_efficiency_improvement),
262311
"memory_usage_improvement": float(memory_usage_improvement),
263312
"time_improvement": float(time_improvement),
313+
"loss_improvement": float(loss_improvement),
264314
"overall_fitness": float(overall_fitness),
265315
"baseline_tokens_per_sec": float(baseline_tokens_per_sec),
266316
"baseline_memory_efficiency": float(baseline_memory_efficiency),
317+
"baseline_final_loss": float(baseline_final_loss),
267318
"config_valid": True,
268319
"fitness_components": fitness_components
269320
}

0 commit comments

Comments
 (0)