22
22
FP8_DTYPE = current_platform .fp8_dtype ()
23
23
24
24
25
- def ensure_divisibility (numerator , denominator ):
25
+ def ensure_divisibility (numerator , denominator , text ):
26
26
"""Ensure that numerator is divisible by the denominator."""
27
- assert numerator % denominator == 0 , (
28
- "intermediate_size {} is not divisible by tp {}." . format ( numerator , denominator )
27
+ assert numerator % denominator == 0 , "{} {} is not divisible by tp {}." . format (
28
+ text , numerator , denominator
29
29
)
30
30
31
31
@@ -577,12 +577,10 @@ def main(args: argparse.Namespace):
577
577
E = config .ffn_config .moe_num_experts
578
578
topk = config .ffn_config .moe_top_k
579
579
intermediate_size = config .ffn_config .ffn_hidden_size
580
- shard_intermediate_size = 2 * intermediate_size // args .tp_size
581
580
elif config .architectures [0 ] == "JambaForCausalLM" :
582
581
E = config .num_experts
583
582
topk = config .num_experts_per_tok
584
583
intermediate_size = config .intermediate_size
585
- shard_intermediate_size = 2 * intermediate_size // args .tp_size
586
584
elif config .architectures [0 ] in (
587
585
"DeepseekV3ForCausalLM" ,
588
586
"DeepseekV2ForCausalLM" ,
@@ -591,26 +589,29 @@ def main(args: argparse.Namespace):
591
589
E = config .n_routed_experts
592
590
topk = config .num_experts_per_tok
593
591
intermediate_size = config .moe_intermediate_size
594
- shard_intermediate_size = 2 * intermediate_size // args .tp_size
595
592
elif config .architectures [0 ] in ("Qwen2MoeForCausalLM" , "Qwen3MoeForCausalLM" ):
596
593
E = config .num_experts
597
594
topk = config .num_experts_per_tok
598
595
intermediate_size = config .moe_intermediate_size
599
- shard_intermediate_size = 2 * intermediate_size // args .tp_size
600
596
elif config .architectures [0 ] in ("HunYuanMoEV1ForCausalLM" ):
601
597
E = config .num_experts
602
598
topk = config .moe_topk [0 ]
603
599
intermediate_size = config .moe_intermediate_size [0 ]
604
- shard_intermediate_size = 2 * intermediate_size // args .tp_size
605
600
else :
606
601
# Support for llama4
607
602
config = config .get_text_config ()
608
603
# Default: Mixtral.
609
604
E = config .num_local_experts
610
605
topk = config .num_experts_per_tok
611
606
intermediate_size = config .intermediate_size
607
+ enable_ep = bool (args .enable_expert_parallel )
608
+ if enable_ep :
609
+ ensure_divisibility (E , args .tp_size , "Number of experts" )
610
+ E = E // args .tp_size
611
+ shard_intermediate_size = 2 * intermediate_size
612
+ else :
613
+ ensure_divisibility (intermediate_size , args .tp_size , "intermediate_size" )
612
614
shard_intermediate_size = 2 * intermediate_size // args .tp_size
613
- ensure_divisibility (intermediate_size , args .tp_size )
614
615
hidden_size = config .hidden_size
615
616
dtype = torch .float16 if current_platform .is_rocm () else config .torch_dtype
616
617
use_fp8_w8a8 = args .dtype == "fp8_w8a8"
@@ -742,6 +743,7 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]:
742
743
parser .add_argument (
743
744
"--tp-size" , "-tp" , "--tensor-parallel-size" , type = int , default = 2
744
745
)
746
+ parser .add_argument ("--enable-expert-parallel" , "-enable-ep" , action = "store_true" )
745
747
parser .add_argument (
746
748
"--dtype" , type = str , choices = ["auto" , "fp8_w8a8" , "int8_w8a16" ], default = "auto"
747
749
)
0 commit comments