Skip to content

Validation loss: Eval frequency, Eval Loss tracked, Train-Eval gap monitored#565

Open
Sualeh77 wants to merge 3 commits intorefactor/consolidationfrom
validation_loss-eval_frequency
Open

Validation loss: Eval frequency, Eval Loss tracked, Train-Eval gap monitored#565
Sualeh77 wants to merge 3 commits intorefactor/consolidationfrom
validation_loss-eval_frequency

Conversation

@Sualeh77
Copy link

@Sualeh77 Sualeh77 commented Mar 8, 2026

Description

Related Tasks

  • Eval frequency defined (every N steps)
  • Eval loss tracked alongside train loss
  • Train-eval gap monitored (overfitting detection)

Summary

This PR implements mid-training evaluation and overfitting detection in the distributed pre-training pipeline.

Prior state: pretrainer.py already ran a validation pass at the end of every epoch and logged val_loss / val_perplexity. However, there was no intra-epoch evaluation triggered at a configurable step cadence, no smoothed comparison between training and validation loss, and no automated mechanism to detect when the model was overfitting.

Changes made:

llm/src/llm/config.py

  • Added overfit_patience: int = 5 to TrainingConfig — number of consecutive evaluations without improvement before an overfitting alert is raised.
  • Added overfit_threshold: float = 0.0 (Configure during training) to TrainingConfig — minimum required decrease in val_loss to count as an improvement and reset the counter.

llm/src/llm/pretrainer.py

  • Overfitting detection state (_train_loss_accum, _train_loss_count, _best_val_loss, _overfit_strikes) initialised in init.
  • Smoothed train loss: Training loss is accumulated across the eval_interval window and averaged at evaluation time, eliminating batch-level noise from the gap signal.
  • Train-eval gap: train_eval_gap = val_loss − smoothed_train_loss is computed and logged at every eval trigger point.
  • Overfitting strike counter: Resets to 0 when val_loss improves beyond overfit_threshold; increments by 1 otherwise.
  • Alert: When overfit_strikes >= overfit_patience, an overfitting_detected: True event is sent to the observability backend, which can be consumed by the Watchdog to pause training.
  • Distributed correctness: _validate() now applies dist.all_reduce(SUM) / world_size to both avg_loss and avg_perplexity tensors, ensuring the reported validation metrics are true global averages across all GPUs — not just the rank-0 shard.
  • Zero latency impact: All new bookkeeping (+=, comparisons, averaging) happens on already-available CPU scalars. No new GPU synchronisations are introduced in the per-step hot path.

llm/tests/test_overfit_detection_integration.py (new file)

End-to-end integration test using gpt2 on wikitext-2-raw-v1 via the project's existing get_dataloaders() factory. Trains for 20 steps with eval_interval=5 entirely on CPU (no DeepSpeed required), exercising the exact same logic path as PreTrainer.run().

Reviewers should focus on:

  • The _validate() all_reduce block — verify the averaging is correct for your world-size assumptions.
  • The overfit_patience / overfit_threshold defaults — adjust to match the team's operational preferences.
  • Whether the overfitting_detected event should eventually trigger an automated Watchdog PAUSE via watchdog.py.

Checklist

  • I have added tests that prove my fix is effective or that my feature works.
    • tests/test_overfit_detection_integration.py — 11 assertions covering eval cadence, gap arithmetic, strike counter logic, monotonicity of best_val_loss, and alert firing. All pass (11 passed in ~2m 21s on Mac CPU).
  • I have added necessary documentation (if applicable).
    • Inline docstrings and comments in pretrainer.py and config.py explain each new field and the rationale for smoothing and distributed averaging.
  • My code follows the style guidelines, gitflow branching strategy, and naming conventions of this project Contribution Guidelines.
    • Changes are on branch validation_loss-eval_frequency.
    • Naming conventions (_overfit_strikes, train_eval_gap, overfit_patience) follow the existing snake_case style in TrainingConfig and PreTrainer.

Train-eval gap monitored (overfitting detection)
@Sualeh77 Sualeh77 changed the title Added simple branching condition for Eval frequency check to run eval… Validation loss: Eval frequency, Eval Loss tracked, Train-Eval gap monitored Mar 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants