Skip to content

Commit 5e5bea1

Browse files
jeejeeleenpanpaliya
authored andcommitted
[Misc] Minor enhancement of benchmark_moe (vllm-project#22068)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent 5181482 commit 5e5bea1

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@
2222
FP8_DTYPE = current_platform.fp8_dtype()
2323

2424

25+
def ensure_divisibility(numerator, denominator):
26+
"""Ensure that numerator is divisible by the denominator."""
27+
assert numerator % denominator == 0, (
28+
"intermediate_size {} is not divisible by tp {}.".format(numerator, denominator)
29+
)
30+
31+
2532
class BenchmarkConfig(TypedDict):
2633
BLOCK_SIZE_M: int
2734
BLOCK_SIZE_N: int
@@ -603,7 +610,7 @@ def main(args: argparse.Namespace):
603610
topk = config.num_experts_per_tok
604611
intermediate_size = config.intermediate_size
605612
shard_intermediate_size = 2 * intermediate_size // args.tp_size
606-
613+
ensure_divisibility(intermediate_size, args.tp_size)
607614
hidden_size = config.hidden_size
608615
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
609616
use_fp8_w8a8 = args.dtype == "fp8_w8a8"

0 commit comments

Comments
 (0)