@@ -43,7 +43,19 @@ def load_checkpoint(filename, model):
4343 for k in keys :
4444 assert k in checkpoint
4545
46- model .load_state_dict (checkpoint ['state_dict' ])
46+ if not list (model .state_dict ().keys ())[0 ].startswith ('module.' ) \
47+ and list (checkpoint ['state_dict' ])[0 ].startswith ('module.' ):
48+ # the checkpoint was saved by DDP
49+ logging .info ('load checkpoint from DDP' )
50+ dst_state_dict = model .state_dict ()
51+ src_state_dict = checkpoint ['state_dict' ]
52+ for key in dst_state_dict .keys ():
53+ src_key = '{}.{}' .format ('module' , key )
54+ dst_state_dict [key ] = src_state_dict .pop (src_key )
55+ assert len (src_state_dict ) == 0
56+ model .load_state_dict (dst_state_dict )
57+ else :
58+ model .load_state_dict (checkpoint ['state_dict' ])
4759
4860 epoch = checkpoint ['epoch' ]
4961 learning_rate = checkpoint ['learning_rate' ]
@@ -52,7 +64,10 @@ def load_checkpoint(filename, model):
5264 return epoch , learning_rate , objf
5365
5466
55- def save_checkpoint (filename , model , epoch , learning_rate , objf ):
67+ def save_checkpoint (filename , model , epoch , learning_rate , objf , local_rank = 0 ):
68+ if local_rank != 0 :
69+ return
70+
5671 logging .info ('Save checkpoint to {filename}: epoch={epoch}, '
5772 'learning_rate={learning_rate}, objf={objf}' .format (
5873 filename = filename ,
@@ -68,8 +83,17 @@ def save_checkpoint(filename, model, epoch, learning_rate, objf):
6883 torch .save (checkpoint , filename )
6984
7085
71- def save_training_info (filename , model_path , current_epoch , learning_rate , objf ,
72- best_objf , best_epoch ):
86+ def save_training_info (filename ,
87+ model_path ,
88+ current_epoch ,
89+ learning_rate ,
90+ objf ,
91+ best_objf ,
92+ best_epoch ,
93+ local_rank = 0 ):
94+ if local_rank != 0 :
95+ return
96+
7397 with open (filename , 'w' ) as f :
7498 f .write ('model_path: {}\n ' .format (model_path ))
7599 f .write ('epoch: {}\n ' .format (current_epoch ))
0 commit comments