Skip to content

Commit 0d7fc35

Browse files
committed
Update evaluator.py
1 parent 8155993 commit 0d7fc35

File tree

1 file changed

+25
-9
lines changed

1 file changed

+25
-9
lines changed

examples/mlx_finetuning_optimization/evaluator.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -196,15 +196,31 @@ def detect_loss_scaling_hacks(optimization_results: Dict[str, Any]) -> Tuple[boo
196196
# This is harder to detect directly, but we can look for patterns
197197
opt_final_loss = optimization_results.get("final_loss", 999.0)
198198

199-
# Check if loss is a simple fraction that suggests artificial scaling
200-
# Common hack: loss / accumulation_steps where accumulation_steps > 1
201-
COMMON_SCALE_FACTORS = [2, 4, 8, 16, 32] # Common accumulation step values
202-
203-
for scale_factor in COMMON_SCALE_FACTORS:
204-
scaled_loss = opt_final_loss * scale_factor
205-
# If scaling by a common factor gives us a "normal" looking loss (1-5 range)
206-
if 1.0 <= scaled_loss <= 5.0:
207-
return False, f"Loss appears artificially scaled: {opt_final_loss:.4f} * {scale_factor} = {scaled_loss:.4f} (possible gradient accumulation hack)"
199+
# FIXED: Only flag extremely suspicious patterns, not normal losses
200+
# A loss between 0.5 and 10.0 is reasonable for language modeling
201+
REASONABLE_LOSS_RANGE = (0.1, 15.0) # Expanded reasonable range
202+
203+
if not (REASONABLE_LOSS_RANGE[0] <= opt_final_loss <= REASONABLE_LOSS_RANGE[1]):
204+
# Only check for scaling hacks if the loss is outside reasonable range
205+
COMMON_SCALE_FACTORS = [2, 4, 8, 16, 32] # Common accumulation step values
206+
207+
for scale_factor in COMMON_SCALE_FACTORS:
208+
scaled_loss = opt_final_loss * scale_factor
209+
# If scaling by a common factor gives us a "normal" looking loss (1-5 range)
210+
# AND the original loss was suspiciously low (< 0.1), then flag it
211+
if opt_final_loss < 0.1 and 1.0 <= scaled_loss <= 5.0:
212+
return False, f"Loss appears artificially scaled: {opt_final_loss:.4f} * {scale_factor} = {scaled_loss:.4f} (possible gradient accumulation hack)"
213+
214+
# Additional check: Flag exact multiples that suggest division hacks
215+
# But only if the loss is suspiciously low to begin with
216+
if opt_final_loss < 0.05: # Only very low losses
217+
for scale_factor in [2, 4, 8, 16]:
218+
scaled_loss = opt_final_loss * scale_factor
219+
# Check if scaled loss is very close to a "normal" value
220+
normal_targets = [1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0]
221+
for target in normal_targets:
222+
if abs(scaled_loss - target) < 0.01: # Very close match
223+
return False, f"Suspiciously exact loss scaling: {opt_final_loss:.4f} * {scale_factor}{target:.1f}"
208224

209225
return True, "No obvious loss scaling detected"
210226

0 commit comments

Comments
 (0)