File tree Expand file tree Collapse file tree 1 file changed +8
-1
lines changed Expand file tree Collapse file tree 1 file changed +8
-1
lines changed Original file line number Diff line number Diff line change 22
22
FP8_DTYPE = current_platform .fp8_dtype ()
23
23
24
24
25
+ def ensure_divisibility (numerator , denominator ):
26
+ """Ensure that numerator is divisible by the denominator."""
27
+ assert numerator % denominator == 0 , (
28
+ "intermediate_size {} is not divisible by tp {}." .format (numerator , denominator )
29
+ )
30
+
31
+
25
32
class BenchmarkConfig (TypedDict ):
26
33
BLOCK_SIZE_M : int
27
34
BLOCK_SIZE_N : int
@@ -603,7 +610,7 @@ def main(args: argparse.Namespace):
603
610
topk = config .num_experts_per_tok
604
611
intermediate_size = config .intermediate_size
605
612
shard_intermediate_size = 2 * intermediate_size // args .tp_size
606
-
613
+ ensure_divisibility ( intermediate_size , args . tp_size )
607
614
hidden_size = config .hidden_size
608
615
dtype = torch .float16 if current_platform .is_rocm () else config .torch_dtype
609
616
use_fp8_w8a8 = args .dtype == "fp8_w8a8"
You can’t perform that action at this time.
0 commit comments