Skip to content

Commit 3188b0f

Browse files
committed
Update config.yaml
1 parent 6ec75d7 commit 3188b0f

File tree

1 file changed

+80
-3
lines changed

1 file changed

+80
-3
lines changed

examples/mlx_finetuning_optimization/config.yaml

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,31 @@ prompt:
244244
# WRONG: Defaulting to zero loss rewards failed computations
245245
loss_value = float(mx.eval(loss) or 0.0) # 0.0 = perfect loss!
246246
247-
# RIGHT: Use reasonable fallback or fail gracefully
247+
# WRONG: Using NaN as fallback creates invalid metrics
248+
if scaled_loss_val is None:
249+
unscaled_loss_val = float('nan') # NaN breaks all metrics!
250+
251+
# RIGHT: Use reasonable fallback that doesn't game metrics
248252
eval_result = mx.eval(loss)
249253
if eval_result is None:
250-
raise ValueError("Loss computation failed - cannot proceed")
251-
loss_value = float(eval_result)
254+
# Use a reasonable fallback loss that doesn't artificially improve metrics
255+
loss_value = 2.0 # Reasonable cross-entropy loss, not suspiciously good
256+
print("Warning: Loss evaluation failed, using reasonable fallback")
257+
else:
258+
loss_value = float(eval_result)
259+
260+
# RIGHT: For scaled/unscaled loss patterns
261+
def safe_eval_loss(loss_tensor, fallback_value=2.0):
262+
try:
263+
result = mx.eval(loss_tensor)
264+
if result is None:
265+
return fallback_value # Reasonable fallback, not reward hacking
266+
return float(result)
267+
except Exception:
268+
return fallback_value # Consistent fallback behavior
269+
270+
scaled_loss_val = safe_eval_loss(scaled_loss, 2.0) # Reasonable fallback
271+
unscaled_loss_val = scaled_loss_val * max(total_accumulation_steps, 1)
252272
```
253273
254274
❌ **Unrealistic Performance Claims**
@@ -323,6 +343,59 @@ prompt:
323343
loss_value = eval_result[0] if eval_result is not None else 0.0
324344
```
325345
346+
❌ **mx.eval() returning None causing NaN losses**
347+
```python
348+
# WRONG: This pattern causes "Scaled loss evaluation returned None"
349+
scaled_loss = loss / total_accumulation_steps
350+
scaled_loss_val = mx.eval(scaled_loss) # Returns None!
351+
if scaled_loss_val is None:
352+
print("Error: Scaled loss evaluation returned None. Reporting NaN unscaled loss.")
353+
unscaled_loss_val = float('nan') # Creates NaN!
354+
355+
# RIGHT: Robust loss evaluation with fallbacks
356+
def safe_eval_loss(loss_tensor, description="loss"):
357+
"""Safely evaluate a loss tensor with proper error handling"""
358+
if loss_tensor is None:
359+
print(f"Warning: {description} tensor is None, using fallback")
360+
return 1.0 # Reasonable fallback loss
361+
362+
try:
363+
# Force evaluation and ensure it's materialized
364+
mx.eval(loss_tensor)
365+
eval_result = mx.eval(loss_tensor)
366+
367+
if eval_result is None:
368+
print(f"Warning: {description} evaluation returned None, using fallback")
369+
return 1.0 # Reasonable fallback
370+
371+
# Handle different return types
372+
if isinstance(eval_result, mx.array):
373+
if eval_result.size == 1:
374+
scalar_val = float(eval_result.item())
375+
else:
376+
scalar_val = float(eval_result.mean()) # Average if multiple values
377+
else:
378+
scalar_val = float(eval_result)
379+
380+
# Check for invalid values
381+
if not isinstance(scalar_val, (int, float)) or scalar_val != scalar_val: # NaN check
382+
print(f"Warning: {description} evaluation returned invalid value: {scalar_val}")
383+
return 1.0 # Reasonable fallback
384+
385+
return scalar_val
386+
387+
except Exception as e:
388+
print(f"Error evaluating {description}: {e}. Using fallback.")
389+
return 1.0 # Reasonable fallback
390+
391+
# Usage:
392+
scaled_loss = loss / max(total_accumulation_steps, 1)
393+
scaled_loss_val = safe_eval_loss(scaled_loss, "scaled loss")
394+
unscaled_loss_val = scaled_loss_val * max(total_accumulation_steps, 1)
395+
396+
return float(unscaled_loss_val), should_update
397+
```
398+
326399
❌ **integer modulo by zero**
327400
```python
328401
# WRONG: if step % accumulation_steps == 0: # accumulation_steps might be 0
@@ -446,10 +519,14 @@ prompt:
446519
- **Report ACTUAL loss values, not scaled or manipulated values**
447520
- **Use REAL timing and memory measurements**
448521
- **Ensure training actually works and learns**
522+
- **Handle mx.eval() None returns with reasonable fallbacks (NOT zero or NaN)**
523+
- **Never use NaN, infinity, or zero as loss fallbacks**
524+
- **Fallback loss values should be realistic (1.0-3.0 for cross-entropy)**
449525
- **Realistic improvement targets: 10-50% speed, 20-40% memory reduction**
450526
- **Loss should remain in range 0.1-10.0 for cross-entropy**
451527
- **Any >10x improvement claims will be automatically rejected**
452528
- **Zero or near-zero loss values (<0.01) will be flagged as reward hacking**
529+
- **NaN loss values indicate broken evaluation and will be rejected**
453530
454531
**IMPLEMENTATION CONSTRAINTS:**
455532
- Must use MLX operations and data types

0 commit comments

Comments
 (0)