Skip to content

Training stability loss spike recovery#558

Open
Sualeh77 wants to merge 19 commits intostagingfrom
training_stability-loss_spike_recovery
Open

Training stability loss spike recovery#558
Sualeh77 wants to merge 19 commits intostagingfrom
training_stability-loss_spike_recovery

Conversation

@Sualeh77
Copy link

@Sualeh77 Sualeh77 commented Mar 6, 2026

Pull Request Template

Description

Implements loss spike detection and automatic recovery for the pretraining pipeline, addressing training stability during LLM pretraining where sudden loss spikes can waste compute or cause divergence.

Key changes:

Two-signal detection:
1- Loss spikes: Sliding window z-score detection after forward pass — flags when loss > mean + K*std or loss > ratio * mean, with minimum absolute delta guard
2- Gradient norm explosion: L2 norm threshold check after backward pass, using DeepSpeed's global grad norm (ZeRO-safe) with local fallback

Automatic escalating recovery (production-default):

  • spike_count <= 3 → skip batch
  • spike_count <= 10 → reduce LR + skip batch
  • spike_count > 10 → rollback to last checkpoint + skip 200 batches (PaLM-style)

Supporting mechanisms:

  • Cooldown (50 steps) prevents cascading alerts after any spike action
  • Spike detector window reset on checkpoint rollback (stale statistics invalidated)
  • Embedding weight norm tracking (token_embed, lm_head, Kronecker projection) throttled to every 50 steps
  • Loss all-reduced across ranks before detection to prevent collective deadlocks
  • Interactive stdin mode as opt-in fallback (auto_recover=False), guarded against multi-GPU

Reviewers should focus on:

  • Integration points in pretrainer.py — two-stage detection (lines 118-168), while-loop conversion for rollback support
  • Escalation policy and cooldown logic in loss_spike_recovery.py
  • LossSpikeConfig defaults in config.py — are thresholds reasonable for our training regime?

References

Checklist

  • I have added tests that prove my fix is effective or that my feature works.
    • 40 unit tests in llm/tests/test_loss_spike_recovery.py covering detector, cooldown, escalation policy, config defaults, factory, grad norm, and embedding norms
    • Local integration test in llm/tests/local_spike_recovery_test.py — trains a small GPT on wikitext-2 with injected spikes, verified on M1 Mac (MPS)
  • I have added necessary documentation (if applicable)
    • llm/src/llm/LOSS_SPIKE_RECOVERY.md — conceptual design, configuration reference, and code change details
  • [ ]

Reviewers

  • Reviewer 1: Rahul Uniyal
  • Reviewer 2: Shyamant Achar

Note: Every pull request requires atleast 2 reviewers/approvers before it can be merged.

firekind and others added 19 commits February 27, 2026 09:59
…557)

- Add support for lower sequence length

Co-authored-by: Hemanth Reddy K <h.kamireddy@yuvohealth.com>
…ding norm tracking for the task : Embedding norms tracked separately (embeddings can diverge)
…ining made the auto action taking as prior choice than user intervention
…and action triggerring mechanism with wikitext data and smll gpt model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants