@@ -1405,7 +1405,7 @@ def is_context_parallel_supported():
14051405 else :
14061406 order = ["dp" , "sharding" , "pp" , "mp" ]
14071407 if self .use_expert_parallel :
1408- if self .moe_sharding_parallel_degree > 1 and self .expert_parallel_degree > 1 :
1408+ if self .moe_sharding_parallel_degree >= 1 and self .expert_parallel_degree > 1 :
14091409 if is_context_parallel_supported ():
14101410 order = ["sharding" , "moe_sharding" , "pp" , "sep" , "cp" , "dp" , "ep" , "mp" ]
14111411 else :
@@ -1429,8 +1429,6 @@ def is_context_parallel_supported():
14291429 "mp_degree" : self .tensor_parallel_degree ,
14301430 "pp_degree" : self .pipeline_parallel_degree ,
14311431 "sharding_degree" : self .sharding_parallel_degree ,
1432- "moe_sharding_degree" : self .moe_sharding_parallel_degree ,
1433- "ep_degree" : self .expert_parallel_degree ,
14341432 "sep_degree" : self .sep_parallel_degree
14351433 if self .sep_parallel_degree > 1
14361434 else self .context_parallel_degree ,
@@ -1442,11 +1440,16 @@ def is_context_parallel_supported():
14421440 "mp_degree" : self .tensor_parallel_degree ,
14431441 "pp_degree" : self .pipeline_parallel_degree ,
14441442 "sharding_degree" : self .sharding_parallel_degree ,
1445- "moe_sharding_degree" : self .moe_sharding_parallel_degree ,
1446- "ep_degree" : self .expert_parallel_degree ,
14471443 "order" : order ,
14481444 }
14491445
1446+ if self .expert_parallel_degree > 1 :
1447+ assert (
1448+ self .use_expert_parallel is True and self .moe_sharding_parallel_degree >= 0
1449+ ), f"invalid expert_parallel_degree { self .expert_parallel_degree } and use_expert_paralle:{ self .use_expert_parallel } ."
1450+ hybrid_configs ["ep_degree" ] = self .expert_parallel_degree
1451+ hybrid_configs ["moe_sharding_degree" ] = self .moe_sharding_parallel_degree
1452+
14501453 try :
14511454 if self .split_norm_comm :
14521455 hybrid_configs ["split_norm_comm" ] = True
0 commit comments