Skip to content

Commit 85295b6

Browse files
bo-keMeiyim
andauthored
[fea] support dp-moe for zcc and global_expert_id (#11050)
* support dp-moe for zcc (#10539) * [fix] zcc ema under non-pp when `acc=1` * zcc-ema fix load-state-dict when dp-moe * [BugFix] shard-reshard下dp间is_matched行为不一致,以dp0为准 (#9404) --------- Co-authored-by: Ferrebo <[email protected]> * [mod] Sharding IO: Added DP-Meta gather for MoE --------- Co-authored-by: Meiyim <[email protected]>
1 parent 509f005 commit 85295b6

File tree

2 files changed

+39
-27
lines changed

2 files changed

+39
-27
lines changed

paddlenlp/trainer/utils/sharding_io.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def _load_one_state_dict_from_checkpoint(self, resume_from_checkpoint, base_weig
402402
if not os.path.isfile(file_path):
403403
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}, no {file_path}")
404404

405-
logger.info(f"Loading model from {resume_from_checkpoint} .")
405+
logger.info(f"Loading model from {file_path}.")
406406
# We load the model state dict on the CPU to avoid an OOM error.
407407
state_dict = paddle.load(file_path, return_numpy=True)
408408
if self.is_ema:
@@ -524,6 +524,11 @@ def load_optimizer_state_with_reshard(self, checkpoint, base_opt_name, model_wra
524524
is_matched = reshard_util.sharding_v2.is_matched_optimizer_state_dict(
525525
one_shard_opt_state_dict, self.optimizer, model_wrapped
526526
)
527+
is_matched = paddle.to_tensor([is_matched], dtype=paddle.int32)
528+
dp_group = fleet.get_hybrid_communicate_group().get_data_parallel_group()
529+
dp_src_rank = fleet.get_hybrid_communicate_group().get_data_parallel_group_src_rank()
530+
dist.broadcast(is_matched, src=dp_src_rank, group=dp_group)
531+
is_matched = bool(is_matched[0])
527532
else:
528533
is_matched = True
529534

@@ -904,6 +909,10 @@ def _gather_sharding_metas(self):
904909
sharding_meta["param_meta_keys"] = ["shape", "dtype", "is_distributed", "no_sync"]
905910
sharding_meta["sharding_strategy"] = sharding_strategy
906911
sharding_meta["enable_overlap"] = pp_overlap
912+
dp_metas_list = self._all_gather_simple_object(sharding_meta, self.hcg.get_data_parallel_group())
913+
for e in dp_metas_list:
914+
for key in ["structure_name_mapping", "param_meta"]:
915+
sharding_meta[key].update(e[key])
907916
suffix = self._sharding_meta_suffix()
908917
sharding_metas[suffix] = sharding_meta
909918
sharding_metas_list = self._all_gather_simple_object(sharding_metas, self.hcg.get_model_parallel_group())

paddlenlp/trainer/utils/zero_cost_checkpoint.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -220,36 +220,34 @@ def ema_state_dict(self):
220220
ema_state_dict[k] = tensor
221221
ema_state_dict_master_weights = {}
222222
for k, meta in self.optimizer_fusion_storage_helper.master_weights_meta.items():
223-
t = self.ema_buffer._slice(
224-
meta["start"] - self.master_min_offset, meta["end"] - self.master_min_offset
225-
).clone()
223+
s = meta["start"] - self.master_min_offset
224+
e = meta["end"] - self.master_min_offset
225+
t = self.ema_buffer._slice(s, e).clone()
226226
t.get_tensor()._set_dims(meta["shape"])
227227
t.name = meta["name"]
228228
ema_state_dict_master_weights[k] = t
229229
ema_state_dict["master_weights"] = ema_state_dict_master_weights
230230
return ema_state_dict
231231

232-
def load_ema_state_dict(self, path):
233-
with device_guard("cpu"):
234-
logger.info(f"[ZCC EMA] load state dict from {path}")
235-
state_dict = paddle.load(path)
236-
for k, tensor_meta in self.param_fusion_storage_helper.model_weights_metas.items():
237-
logger.info(f"[ZCC EMA] load model weight key={k}")
238-
start = tensor_meta["start"]
239-
end = tensor_meta["end"]
240-
if tensor_meta["buffer_index"] not in self.ema_buffer_model_params:
241-
continue # non fp32 has no `self.ema_buffer_model_params`
232+
def load_ema_state_dict(self, state_dict):
233+
for k, tensor_meta in self.param_fusion_storage_helper.model_weights_metas.items():
234+
logger.info(f"[ZCC EMA] load model weight key={k}")
235+
start = tensor_meta["start"]
236+
end = tensor_meta["end"]
237+
if tensor_meta["buffer_index"] not in self.ema_buffer_model_params:
238+
continue # non fp32 has no `self.ema_buffer_model_params`
239+
if k in state_dict:
242240
cpu_buffer = self.ema_buffer_model_params[tensor_meta["buffer_index"]]
243241
tensor = state_dict[k].flatten()
244242
cpu_buffer[start:end] = tensor
245243

246-
ema_master = state_dict["master_weights"]
247-
for k, meta in self.optimizer_fusion_storage_helper.master_weights_meta.items():
248-
logger.info(f"[ZCC EMA] load optimizer weight key={k}")
249-
s = meta["start"] - self.master_min_offset
250-
e = meta["end"] - self.master_min_offset
251-
self.ema_buffer[s:e] = ema_master[k]
252-
logger.info("[ZCC EMA] done loading")
244+
ema_master = state_dict["master_weights"]
245+
for k, meta in self.optimizer_fusion_storage_helper.master_weights_meta.items():
246+
logger.info(f"[ZCC EMA] load optimizer weight key={k}")
247+
s = meta["start"] - self.master_min_offset
248+
e = meta["end"] - self.master_min_offset
249+
if k in ema_master: # state-dict is filtered
250+
self.ema_buffer[s:e] = ema_master[k].flatten()
253251

254252

255253
class ParamFusionStorageHelper:
@@ -408,11 +406,6 @@ def on_optimizer_begin(self, args, state, control, **kwargs):
408406
logger.info("[ZCC manager] Synced checkpoints.")
409407

410408
def on_step_end(self, args, state, control, model, lr_scheduler, optimizer, **kwargs):
411-
if not isinstance(model, PipelineLayer):
412-
self.manager.zcc_pipeline_hook(0)
413-
# logger.info(
414-
# f"check coef: {args.zcc_save_ema_coef} {control.should_save}, {state.global_step}, {self.zcc_ema_interval}"
415-
# )
416409
if not control.should_save:
417410
if args.zcc_save_ema_coef is not None and state.global_step % self.zcc_ema_interval == 0:
418411
self.maybe_update_zcc_worker(args, model, optimizer, state.global_step)
@@ -425,6 +418,8 @@ def on_step_end(self, args, state, control, model, lr_scheduler, optimizer, **kw
425418
non_cached_objects = (lr_scheduler.state_dict(), state, self.get_rng_states(args))
426419
self.manager.get_idle_worker_for_saving((save_infos, non_cached_objects))
427420
self.runtime_timer.stop()
421+
if not isinstance(model, PipelineLayer):
422+
self.manager.zcc_pipeline_hook(0)
428423

429424
def get_rng_states(self, args):
430425
if not args.save_rng_states:
@@ -959,7 +954,15 @@ def run(self):
959954
self.optimizer_fusion_storage_helper, self.param_fusion_storage_helper, self.ema_coef
960955
)
961956
if ema_ckpt_path is not None: # update ema if needed
962-
self.zcc_ema_processor.load_ema_state_dict(ema_ckpt_path)
957+
logger.info(f"[ZCC EMA] load state dict from {ema_ckpt_path}")
958+
with device_guard("cpu"):
959+
state_dict = paddle.load(ema_ckpt_path)
960+
if self.use_expert_parallel and self.dp_rank > 0:
961+
state_dict = self._filter_moe_no_sync_optimizer_params(
962+
self.model_meta_content, state_dict
963+
)
964+
self.zcc_ema_processor.load_ema_state_dict(state_dict)
965+
logger.info("[ZCC EMA] done loading")
963966
ema_ckpt_path = None
964967
elif task_type == ZCCTaskType.PREPARE:
965968
start_time = time.time()

0 commit comments

Comments
 (0)