Skip to content

Commit b63efc3

Browse files
authored
add PyTorch's DistributedDataParallel training. (#3940)
* add PyTorch's DistributedDataParallel training. * fix checkpoint loader ofr DDP training. * invoke constrain_orthornomral via pre-forward hooks. * fix a typo in learning rate scheduler.
1 parent 342adfa commit b63efc3

File tree

9 files changed

+552
-62
lines changed

9 files changed

+552
-62
lines changed

egs/aishell/s10/chain/common.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,19 @@ def load_checkpoint(filename, model):
4343
for k in keys:
4444
assert k in checkpoint
4545

46-
model.load_state_dict(checkpoint['state_dict'])
46+
if not list(model.state_dict().keys())[0].startswith('module.') \
47+
and list(checkpoint['state_dict'])[0].startswith('module.'):
48+
# the checkpoint was saved by DDP
49+
logging.info('load checkpoint from DDP')
50+
dst_state_dict = model.state_dict()
51+
src_state_dict = checkpoint['state_dict']
52+
for key in dst_state_dict.keys():
53+
src_key = '{}.{}'.format('module', key)
54+
dst_state_dict[key] = src_state_dict.pop(src_key)
55+
assert len(src_state_dict) == 0
56+
model.load_state_dict(dst_state_dict)
57+
else:
58+
model.load_state_dict(checkpoint['state_dict'])
4759

4860
epoch = checkpoint['epoch']
4961
learning_rate = checkpoint['learning_rate']
@@ -52,7 +64,10 @@ def load_checkpoint(filename, model):
5264
return epoch, learning_rate, objf
5365

5466

55-
def save_checkpoint(filename, model, epoch, learning_rate, objf):
67+
def save_checkpoint(filename, model, epoch, learning_rate, objf, local_rank=0):
68+
if local_rank != 0:
69+
return
70+
5671
logging.info('Save checkpoint to {filename}: epoch={epoch}, '
5772
'learning_rate={learning_rate}, objf={objf}'.format(
5873
filename=filename,
@@ -68,8 +83,17 @@ def save_checkpoint(filename, model, epoch, learning_rate, objf):
6883
torch.save(checkpoint, filename)
6984

7085

71-
def save_training_info(filename, model_path, current_epoch, learning_rate, objf,
72-
best_objf, best_epoch):
86+
def save_training_info(filename,
87+
model_path,
88+
current_epoch,
89+
learning_rate,
90+
objf,
91+
best_objf,
92+
best_epoch,
93+
local_rank=0):
94+
if local_rank != 0:
95+
return
96+
7397
with open(filename, 'w') as f:
7498
f.write('model_path: {}\n'.format(model_path))
7599
f.write('epoch: {}\n'.format(current_epoch))

0 commit comments

Comments
 (0)