Skip to content

Commit f22279a

Browse files
authored
[Unified Checkpoint] Support deepep (#2623)
1 parent e54947d commit f22279a

File tree

4 files changed

+112
-36
lines changed

4 files changed

+112
-36
lines changed

paddleformers/trainer/training_args.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2231,6 +2231,17 @@ def expert_parallel_rank(self):
22312231
else:
22322232
return 0
22332233

2234+
@property
2235+
def moe_sharding_parallel_rank(self):
2236+
if self.use_hybrid_parallel:
2237+
hcg = fleet.get_hybrid_communicate_group()
2238+
if hasattr(hcg, "get_moe_sharding_parallel_group"):
2239+
return max(hcg.get_moe_sharding_parallel_group().rank, 0)
2240+
else:
2241+
return 0
2242+
else:
2243+
return 0
2244+
22342245
def _format_name(self, prefix, rank, degree):
22352246
size = 2
22362247
return f"{prefix}{rank:0>{size}d}"
@@ -2390,7 +2401,9 @@ def should_save_model_state(self):
23902401
return True
23912402
elif self.use_hybrid_parallel:
23922403
# save on dataset rank 0
2393-
return self.sharding_parallel_rank == 0 and (self.data_parallel_rank == 0 or self.use_expert_parallel)
2404+
return (
2405+
self.sharding_parallel_rank == 0 and (self.data_parallel_rank == 0 or self.use_expert_parallel)
2406+
) or (self.expert_parallel_degree > 1 and self.moe_sharding_parallel_rank == 0)
23942407
else:
23952408
return self.process_index == 0 or self.use_expert_parallel
23962409

paddleformers/trainer/unified_checkpoint/unified_checkpoint.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -232,10 +232,21 @@ def save_non_merge_optimizer(
232232
for key in list(master_weights.keys()):
233233
master_weights[static2struct_name_mappings[key]] = master_weights.pop(key)
234234

235-
if self.args.use_expert_parallel:
236-
model_state_dict = get_expected_state_dict(model)
237-
filter_sync_parameters(model_state_dict, optim_state_dict, is_model_weight=False)
238-
filter_sync_parameters(model_state_dict, master_weights, is_model_weight=False)
235+
model_state_dict = get_expected_state_dict(model)
236+
filter_sync_parameters(
237+
model_state_dict,
238+
optim_state_dict,
239+
is_model_weight=False,
240+
use_expert_parallel=self.args.use_expert_parallel,
241+
expert_parallel_degree=self.args.expert_parallel_degree,
242+
)
243+
filter_sync_parameters(
244+
model_state_dict,
245+
master_weights,
246+
is_model_weight=False,
247+
use_expert_parallel=self.args.use_expert_parallel,
248+
expert_parallel_degree=self.args.expert_parallel_degree,
249+
)
239250

240251
optimizer_name = _add_variant(SAFE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
241252
master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, self.args.optimizer_name_suffix)
@@ -607,9 +618,12 @@ def unified_checkpoint_into_shards(
607618

608619
config_to_save = copy.deepcopy(model_to_save.config)
609620

610-
if args.use_expert_parallel:
611-
# ignore saving `no_sync=False` tensors when using expert_parallel under dp_rank > 0.
612-
filter_sync_parameters(state_dict, is_model_weight=True)
621+
filter_sync_parameters(
622+
state_dict,
623+
is_model_weight=True,
624+
use_expert_parallel=args.use_expert_parallel,
625+
expert_parallel_degree=args.expert_parallel_degree,
626+
)
613627

614628
if config_to_save.tensor_parallel_degree > 1:
615629
if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM):
@@ -639,20 +653,24 @@ def unified_checkpoint_into_shards(
639653

640654
shard_file = get_sharded_file_name(args, weights_name)
641655
# renumerize shard_file name for expert_parallel.
642-
if args.use_expert_parallel:
656+
if args.use_expert_parallel and args.expert_parallel_degree <= 1:
643657
shard_file = rename_shard_file(args, shard_file, weights_name)
644658

645659
for key, weight in state_dict.items():
646660
index_weight_file[key] = shard_file
647661
total_size += weight.numel().item() * dtype_byte_size(weight.dtype)
648662

649663
index_file_list, total_size_list = gather_sharded_object(
650-
index_weight_file, total_size, use_expert_parallel=args.use_expert_parallel
664+
index_weight_file,
665+
total_size,
666+
use_expert_parallel=args.use_expert_parallel,
667+
expert_parallel_degree=args.expert_parallel_degree,
651668
)
652669
sharded_index = get_sharded_index(
653670
index_file_list,
654671
total_size_list,
655672
)
673+
656674
if sharded_index is not None:
657675
if isinstance(model_to_save, LoRAModel):
658676
sharded_index["type"] = "lora"
@@ -724,8 +742,13 @@ def unified_optimizer_into_shards(
724742
tp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group()
725743
tp_size = tp_group.nranks
726744

727-
if args.use_expert_parallel:
728-
filter_sync_parameters(model_state_dict, state_dict, is_model_weight=False)
745+
filter_sync_parameters(
746+
model_state_dict,
747+
state_dict,
748+
is_model_weight=False,
749+
use_expert_parallel=args.use_expert_parallel,
750+
expert_parallel_degree=args.expert_parallel_degree,
751+
)
729752

730753
if tp_size > 1:
731754
# get tp_actions

paddleformers/trainer/unified_checkpoint/utils.py

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,8 @@ def filter_params(model_to_save, state_dict, args, is_optimizer=False):
493493
weight_key = k.split("/")[0]
494494
model_v = model_state_dict[weight_key] if is_optimizer else v
495495
mp_moe = getattr(model_v, "mp_moe", False)
496-
if not mp_moe:
496+
no_sync = getattr(model_v, "no_sync", False)
497+
if not mp_moe or no_sync:
497498
if not quant or not is_optimizer:
498499
if hasattr(model_v, "is_distributed") and model_v.is_distributed:
499500
tensor_bytes_dict[k] = v.numel().item() * tp_size * dtype_byte_size(v.dtype)
@@ -555,6 +556,9 @@ def filter_params(model_to_save, state_dict, args, is_optimizer=False):
555556
mp_moe = getattr(model_v, "mp_moe", False)
556557
if mp_moe:
557558
filter_tensor_list[tp_rank].append(k)
559+
no_sync = getattr(model_v, "no_sync", False)
560+
if no_sync and k not in filter_tensor_list[tp_rank]:
561+
filter_tensor_list[tp_rank].append(k)
558562

559563
final_filter_tensor_list = []
560564
dist.all_gather_object(final_filter_tensor_list, filter_tensor_list[tp_rank], group=tp_group)
@@ -568,14 +572,20 @@ def get_sharded_file_name(args, file_name, is_optimizer=False):
568572
"""
569573
if not is_optimizer:
570574
sd_degree = args.sharding_parallel_degree if args.sharding_parallel_degree > 1 else 1
571-
size = sd_degree if args.use_expert_parallel else args.dataset_world_size
575+
if args.use_expert_parallel:
576+
if args.expert_parallel_degree > 1:
577+
size = dist.get_world_size() // args.moe_sharding_parallel_degree
578+
else:
579+
size = args.world_size // sd_degree
580+
else:
581+
size = args.world_size // args.dataset_world_size
572582
shard_file = file_name.replace(
573583
".pdparams",
574-
f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.pdparams",
584+
f"-{args.logical_process_index + 1:05d}-of-{size:05d}.pdparams",
575585
)
576586
shard_file = shard_file.replace(
577587
".safetensors",
578-
f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.safetensors",
588+
f"-{args.logical_process_index + 1:05d}-of-{size:05d}.safetensors",
579589
)
580590
else:
581591
hcg = fleet.get_hybrid_communicate_group()
@@ -617,7 +627,9 @@ def get_sharded_index(
617627
return None
618628

619629

620-
def gather_sharded_object(index_file, total_size, is_optimizer=False, use_expert_parallel=False):
630+
def gather_sharded_object(
631+
index_file, total_size, is_optimizer=False, use_expert_parallel=False, expert_parallel_degree=1
632+
):
621633
"""
622634
All gather sharded files list across different groups.
623635
"""
@@ -654,7 +666,7 @@ def gather_sharded_object(index_file, total_size, is_optimizer=False, use_expert
654666
index_file_list = [index_file]
655667
total_size_list = [total_size]
656668

657-
if use_expert_parallel:
669+
if use_expert_parallel and expert_parallel_degree <= 1:
658670
data_group = hcg.get_data_parallel_group()
659671
if data_group.nranks > 1:
660672
data_index_file_list = []
@@ -664,7 +676,7 @@ def gather_sharded_object(index_file, total_size, is_optimizer=False, use_expert
664676
index_file_list = flatten_list(data_index_file_list)
665677
total_size_list = flatten_list(data_total_size_list)
666678

667-
if is_optimizer:
679+
if is_optimizer or expert_parallel_degree > 1:
668680
sharding_group = hcg.get_sharding_parallel_group()
669681
if sharding_group.nranks > 1:
670682
sharding_index_file_list = []
@@ -781,29 +793,48 @@ def save_config(model_to_save):
781793
model_to_save.generation_config.save_pretrained(save_directory)
782794

783795

784-
def filter_sync_parameters(model_state_dict, optim_state_dict=None, master_weights=None, is_model_weight=True):
796+
def filter_sync_parameters(
797+
model_state_dict,
798+
optim_state_dict=None,
799+
master_weights=None,
800+
is_model_weight=True,
801+
use_expert_parallel=False,
802+
expert_parallel_degree=1,
803+
):
785804
"""Filter sync parameters under expert parallel mode."""
786805

787806
hcg = fleet.get_hybrid_communicate_group()
788807
dp_group = hcg.get_data_parallel_group()
808+
sharding_group = hcg.get_sharding_parallel_group()
789809
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0
810+
sharding_rank = sharding_group.rank if sharding_group.nranks > 1 else 0
811+
if expert_parallel_degree > 1:
812+
ep_group = hcg.get_expert_parallel_group()
813+
ep_rank = ep_group.rank if ep_group.nranks > 1 else 0
814+
logger.info("Filter sync parameters under expert parallel mode.")
790815

791816
if is_model_weight:
792817
for key in list(model_state_dict.keys()):
793-
if dp_rank > 0 and not getattr(model_state_dict[key], "no_sync", False):
794-
model_state_dict.pop(key)
818+
if use_expert_parallel:
819+
if expert_parallel_degree > 1:
820+
if ep_rank > 0 and sharding_rank > 0 and not getattr(model_state_dict[key], "no_sync", False):
821+
model_state_dict.pop(key)
822+
else:
823+
if dp_rank > 0 and not getattr(model_state_dict[key], "no_sync", False):
824+
model_state_dict.pop(key)
795825
else:
796-
no_sync_kname = []
797-
for k, v in model_state_dict.items():
798-
if getattr(v, "no_sync", False):
799-
no_sync_kname.append(k)
800-
801-
for key in list(optim_state_dict.keys()):
802-
model_key = key.split("/")[0]
803-
if dp_rank > 0 and model_key not in no_sync_kname:
804-
optim_state_dict.pop(key)
805-
806-
if master_weights is not None:
807-
for key in list(master_weights.keys()):
808-
if dp_rank > 0 and key not in no_sync_kname:
809-
master_weights.pop(key)
826+
if use_expert_parallel and expert_parallel_degree == 1:
827+
no_sync_kname = []
828+
for k, v in model_state_dict.items():
829+
if getattr(v, "no_sync", False):
830+
no_sync_kname.append(k)
831+
832+
for key in list(optim_state_dict.keys()):
833+
model_key = key.split("/")[0]
834+
if dp_rank > 0 and model_key not in no_sync_kname:
835+
optim_state_dict.pop(key)
836+
837+
if master_weights is not None:
838+
for key in list(master_weights.keys()):
839+
if dp_rank > 0 and key not in no_sync_kname:
840+
master_weights.pop(key)

paddleformers/transformers/moe_layer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ def __init__(self, config, moe_num_experts, expert_class, expert_kwargs, gate, m
353353
self.num_local_experts, self.moe_router_topk, self.moe_num_experts, moe_group
354354
)
355355
self.expert_parallel_degree = 1 if self.ep_size < 0 else self.ep_size
356+
self.is_dummy_moe = False if self.expert_parallel_degree > 1 else True
356357
self.moe_num_experts_per_device = self._parse_moe_expert_parallel(
357358
self.moe_num_experts, self.expert_parallel_degree
358359
)
@@ -363,6 +364,14 @@ def __init__(self, config, moe_num_experts, expert_class, expert_kwargs, gate, m
363364
else:
364365
self.experts.append(None)
365366
self.gate = gate
367+
self._post_init()
368+
369+
def _post_init(self):
370+
for k in self.experts:
371+
if k is not None:
372+
for p in k.parameters():
373+
p.expert = not self.is_dummy_moe
374+
p.no_sync = not self.is_dummy_moe
366375

367376
def expert_forward(self, dispatched_input, tokens_per_expert):
368377
outputs = []

0 commit comments

Comments
 (0)