Skip to content

Resume recovery - RNG state manager#564

Open
Sualeh77 wants to merge 1 commit intorefactor/consolidationfrom
resume_recovery-rng_state
Open

Resume recovery - RNG state manager#564
Sualeh77 wants to merge 1 commit intorefactor/consolidationfrom
resume_recovery-rng_state

Conversation

@Sualeh77
Copy link

@Sualeh77 Sualeh77 commented Mar 8, 2026

Pull Request Template

Description

RNG State Restoration for Reproducible Training Resume

When training resumes from a checkpoint, random number generator states were not being saved or restored. This caused data shuffling order, dropout masks, and other random operations to diverge from the original run, making training non-reproducible after resume.

This PR adds an RNGStateManager module that captures and restores RNG states across all libraries (Python random, NumPy, PyTorch CPU, PyTorch CUDA) and integrates it into the checkpoint save/load pipeline with minimal changes to existing code.

What changed

New files:

  • llm/src/llm/rng_state_manager.py — RNGStateManager class with capture() and restore() static methods. Handles all 4 RNG sources, guards against CUDA device count mismatch on world-size changes (warns instead of crashing), and gracefully skips CUDA RNG when no GPU is available.
  • llm/tests/test_rng_state_manager.py — 12 unit tests covering capture key structure, round-trip restore for all RNG sources, partial/empty state handling, and checkpoint dict integration.
  • llm/tests/test_rng_state_e2e.py — 2 end-to-end tests that train sshleifer/tiny-gpt2 on wikitext-2-raw-v1, checkpoint RNG state mid-training, trash and restore it, then verify losses match the uninterrupted baseline exactly. Includes a negative test proving losses diverge without restore.

Modified files:

  • llm/src/llm/pretrainer.py — 2 minimal changes:
  • _save_checkpoint(): adds "rng_state": RNGStateManager.capture() to client_state
  • _resume(): calls RNGStateManager.restore(rng_state) before training resumes

Key design decisions

  • Static methods, no instance state — drop-in single-call API
  • RNG state stored inside existing client_state dict — no checkpoint format changes
  • Per-rank by default since DeepSpeed saves client_state per-rank in each shard
  • Restore happens in _resume() before training loop iterates, which is correct because both DistributedSampler and
  • DataLoader(shuffle=True) consume RNG at iterator creation time, not at construction

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants