@@ -1424,10 +1424,13 @@ def is_context_parallel_supported():
14241424 else :
14251425 order = ["dp" , "sharding" , "pp" , "mp" ]
14261426 if self .use_expert_parallel :
1427- if is_context_parallel_supported ():
1428- order = order [1 :- 1 ] + ["cp" , "dp" , "mp" ]
1427+ if self .moe_sharding_parallel_degree >= 1 and self .expert_parallel_degree > 1 :
1428+ if is_context_parallel_supported ():
1429+ order = ["sharding" , "moe_sharding" , "pp" , "sep" , "cp" , "dp" , "ep" , "mp" ]
1430+ else :
1431+ order = ["sharding" , "moe_sharding" , "pp" , "sep" , "dp" , "ep" , "mp" ]
14291432 else :
1430- order = order [ 1 : - 1 ] + [ "dp" , "mp" ]
1433+ order = [ "sharding" , "pp" , "sep" , "dp" , "mp" ]
14311434
14321435 if is_context_parallel_supported ():
14331436 hybrid_configs = {
@@ -1459,6 +1462,13 @@ def is_context_parallel_supported():
14591462 "order" : order ,
14601463 }
14611464
1465+ if self .expert_parallel_degree > 1 :
1466+ assert (
1467+ self .use_expert_parallel is True and self .moe_sharding_parallel_degree >= 0
1468+ ), f"invalid expert_parallel_degree { self .expert_parallel_degree } and use_expert_paralle:{ self .use_expert_parallel } ."
1469+ hybrid_configs ["ep_degree" ] = self .expert_parallel_degree
1470+ hybrid_configs ["moe_sharding_degree" ] = self .moe_sharding_parallel_degree
1471+
14621472 try :
14631473 if self .split_norm_comm :
14641474 hybrid_configs ["split_norm_comm" ] = True
@@ -1587,9 +1597,6 @@ def is_context_parallel_supported():
15871597 fleet .init (is_collective = True , strategy = strategy )
15881598 logger .info (strategy )
15891599
1590- if self .expert_parallel_degree > 1 :
1591- self .add_moe_comm_group ()
1592-
15931600 elif self .enable_auto_parallel :
15941601 self .tensor_parallel_degree = max (self .tensor_parallel_degree , 1 )
15951602 self .sep_parallel_degree = max (self .sep_parallel_degree , 1 )
@@ -2035,11 +2042,6 @@ def _post_init_parallel_degree(self):
20352042 logger .warning ("sharding_parallel_degree=1 means no sharding, please set sharding to empty!" )
20362043 self .sharding = []
20372044
2038- if sharding_parallel_degree > 1 :
2039- assert (
2040- sharding_parallel_degree % expert_parallel_degree == 0
2041- ), f"sharding_parallel_degree should be divided by expert_parallel_degree, current sharding_parallel_degree: { sharding_parallel_degree } , expert_parallel_degree: { expert_parallel_degree } ."
2042-
20432045 self .data_parallel_degree = world_size // (
20442046 sharding_parallel_degree * tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree
20452047 )
@@ -2048,6 +2050,19 @@ def _post_init_parallel_degree(self):
20482050 assert (
20492051 self .expert_tensor_parallel_degree <= 1
20502052 ), "expert_tensor_parallel_degree > 1 is not supported when expert_parallel_degree > 1"
2053+ moe_sharding_parallel_degree = world_size // (pipeline_parallel_degree * expert_parallel_degree )
2054+ else :
2055+ moe_sharding_parallel_degree = 1
2056+ moe_sharding_parallel_degree = max (moe_sharding_parallel_degree , 1 )
2057+ if moe_sharding_parallel_degree > 1 and self .data_parallel_degree > 1 :
2058+ raise NotImplementedError (
2059+ f"Currently only support use expert_data_parallel strategy together with sharding_parallel strategy, but not with data_parallel strategy. But got data_parallel_degree: { self .data_parallel_degree } , expert_parallel_degree: { expert_parallel_degree } , moe_sharding_parallel_degree: { moe_sharding_parallel_degree } ."
2060+ )
2061+
2062+ if sharding_parallel_degree > 1 and moe_sharding_parallel_degree > 1 :
2063+ assert (
2064+ sharding_parallel_degree % moe_sharding_parallel_degree == 0
2065+ ), f"sharding_parallel_degree should be divided by moe_sharding_parallel_degree, current sharding_parallel_degree: { sharding_parallel_degree } , moe_sharding_parallel_degree: { moe_sharding_parallel_degree } ."
20512066
20522067 assert not (
20532068 self .data_parallel_degree > 1 and expert_parallel_degree > 1
@@ -2070,6 +2085,7 @@ def _post_init_parallel_degree(self):
20702085 self .context_parallel_degree = context_parallel_degree
20712086 self .expert_parallel_degree = expert_parallel_degree
20722087 self .expert_tensor_parallel_degree = expert_tensor_parallel_degree
2088+ self .moe_sharding_parallel_degree = moe_sharding_parallel_degree
20732089
20742090 if not self .use_hybrid_parallel :
20752091 self .sharding = []
0 commit comments