Skip to content

Commit 587a5cd

Browse files
authored
fix load ema and model_meta path (#11203)
1 parent 32b9a2e commit 587a5cd

File tree

2 files changed

+27
-27
lines changed

2 files changed

+27
-27
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -784,27 +784,27 @@ def get_metadata_file_name(path):
784784
metadata = paddle.load(metadata_file)
785785
state_dict_metadata.update(metadata.state_dict_metadata)
786786

787-
init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata)
787+
if not self.args.sharded_model_from_ema:
788+
init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata)
788789

789-
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
790+
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
790791

791-
opt_states = {}
792-
master_weights = {}
793-
for k, v in optimizer_sharded_state_dict.items():
794-
if k.endswith(".w_0"):
795-
master_weights[k] = v
796-
else:
797-
opt_states[k] = v
792+
opt_states = {}
793+
master_weights = {}
794+
for k, v in optimizer_sharded_state_dict.items():
795+
if k.endswith(".w_0"):
796+
master_weights[k] = v
797+
else:
798+
opt_states[k] = v
798799

799-
dist.load_state_dict(
800-
opt_states,
801-
opt_states_path,
802-
aoa_config=self.args.aoa_config,
803-
offload=self.args.load_via_cpu,
804-
comm_method=self.args.flex_ckpt_comm_method,
805-
)
800+
dist.load_state_dict(
801+
opt_states,
802+
opt_states_path,
803+
aoa_config=self.args.aoa_config,
804+
offload=self.args.load_via_cpu,
805+
comm_method=self.args.flex_ckpt_comm_method,
806+
)
806807

807-
if not self.args.sharded_model_from_ema:
808808
dist.load_state_dict(
809809
master_weights,
810810
master_weights_path,
@@ -819,12 +819,8 @@ def get_metadata_file_name(path):
819819
ema_states_path = os.path.join(resume_from_checkpoint, EMA_STATE_DIC, f"{dist.get_rank()}_0.distcp")
820820
ema_state_dict = paddle.load(ema_states_path)
821821
ema_master_weights = ema_state_dict.pop("master_weights", None)
822-
opt_master_weights = self.optimizer.state_dict()["master_weights"]
823-
for k, v in opt_master_weights.items():
824-
assert (
825-
k in ema_master_weights
826-
), f"{k} not in ema_master_weights, emas_master_weight keys {ema_master_weights.keys()}"
827-
paddle.assign(ema_master_weights[k], opt_master_weights[k])
822+
opt_state_dict = {"master_weights": ema_master_weights}
823+
self.optimizer.set_state_dict(opt_state_dict)
828824

829825
self.model.set_state_dict(ema_state_dict)
830826
else:

paddlenlp/trainer/utils/zero_cost_checkpoint.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,6 +1253,7 @@ def __init__(self, resume_from_checkpoint, args, offload=True, hcg=None, model=N
12531253
self.optimizer = optimizer
12541254
self.dist_info_collector_and_validator = DistInfoCollectorValidator(args, hcg)
12551255

1256+
self.device_id = int(os.getenv("FLAGS_selected_gpus"))
12561257
super().__init__(resume_from_checkpoint, args, offload)
12571258

12581259
def _get_model_meta(self):
@@ -1262,7 +1263,7 @@ def _ema_path(self, base_path):
12621263
return os.path.join(base_path, "ema_state", f"{dist.get_rank()}_0.distcp")
12631264

12641265
def _check_consistent_dist_strategy(self, resume_from_checkpoint):
1265-
return self.dist_info_collector_and_validator.check_same_strategy(os.path.dirname(resume_from_checkpoint))
1266+
return self.dist_info_collector_and_validator.check_same_strategy(resume_from_checkpoint)
12661267

12671268
def _get_model_state(self):
12681269
assert self.model is not None, "expected model is not None"
@@ -1274,9 +1275,12 @@ def _get_master_weight(self):
12741275

12751276
def save(self, global_step):
12761277
model_meta_content = self._get_model_meta()
1277-
model_meta_path = os.path.join(self.args.output_dir, MODEL_META_NAME)
1278-
with open(model_meta_path, "w") as f:
1279-
json.dump(model_meta_content, f)
1278+
base_path = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{global_step}")
1279+
os.makedirs(base_path, exist_ok=True)
1280+
model_meta_path = os.path.join(base_path, MODEL_META_NAME)
1281+
if self.device_id == 0:
1282+
with open(model_meta_path, "w") as f:
1283+
json.dump(model_meta_content, f)
12801284

12811285
super().save(global_step)
12821286

0 commit comments

Comments
 (0)