Skip to content

Commit 2e7fbc3

Browse files
jeremymanningclaude
andcommitted
Fix checkpoint loading: Move RNG states to CPU before restoration
Root cause: torch.load() with map_location='cuda:N' moves tensors to CUDA, but torch.set_rng_state() requires CPU tensors. Changes: - Move torch_random_state to CPU before calling set_rng_state() - Move cuda_random_state elements to CPU (set_rng_state_all handles placement) - Add try-except fallback: if RNG restoration fails, log warning and continue - This allows training to resume without losing model/optimizer state Benefits: - Resume mode now works correctly - Graceful fallback if RNG restoration fails for any reason - Preserves deterministic training when possible 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 66f0e4c commit 2e7fbc3

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

code/model_utils.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,27 @@ def load_checkpoint(model_class, model_name, device):
6969
logger.info("Restored NumPy random state")
7070

7171
if "torch_random_state" in training_state:
72-
torch.set_rng_state(training_state["torch_random_state"])
73-
logger.info("Restored PyTorch random state")
72+
try:
73+
# torch.set_rng_state() requires CPU tensor
74+
rng_state = training_state["torch_random_state"]
75+
if rng_state.device.type != 'cpu':
76+
rng_state = rng_state.cpu()
77+
torch.set_rng_state(rng_state)
78+
logger.info("Restored PyTorch random state")
79+
except Exception as e:
80+
logger.warning(f"Could not restore PyTorch RNG state: {e}. Continuing with random initialization.")
7481

7582
if "cuda_random_state" in training_state and torch.cuda.is_available():
76-
torch.cuda.set_rng_state_all(training_state["cuda_random_state"])
77-
logger.info("Restored CUDA random state")
83+
try:
84+
# Ensure CUDA RNG states are on correct devices
85+
cuda_states = training_state["cuda_random_state"]
86+
if isinstance(cuda_states, list):
87+
# Move each state to CPU if needed (set_rng_state_all handles device placement)
88+
cuda_states = [s.cpu() if hasattr(s, 'cpu') and s.device.type != 'cpu' else s for s in cuda_states]
89+
torch.cuda.set_rng_state_all(cuda_states)
90+
logger.info("Restored CUDA random state")
91+
except Exception as e:
92+
logger.warning(f"Could not restore CUDA RNG state: {e}. Continuing with random initialization.")
7893

7994
logger.info(
8095
f"Checkpoint loaded for {model_name} from epochs_completed={epochs_completed}"

0 commit comments

Comments
 (0)