@@ -1424,10 +1424,13 @@ def is_context_parallel_supported():
1424
1424
else :
1425
1425
order = ["dp" , "sharding" , "pp" , "mp" ]
1426
1426
if self .use_expert_parallel :
1427
- if is_context_parallel_supported ():
1428
- order = order [1 :- 1 ] + ["cp" , "dp" , "mp" ]
1427
+ if self .moe_sharding_parallel_degree >= 1 and self .expert_parallel_degree > 1 :
1428
+ if is_context_parallel_supported ():
1429
+ order = ["sharding" , "moe_sharding" , "pp" , "sep" , "cp" , "dp" , "ep" , "mp" ]
1430
+ else :
1431
+ order = ["sharding" , "moe_sharding" , "pp" , "sep" , "dp" , "ep" , "mp" ]
1429
1432
else :
1430
- order = order [ 1 : - 1 ] + [ "dp" , "mp" ]
1433
+ order = [ "sharding" , "pp" , "sep" , "dp" , "mp" ]
1431
1434
1432
1435
if is_context_parallel_supported ():
1433
1436
hybrid_configs = {
@@ -1459,6 +1462,13 @@ def is_context_parallel_supported():
1459
1462
"order" : order ,
1460
1463
}
1461
1464
1465
+ if self .expert_parallel_degree > 1 :
1466
+ assert (
1467
+ self .use_expert_parallel is True and self .moe_sharding_parallel_degree >= 0
1468
+ ), f"invalid expert_parallel_degree { self .expert_parallel_degree } and use_expert_paralle:{ self .use_expert_parallel } ."
1469
+ hybrid_configs ["ep_degree" ] = self .expert_parallel_degree
1470
+ hybrid_configs ["moe_sharding_degree" ] = self .moe_sharding_parallel_degree
1471
+
1462
1472
try :
1463
1473
if self .split_norm_comm :
1464
1474
hybrid_configs ["split_norm_comm" ] = True
@@ -1587,9 +1597,6 @@ def is_context_parallel_supported():
1587
1597
fleet .init (is_collective = True , strategy = strategy )
1588
1598
logger .info (strategy )
1589
1599
1590
- if self .expert_parallel_degree > 1 :
1591
- self .add_moe_comm_group ()
1592
-
1593
1600
elif self .enable_auto_parallel :
1594
1601
self .tensor_parallel_degree = max (self .tensor_parallel_degree , 1 )
1595
1602
self .sep_parallel_degree = max (self .sep_parallel_degree , 1 )
@@ -2035,11 +2042,6 @@ def _post_init_parallel_degree(self):
2035
2042
logger .warning ("sharding_parallel_degree=1 means no sharding, please set sharding to empty!" )
2036
2043
self .sharding = []
2037
2044
2038
- if sharding_parallel_degree > 1 :
2039
- assert (
2040
- sharding_parallel_degree % expert_parallel_degree == 0
2041
- ), f"sharding_parallel_degree should be divided by expert_parallel_degree, current sharding_parallel_degree: { sharding_parallel_degree } , expert_parallel_degree: { expert_parallel_degree } ."
2042
-
2043
2045
self .data_parallel_degree = world_size // (
2044
2046
sharding_parallel_degree * tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree
2045
2047
)
@@ -2048,6 +2050,19 @@ def _post_init_parallel_degree(self):
2048
2050
assert (
2049
2051
self .expert_tensor_parallel_degree <= 1
2050
2052
), "expert_tensor_parallel_degree > 1 is not supported when expert_parallel_degree > 1"
2053
+ moe_sharding_parallel_degree = world_size // (pipeline_parallel_degree * expert_parallel_degree )
2054
+ else :
2055
+ moe_sharding_parallel_degree = 1
2056
+ moe_sharding_parallel_degree = max (moe_sharding_parallel_degree , 1 )
2057
+ if moe_sharding_parallel_degree > 1 and self .data_parallel_degree > 1 :
2058
+ raise NotImplementedError (
2059
+ f"Currently only support use expert_data_parallel strategy together with sharding_parallel strategy, but not with data_parallel strategy. But got data_parallel_degree: { self .data_parallel_degree } , expert_parallel_degree: { expert_parallel_degree } , moe_sharding_parallel_degree: { moe_sharding_parallel_degree } ."
2060
+ )
2061
+
2062
+ if sharding_parallel_degree > 1 and moe_sharding_parallel_degree > 1 :
2063
+ assert (
2064
+ sharding_parallel_degree % moe_sharding_parallel_degree == 0
2065
+ ), f"sharding_parallel_degree should be divided by moe_sharding_parallel_degree, current sharding_parallel_degree: { sharding_parallel_degree } , moe_sharding_parallel_degree: { moe_sharding_parallel_degree } ."
2051
2066
2052
2067
assert not (
2053
2068
self .data_parallel_degree > 1 and expert_parallel_degree > 1
@@ -2070,6 +2085,7 @@ def _post_init_parallel_degree(self):
2070
2085
self .context_parallel_degree = context_parallel_degree
2071
2086
self .expert_parallel_degree = expert_parallel_degree
2072
2087
self .expert_tensor_parallel_degree = expert_tensor_parallel_degree
2088
+ self .moe_sharding_parallel_degree = moe_sharding_parallel_degree
2073
2089
2074
2090
if not self .use_hybrid_parallel :
2075
2091
self .sharding = []
0 commit comments