Skip to content

Commit 509f005

Browse files
authored
[Cherry-pick]Support to load sharded EMA checkpoint (#11075)
* support load sharded EMA checkpoints * support_ema_loading_no_pdopt * polish code
1 parent 20a584f commit 509f005

File tree

4 files changed

+19
-3
lines changed

4 files changed

+19
-3
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ def __init__(
378378
self.model,
379379
self.optimizer,
380380
remap_parameter_name=self.args.load_sharded_model_remap_parameter_name,
381+
is_ema=self.args.sharded_model_from_ema,
381382
)
382383

383384
if self.args.unified_checkpoint:

paddlenlp/trainer/training_args.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,11 @@ class TrainingArguments:
638638
metadata={"help": "Whether to remap parameter name when load_sharded_model = true."},
639639
)
640640

641+
sharded_model_from_ema: bool = field(
642+
default=False,
643+
metadata={"help": "Whether to load sharded model from EMA."},
644+
)
645+
641646
tensor_parallel_degree: int = field(
642647
default=-1,
643648
metadata={

paddlenlp/trainer/utils/reshard/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def convert_opt_name_to_tname(tensor_names, opt_names):
102102
opt_to_t[t] = t[: -len(s)]
103103
_find = True
104104
break
105-
assert _find
105+
assert _find, t
106106
return opt_to_t
107107

108108

paddlenlp/trainer/utils/sharding_io.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def get_group_ids(self):
270270

271271

272272
class ShardingIO:
273-
def __init__(self, args, model, optimizer=None, hcg=None, remap_parameter_name=False):
273+
def __init__(self, args, model, optimizer=None, hcg=None, remap_parameter_name=False, is_ema=False):
274274
self.args = args
275275
self.model = model
276276
self.optimizer = optimizer
@@ -282,6 +282,7 @@ def __init__(self, args, model, optimizer=None, hcg=None, remap_parameter_name=F
282282

283283
self.remap_parameter_name = remap_parameter_name
284284
self.remapper = None
285+
self.is_ema = is_ema
285286

286287
def _get_remapper(self, checkpoint):
287288
if not self.remap_parameter_name:
@@ -395,24 +396,33 @@ def _load_one_state_dict_from_checkpoint(self, resume_from_checkpoint, base_weig
395396
"""
396397
load state_dict of one shard from_checkpoint, Only load model state dict.
397398
"""
399+
if self.is_ema:
400+
base_weight_name = base_weight_name.replace("model_state", "ema").replace("pdparams", "pdopt")
398401
file_path = os.path.join(resume_from_checkpoint, _add_variant(base_weight_name, weight_name_suffix))
399402
if not os.path.isfile(file_path):
400403
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}, no {file_path}")
401404

402405
logger.info(f"Loading model from {resume_from_checkpoint} .")
403406
# We load the model state dict on the CPU to avoid an OOM error.
404407
state_dict = paddle.load(file_path, return_numpy=True)
408+
if self.is_ema:
409+
state_dict.pop("master_weights", None)
405410
state_dict = self._remap_parameter_name(resume_from_checkpoint, state_dict, is_opt=False)
406411
return state_dict
407412

408413
def _load_optimizer_state_of_one_shard(self, checkpoint, base_opt_name, optimizer_name_suffix, group_getter=None):
414+
if self.is_ema:
415+
base_opt_name = base_opt_name.replace("optimizer", "ema")
409416
optimizer_name = _add_variant(base_opt_name, optimizer_name_suffix)
410417
path = os.path.join(checkpoint, optimizer_name)
411418
logger.info(f"load optimizer state from {path}")
412419
if os.path.isfile(path):
420+
opt_state = paddlenlp_load(path, map_location="cpu")
421+
if self.is_ema:
422+
opt_state = {"master_weights": opt_state.get("master_weights", {})}
413423
return self._remap_parameter_name(
414424
checkpoint,
415-
self._modify_ckpt_for_compatibility(paddlenlp_load(path, map_location="cpu")),
425+
self._modify_ckpt_for_compatibility(opt_state),
416426
is_opt=True,
417427
)
418428
logger.info(f"{path} not exists")

0 commit comments

Comments
 (0)