@@ -196,15 +196,31 @@ def detect_loss_scaling_hacks(optimization_results: Dict[str, Any]) -> Tuple[boo
196196 # This is harder to detect directly, but we can look for patterns
197197 opt_final_loss = optimization_results .get ("final_loss" , 999.0 )
198198
199- # Check if loss is a simple fraction that suggests artificial scaling
200- # Common hack: loss / accumulation_steps where accumulation_steps > 1
201- COMMON_SCALE_FACTORS = [2 , 4 , 8 , 16 , 32 ] # Common accumulation step values
202-
203- for scale_factor in COMMON_SCALE_FACTORS :
204- scaled_loss = opt_final_loss * scale_factor
205- # If scaling by a common factor gives us a "normal" looking loss (1-5 range)
206- if 1.0 <= scaled_loss <= 5.0 :
207- return False , f"Loss appears artificially scaled: { opt_final_loss :.4f} * { scale_factor } = { scaled_loss :.4f} (possible gradient accumulation hack)"
199+ # FIXED: Only flag extremely suspicious patterns, not normal losses
200+ # A loss between 0.5 and 10.0 is reasonable for language modeling
201+ REASONABLE_LOSS_RANGE = (0.1 , 15.0 ) # Expanded reasonable range
202+
203+ if not (REASONABLE_LOSS_RANGE [0 ] <= opt_final_loss <= REASONABLE_LOSS_RANGE [1 ]):
204+ # Only check for scaling hacks if the loss is outside reasonable range
205+ COMMON_SCALE_FACTORS = [2 , 4 , 8 , 16 , 32 ] # Common accumulation step values
206+
207+ for scale_factor in COMMON_SCALE_FACTORS :
208+ scaled_loss = opt_final_loss * scale_factor
209+ # If scaling by a common factor gives us a "normal" looking loss (1-5 range)
210+ # AND the original loss was suspiciously low (< 0.1), then flag it
211+ if opt_final_loss < 0.1 and 1.0 <= scaled_loss <= 5.0 :
212+ return False , f"Loss appears artificially scaled: { opt_final_loss :.4f} * { scale_factor } = { scaled_loss :.4f} (possible gradient accumulation hack)"
213+
214+ # Additional check: Flag exact multiples that suggest division hacks
215+ # But only if the loss is suspiciously low to begin with
216+ if opt_final_loss < 0.05 : # Only very low losses
217+ for scale_factor in [2 , 4 , 8 , 16 ]:
218+ scaled_loss = opt_final_loss * scale_factor
219+ # Check if scaled loss is very close to a "normal" value
220+ normal_targets = [1.0 , 1.5 , 2.0 , 2.5 , 3.0 , 3.5 , 4.0 , 4.5 , 5.0 ]
221+ for target in normal_targets :
222+ if abs (scaled_loss - target ) < 0.01 : # Very close match
223+ return False , f"Suspiciously exact loss scaling: { opt_final_loss :.4f} * { scale_factor } ≈ { target :.1f} "
208224
209225 return True , "No obvious loss scaling detected"
210226
0 commit comments