@@ -633,6 +633,11 @@ class TrainingArguments:
633
633
},
634
634
)
635
635
636
+ load_sharded_model_remap_parameter_name : bool = field (
637
+ default = False ,
638
+ metadata = {"help" : "Whether to remap parameter name when load_sharded_model = true." },
639
+ )
640
+
636
641
tensor_parallel_degree : int = field (
637
642
default = - 1 ,
638
643
metadata = {
@@ -2039,6 +2044,11 @@ def _post_init_parallel_degree(self):
2039
2044
sharding_parallel_degree * tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree
2040
2045
)
2041
2046
2047
+ if expert_parallel_degree > 1 :
2048
+ assert (
2049
+ self .expert_tensor_parallel_degree <= 1
2050
+ ), "expert_tensor_parallel_degree > 1 is not supported when expert_parallel_degree > 1"
2051
+
2042
2052
assert not (
2043
2053
self .data_parallel_degree > 1 and expert_parallel_degree > 1
2044
2054
), f"Currently only support use expert_data_parallel strategy together with sharding_parallel strategy, but not with data_parallel strategy. Currently data_parallel_degree is { self .data_parallel_degree } ."
@@ -2227,6 +2237,17 @@ def pipeline_parallel_rank(self):
2227
2237
else :
2228
2238
return 0
2229
2239
2240
+ @property
2241
+ def expert_parallel_rank (self ):
2242
+ if self .use_hybrid_parallel :
2243
+ hcg = fleet .get_hybrid_communicate_group ()
2244
+ if hasattr (hcg , "get_expert_parallel_rank" ):
2245
+ return max (hcg .get_expert_parallel_rank (), 0 )
2246
+ else :
2247
+ return 0
2248
+ else :
2249
+ return 0
2250
+
2230
2251
@property
2231
2252
def context_parallel_rank (self ):
2232
2253
if self .use_hybrid_parallel :
@@ -2252,7 +2273,7 @@ def optimizer_name_suffix(self):
2252
2273
name .append (self ._format_name ("pp" , self .pipeline_parallel_rank , self .pipeline_parallel_degree ))
2253
2274
if self .sharding_parallel_degree > 1 :
2254
2275
name .append (self ._format_name ("shard" , self .sharding_parallel_rank , self .sharding_parallel_degree ))
2255
- if self .use_expert_parallel :
2276
+ if self .use_expert_parallel and self . expert_parallel_degree <= 1 :
2256
2277
name .append (self ._format_name ("moe" , self .data_parallel_rank , self .data_parallel_degree ))
2257
2278
return "_" .join (name )
2258
2279
else :
@@ -2268,7 +2289,7 @@ def weight_name_suffix(self):
2268
2289
name .append (self ._format_name ("tp" , self .tensor_parallel_rank , self .tensor_parallel_degree ))
2269
2290
if self .pipeline_parallel_degree > 1 :
2270
2291
name .append (self ._format_name ("pp" , self .pipeline_parallel_rank , self .pipeline_parallel_degree ))
2271
- if self .use_expert_parallel :
2292
+ if self .use_expert_parallel and self . expert_parallel_degree <= 1 :
2272
2293
name .append (self ._format_name ("moe" , self .data_parallel_rank , self .data_parallel_degree ))
2273
2294
return "_" .join (name )
2274
2295
@@ -2277,7 +2298,9 @@ def weight_name_suffix(self):
2277
2298
return self ._format_name ("moe" , self .data_parallel_rank , self .data_parallel_degree )
2278
2299
return None
2279
2300
2280
- def sharded_name_suffix (self , shard_id = None , pp_id = None , moe_id = None ):
2301
+ def sharded_name_suffix (self , shard_id = None , pp_id = None , moe_id = None , sharding_parallel_degree = None ):
2302
+ if sharding_parallel_degree is None :
2303
+ sharding_parallel_degree = self .sharding_parallel_degree
2281
2304
if self .use_hybrid_parallel :
2282
2305
name = []
2283
2306
if self .tensor_parallel_degree > 1 :
@@ -2287,12 +2310,12 @@ def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None):
2287
2310
pp_id = self .pipeline_parallel_rank
2288
2311
assert isinstance (pp_id , int )
2289
2312
name .append (self ._format_name ("pp" , pp_id , self .pipeline_parallel_degree ))
2290
- if self . sharding_parallel_degree > 1 :
2313
+ if sharding_parallel_degree > 1 :
2291
2314
if shard_id is None :
2292
2315
shard_id = self .sharding_parallel_rank
2293
2316
assert isinstance (shard_id , int )
2294
- name .append (self ._format_name ("shard" , shard_id , self . sharding_parallel_degree ))
2295
- if self .use_expert_parallel :
2317
+ name .append (self ._format_name ("shard" , shard_id , sharding_parallel_degree ))
2318
+ if self .use_expert_parallel and self . expert_parallel_degree <= 1 :
2296
2319
if moe_id is None :
2297
2320
moe_id = self .data_parallel_rank
2298
2321
assert isinstance (moe_id , int )
@@ -2418,9 +2441,7 @@ def should_save_sharding_stage1_model(self):
2418
2441
def should_load_sharding_stage1_model (self ):
2419
2442
if self .enable_auto_parallel :
2420
2443
return False
2421
- return (
2422
- ShardingOption .SHARD_OP in self .sharding and self .sharding_parallel_degree > 1 and self .load_sharded_model
2423
- )
2444
+ return self .load_sharded_model
2424
2445
2425
2446
@property
2426
2447
def should_load_dataset (self ):
0 commit comments