Skip to content

Resume recovery loss continuity#563

Open
Sualeh77 wants to merge 14 commits intostagingfrom
resume_recovery-loss_continuity
Open

Resume recovery loss continuity#563
Sualeh77 wants to merge 14 commits intostagingfrom
resume_recovery-loss_continuity

Conversation

@Sualeh77
Copy link

@Sualeh77 Sualeh77 commented Mar 8, 2026

Description

Implements LossContinuityGuard — a lightweight guard that detects loss discontinuities after checkpoint resume. Loss jumps after resume indicate that some training state (optimizer moments, LR scheduler, model weights) was not correctly restored. This guard records a rolling window of optimizer-step losses before each checkpoint, saves the window statistics into the checkpoint metadata (client_state), and automatically verifies continuity for the first window_size steps after resume.

Changes:

  • src/llm/loss_continuity_guard.py (new): LossContinuityGuard class with observe(), state_dict(), restore(), and verify(). Uses a dual σ-based + relative-difference (20%) threshold to tolerate natural loss noise while catching real jumps. Handles distributed training via dist.all_reduce (guarded by dist.is_initialized()).
  • src/llm/pretrainer.py (modified): Integrates the guard with minimal surface area — instantiated in init, restored in _resume(), observed after _optimizer_step(), and persisted in _save_checkpoint() via client_state["loss_guard"].

How it works at resume:

1. guard.restore(client_state["loss_guard"])   ← loads pre-resume mean/std
2. training loop: guard.observe(loss) × N      ← collects post-resume losses
3. after window_size steps: guard.verify()     ← auto-triggers, logs WARNING if jump detected

Zero latency impact: observe() is a Python list.append(float) — completely invisible against GPU-bound fused Triton kernels.


Checklist

  • I have added tests that prove my fix is effective or that my feature works.
    • tests/test_loss_continuity_guard.py — 11 unit tests covering normal resume, optimizer reset detection, LR reset detection, RNG drift tolerance, edge cases (empty restore, window rollover, relative-check fallback)
    • tests/test_loss_continuity_guard_integration.py — 2 end-to-end integration tests using gpt2 + wikitext-2-raw-v1 via HuggingFace, verifying a clean resume passes and a corrupted-weights resume is detected
  • I have added necessary documentation (if applicable).
    • LossContinuityGuard is fully docstringed with usage example, arg descriptions, and inline comments explaining threshold choices
  • My code follows the style guidelines, gitflow branching strategy, and naming conventions of this project Contribution Guidelines
    • Branch: resume_recovery-loss_continuity
    • New module placed in src/llm/ alongside existing modules (loss_spike_recovery.py, etc.)
    • Integration follows the client_state dict pattern already established in checkpoint.py and pretrainer.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants