@@ -493,7 +493,8 @@ def filter_params(model_to_save, state_dict, args, is_optimizer=False):
493
493
weight_key = k .split ("/" )[0 ]
494
494
model_v = model_state_dict [weight_key ] if is_optimizer else v
495
495
mp_moe = getattr (model_v , "mp_moe" , False )
496
- if not mp_moe :
496
+ no_sync = getattr (model_v , "no_sync" , False )
497
+ if not mp_moe or no_sync :
497
498
if not quant or not is_optimizer :
498
499
if hasattr (model_v , "is_distributed" ) and model_v .is_distributed :
499
500
tensor_bytes_dict [k ] = v .numel ().item () * tp_size * dtype_byte_size (v .dtype )
@@ -555,6 +556,9 @@ def filter_params(model_to_save, state_dict, args, is_optimizer=False):
555
556
mp_moe = getattr (model_v , "mp_moe" , False )
556
557
if mp_moe :
557
558
filter_tensor_list [tp_rank ].append (k )
559
+ no_sync = getattr (model_v , "no_sync" , False )
560
+ if no_sync and k not in filter_tensor_list [tp_rank ]:
561
+ filter_tensor_list [tp_rank ].append (k )
558
562
559
563
final_filter_tensor_list = []
560
564
dist .all_gather_object (final_filter_tensor_list , filter_tensor_list [tp_rank ], group = tp_group )
@@ -568,14 +572,20 @@ def get_sharded_file_name(args, file_name, is_optimizer=False):
568
572
"""
569
573
if not is_optimizer :
570
574
sd_degree = args .sharding_parallel_degree if args .sharding_parallel_degree > 1 else 1
571
- size = sd_degree if args .use_expert_parallel else args .dataset_world_size
575
+ if args .use_expert_parallel :
576
+ if args .expert_parallel_degree > 1 :
577
+ size = dist .get_world_size () // args .moe_sharding_parallel_degree
578
+ else :
579
+ size = args .world_size // sd_degree
580
+ else :
581
+ size = args .world_size // args .dataset_world_size
572
582
shard_file = file_name .replace (
573
583
".pdparams" ,
574
- f"-{ args .logical_process_index + 1 :05d} -of-{ args . world_size // size :05d} .pdparams" ,
584
+ f"-{ args .logical_process_index + 1 :05d} -of-{ size :05d} .pdparams" ,
575
585
)
576
586
shard_file = shard_file .replace (
577
587
".safetensors" ,
578
- f"-{ args .logical_process_index + 1 :05d} -of-{ args . world_size // size :05d} .safetensors" ,
588
+ f"-{ args .logical_process_index + 1 :05d} -of-{ size :05d} .safetensors" ,
579
589
)
580
590
else :
581
591
hcg = fleet .get_hybrid_communicate_group ()
@@ -617,7 +627,9 @@ def get_sharded_index(
617
627
return None
618
628
619
629
620
- def gather_sharded_object (index_file , total_size , is_optimizer = False , use_expert_parallel = False ):
630
+ def gather_sharded_object (
631
+ index_file , total_size , is_optimizer = False , use_expert_parallel = False , expert_parallel_degree = 1
632
+ ):
621
633
"""
622
634
All gather sharded files list across different groups.
623
635
"""
@@ -654,7 +666,7 @@ def gather_sharded_object(index_file, total_size, is_optimizer=False, use_expert
654
666
index_file_list = [index_file ]
655
667
total_size_list = [total_size ]
656
668
657
- if use_expert_parallel :
669
+ if use_expert_parallel and expert_parallel_degree <= 1 :
658
670
data_group = hcg .get_data_parallel_group ()
659
671
if data_group .nranks > 1 :
660
672
data_index_file_list = []
@@ -664,7 +676,7 @@ def gather_sharded_object(index_file, total_size, is_optimizer=False, use_expert
664
676
index_file_list = flatten_list (data_index_file_list )
665
677
total_size_list = flatten_list (data_total_size_list )
666
678
667
- if is_optimizer :
679
+ if is_optimizer or expert_parallel_degree > 1 :
668
680
sharding_group = hcg .get_sharding_parallel_group ()
669
681
if sharding_group .nranks > 1 :
670
682
sharding_index_file_list = []
@@ -781,29 +793,48 @@ def save_config(model_to_save):
781
793
model_to_save .generation_config .save_pretrained (save_directory )
782
794
783
795
784
- def filter_sync_parameters (model_state_dict , optim_state_dict = None , master_weights = None , is_model_weight = True ):
796
+ def filter_sync_parameters (
797
+ model_state_dict ,
798
+ optim_state_dict = None ,
799
+ master_weights = None ,
800
+ is_model_weight = True ,
801
+ use_expert_parallel = False ,
802
+ expert_parallel_degree = 1 ,
803
+ ):
785
804
"""Filter sync parameters under expert parallel mode."""
786
805
787
806
hcg = fleet .get_hybrid_communicate_group ()
788
807
dp_group = hcg .get_data_parallel_group ()
808
+ sharding_group = hcg .get_sharding_parallel_group ()
789
809
dp_rank = dp_group .rank if dp_group .nranks > 1 else 0
810
+ sharding_rank = sharding_group .rank if sharding_group .nranks > 1 else 0
811
+ if expert_parallel_degree > 1 :
812
+ ep_group = hcg .get_expert_parallel_group ()
813
+ ep_rank = ep_group .rank if ep_group .nranks > 1 else 0
814
+ logger .info ("Filter sync parameters under expert parallel mode." )
790
815
791
816
if is_model_weight :
792
817
for key in list (model_state_dict .keys ()):
793
- if dp_rank > 0 and not getattr (model_state_dict [key ], "no_sync" , False ):
794
- model_state_dict .pop (key )
818
+ if use_expert_parallel :
819
+ if expert_parallel_degree > 1 :
820
+ if ep_rank > 0 and sharding_rank > 0 and not getattr (model_state_dict [key ], "no_sync" , False ):
821
+ model_state_dict .pop (key )
822
+ else :
823
+ if dp_rank > 0 and not getattr (model_state_dict [key ], "no_sync" , False ):
824
+ model_state_dict .pop (key )
795
825
else :
796
- no_sync_kname = []
797
- for k , v in model_state_dict .items ():
798
- if getattr (v , "no_sync" , False ):
799
- no_sync_kname .append (k )
800
-
801
- for key in list (optim_state_dict .keys ()):
802
- model_key = key .split ("/" )[0 ]
803
- if dp_rank > 0 and model_key not in no_sync_kname :
804
- optim_state_dict .pop (key )
805
-
806
- if master_weights is not None :
807
- for key in list (master_weights .keys ()):
808
- if dp_rank > 0 and key not in no_sync_kname :
809
- master_weights .pop (key )
826
+ if use_expert_parallel and expert_parallel_degree == 1 :
827
+ no_sync_kname = []
828
+ for k , v in model_state_dict .items ():
829
+ if getattr (v , "no_sync" , False ):
830
+ no_sync_kname .append (k )
831
+
832
+ for key in list (optim_state_dict .keys ()):
833
+ model_key = key .split ("/" )[0 ]
834
+ if dp_rank > 0 and model_key not in no_sync_kname :
835
+ optim_state_dict .pop (key )
836
+
837
+ if master_weights is not None :
838
+ for key in list (master_weights .keys ()):
839
+ if dp_rank > 0 and key not in no_sync_kname :
840
+ master_weights .pop (key )
0 commit comments