Skip to content

Commit d17b10e

Browse files
author
GrzegorzKarchNV
authored
Merge pull request #759 from GrzegorzKarchNV/fix_rng_state
fixing rng_state for compatibility with older checkpoints
2 parents 4a64c5b + 9a6c524 commit d17b10e

File tree

1 file changed

+6
-1
lines changed
  • PyTorch/SpeechSynthesis/Tacotron2

1 file changed

+6
-1
lines changed

PyTorch/SpeechSynthesis/Tacotron2/train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,12 @@ def load_checkpoint(model, optimizer, epoch, config, amp_run, filepath, local_ra
250250
epoch[0] = checkpoint['epoch']+1
251251
device_id = local_rank % torch.cuda.device_count()
252252
torch.cuda.set_rng_state(checkpoint['cuda_rng_state_all'][device_id])
253-
torch.random.set_rng_state(checkpoint['random_rng_states_all'][device_id])
253+
if 'random_rng_states_all' in checkpoint:
254+
torch.random.set_rng_state(checkpoint['random_rng_states_all'][device_id])
255+
elif 'random_rng_state' in checkpoint:
256+
torch.random.set_rng_state(checkpoint['random_rng_state'])
257+
else:
258+
raise Exception("Model checkpoint must have either 'random_rng_state' or 'random_rng_states_all' key.")
254259
config = checkpoint['config']
255260
model.load_state_dict(checkpoint['state_dict'])
256261
optimizer.load_state_dict(checkpoint['optimizer'])

0 commit comments

Comments
 (0)