@@ -270,7 +270,7 @@ def get_group_ids(self):
270
270
271
271
272
272
class ShardingIO :
273
- def __init__ (self , args , model , optimizer = None , hcg = None , remap_parameter_name = False ):
273
+ def __init__ (self , args , model , optimizer = None , hcg = None , remap_parameter_name = False , is_ema = False ):
274
274
self .args = args
275
275
self .model = model
276
276
self .optimizer = optimizer
@@ -282,6 +282,7 @@ def __init__(self, args, model, optimizer=None, hcg=None, remap_parameter_name=F
282
282
283
283
self .remap_parameter_name = remap_parameter_name
284
284
self .remapper = None
285
+ self .is_ema = is_ema
285
286
286
287
def _get_remapper (self , checkpoint ):
287
288
if not self .remap_parameter_name :
@@ -395,24 +396,33 @@ def _load_one_state_dict_from_checkpoint(self, resume_from_checkpoint, base_weig
395
396
"""
396
397
load state_dict of one shard from_checkpoint, Only load model state dict.
397
398
"""
399
+ if self .is_ema :
400
+ base_weight_name = base_weight_name .replace ("model_state" , "ema" ).replace ("pdparams" , "pdopt" )
398
401
file_path = os .path .join (resume_from_checkpoint , _add_variant (base_weight_name , weight_name_suffix ))
399
402
if not os .path .isfile (file_path ):
400
403
raise ValueError (f"Can't find a valid checkpoint at { resume_from_checkpoint } , no { file_path } " )
401
404
402
405
logger .info (f"Loading model from { resume_from_checkpoint } ." )
403
406
# We load the model state dict on the CPU to avoid an OOM error.
404
407
state_dict = paddle .load (file_path , return_numpy = True )
408
+ if self .is_ema :
409
+ state_dict .pop ("master_weights" , None )
405
410
state_dict = self ._remap_parameter_name (resume_from_checkpoint , state_dict , is_opt = False )
406
411
return state_dict
407
412
408
413
def _load_optimizer_state_of_one_shard (self , checkpoint , base_opt_name , optimizer_name_suffix , group_getter = None ):
414
+ if self .is_ema :
415
+ base_opt_name = base_opt_name .replace ("optimizer" , "ema" )
409
416
optimizer_name = _add_variant (base_opt_name , optimizer_name_suffix )
410
417
path = os .path .join (checkpoint , optimizer_name )
411
418
logger .info (f"load optimizer state from { path } " )
412
419
if os .path .isfile (path ):
420
+ opt_state = paddlenlp_load (path , map_location = "cpu" )
421
+ if self .is_ema :
422
+ opt_state = {"master_weights" : opt_state .get ("master_weights" , {})}
413
423
return self ._remap_parameter_name (
414
424
checkpoint ,
415
- self ._modify_ckpt_for_compatibility (paddlenlp_load ( path , map_location = "cpu" ) ),
425
+ self ._modify_ckpt_for_compatibility (opt_state ),
416
426
is_opt = True ,
417
427
)
418
428
logger .info (f"{ path } not exists" )
0 commit comments