Skip to content

Commit a739c9c

Browse files
jeejeeleejingyu
authored andcommitted
[Misc] benchmark_moe supports expert parallel (vllm-project#22251)
Signed-off-by: Jee Jee Li <[email protected]> Signed-off-by: jingyu <[email protected]>
1 parent f39375d commit a739c9c

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
FP8_DTYPE = current_platform.fp8_dtype()
2323

2424

25-
def ensure_divisibility(numerator, denominator):
25+
def ensure_divisibility(numerator, denominator, text):
2626
"""Ensure that numerator is divisible by the denominator."""
27-
assert numerator % denominator == 0, (
28-
"intermediate_size {} is not divisible by tp {}.".format(numerator, denominator)
27+
assert numerator % denominator == 0, "{} {} is not divisible by tp {}.".format(
28+
text, numerator, denominator
2929
)
3030

3131

@@ -577,12 +577,10 @@ def main(args: argparse.Namespace):
577577
E = config.ffn_config.moe_num_experts
578578
topk = config.ffn_config.moe_top_k
579579
intermediate_size = config.ffn_config.ffn_hidden_size
580-
shard_intermediate_size = 2 * intermediate_size // args.tp_size
581580
elif config.architectures[0] == "JambaForCausalLM":
582581
E = config.num_experts
583582
topk = config.num_experts_per_tok
584583
intermediate_size = config.intermediate_size
585-
shard_intermediate_size = 2 * intermediate_size // args.tp_size
586584
elif config.architectures[0] in (
587585
"DeepseekV3ForCausalLM",
588586
"DeepseekV2ForCausalLM",
@@ -591,26 +589,29 @@ def main(args: argparse.Namespace):
591589
E = config.n_routed_experts
592590
topk = config.num_experts_per_tok
593591
intermediate_size = config.moe_intermediate_size
594-
shard_intermediate_size = 2 * intermediate_size // args.tp_size
595592
elif config.architectures[0] in ("Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"):
596593
E = config.num_experts
597594
topk = config.num_experts_per_tok
598595
intermediate_size = config.moe_intermediate_size
599-
shard_intermediate_size = 2 * intermediate_size // args.tp_size
600596
elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"):
601597
E = config.num_experts
602598
topk = config.moe_topk[0]
603599
intermediate_size = config.moe_intermediate_size[0]
604-
shard_intermediate_size = 2 * intermediate_size // args.tp_size
605600
else:
606601
# Support for llama4
607602
config = config.get_text_config()
608603
# Default: Mixtral.
609604
E = config.num_local_experts
610605
topk = config.num_experts_per_tok
611606
intermediate_size = config.intermediate_size
607+
enable_ep = bool(args.enable_expert_parallel)
608+
if enable_ep:
609+
ensure_divisibility(E, args.tp_size, "Number of experts")
610+
E = E // args.tp_size
611+
shard_intermediate_size = 2 * intermediate_size
612+
else:
613+
ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size")
612614
shard_intermediate_size = 2 * intermediate_size // args.tp_size
613-
ensure_divisibility(intermediate_size, args.tp_size)
614615
hidden_size = config.hidden_size
615616
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
616617
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
@@ -742,6 +743,7 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]:
742743
parser.add_argument(
743744
"--tp-size", "-tp", "--tensor-parallel-size", type=int, default=2
744745
)
746+
parser.add_argument("--enable-expert-parallel", "-enable-ep", action="store_true")
745747
parser.add_argument(
746748
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
747749
)

0 commit comments

Comments
 (0)