Skip to content

Commit 1c14682

Browse files
committed
fix hang issue
Signed-off-by: jiahanc <[email protected]>
1 parent 639df1e commit 1c14682

File tree

4 files changed

+33
-60
lines changed

4 files changed

+33
-60
lines changed

benchmarks/bench_trtllm_gen_fused_moe_autotuner.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,34 +7,11 @@
77
GatedActType,
88
fp4_quantize,
99
mxfp8_quantize,
10-
next_positive_power_of_2,
1110
)
1211
from flashinfer.fused_moe import trtllm_fp4_block_scale_moe
1312
from flashinfer.autotuner import autotune
1413
from flashinfer.testing.utils import bench_gpu_time
15-
from flashinfer.utils import device_support_pdl
16-
17-
18-
def get_tile_tokens_dim(num_tokens, num_experts, top_k):
19-
# Factor to account for the imbalance of the experts.
20-
# factor equals to the
21-
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
22-
# - 1.0 means perfect expert distribution.
23-
# - > 1.0 means some experts have more
24-
# tokens than the perfect distribution.
25-
# - < 1.0 does not make sense.
26-
imbalance_factor = 1.3
27-
# Calculate the number of tokens per expert
28-
# assuming perfect distribution.
29-
num_tokens_per_expert = (num_tokens * top_k) // num_experts
30-
# Apply the imbalance factor.
31-
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
32-
# And pad the number to the next power of 2.
33-
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
34-
# Cap to 8-64 tokens per CTA tile
35-
# as it's the range supported by the kernel.
36-
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
37-
return tile_tokens_dim
14+
from flashinfer.utils import device_support_pdl, calculate_tile_tokens_dim
3815

3916

4017
def bench_trtllm_gen_fused_moe_autotuner(
@@ -122,7 +99,9 @@ def bench_trtllm_gen_fused_moe_autotuner(
12299
bias13 = torch.randn(num_experts, intermediate_size * 2, device=device) * 10
123100
bias2 = torch.randn(num_experts, intermediate_size * 2, device=device) * 10
124101

125-
tile_tokens_dim = get_tile_tokens_dim(num_tokens, num_experts, top_k)
102+
tile_tokens_dim = calculate_tile_tokens_dim(
103+
num_tokens, num_experts, top_k, 64 if quant_mode == "MxFP4xBf16" else 128
104+
)
126105
output1_scale_scalar = torch.tensor(
127106
[hidden_states_global_scale * w13_global_scale] * num_experts, device=device
128107
)

flashinfer/fused_moe/core.py

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,13 @@
4444
device_support_pdl,
4545
get_shuffle_matrix_a_row_indices,
4646
get_shuffle_matrix_sf_a_row_indices,
47+
calculate_tile_tokens_dim,
4748
register_custom_op,
4849
register_fake_op,
4950
)
5051
from .utils import (
5152
get_last_power_of_2_num_tokens_buckets,
5253
last_positive_power_of_2,
53-
next_positive_power_of_2,
5454
)
5555

5656

@@ -894,30 +894,6 @@ def __init__(
894894
self.gated_act_type = gated_act_type
895895
self.tile_tokens_dim = tile_tokens_dim
896896

897-
def get_tile_tokens_dim(
898-
self, num_tokens: int, top_k: int, max_tile_tokens_dim: int = 128
899-
):
900-
# Factor to account for the imbalance of the experts.
901-
# factor equals to the
902-
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
903-
# - 1.0 means perfect expert distribution.
904-
# - > 1.0 means some experts have more
905-
# tokens than the perfect distribution.
906-
# - < 1.0 does not make sense.
907-
imbalance_factor = 1.3
908-
# Calculate the number of tokens per expert
909-
# assuming perfect distribution.
910-
num_tokens_per_expert = (num_tokens * top_k) // self.num_local_experts
911-
# Apply the imbalance factor.
912-
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
913-
# And pad the number to the next power of 2.
914-
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
915-
if num_tokens_per_expert > 128 and num_tokens_per_expert < 256:
916-
tile_tokens_dim = 192
917-
# Cap to 8-max_tile_tokens_dim tokens per CTA tile as it's the range supported by the kernel.
918-
tile_tokens_dim = min(max(tile_tokens_dim, 8), max_tile_tokens_dim)
919-
return tile_tokens_dim
920-
921897
def get_valid_tactics(
922898
self,
923899
inputs: List[torch.Tensor],
@@ -933,7 +909,12 @@ def get_valid_tactics(
933909
) = inputs
934910
num_tokens = routing_logits.shape[0]
935911
tile_tokens_dim = (
936-
self.get_tile_tokens_dim(num_tokens, self.top_k, 128)
912+
calculate_tile_tokens_dim(
913+
num_tokens,
914+
self.num_local_experts,
915+
self.top_k,
916+
64 if self.dtype_act == DtypeTrtllmGen.Bfloat16 else 128,
917+
)
937918
if self.tile_tokens_dim is None
938919
else self.tile_tokens_dim
939920
)
@@ -977,7 +958,12 @@ def forward(
977958
) = inputs
978959
num_tokens = routing_logits.shape[0]
979960
tile_tokens_dim = (
980-
self.get_tile_tokens_dim(num_tokens, self.top_k, 128)
961+
calculate_tile_tokens_dim(
962+
num_tokens,
963+
self.num_local_experts,
964+
self.top_k,
965+
64 if self.dtype_act == DtypeTrtllmGen.Bfloat16 else 128,
966+
)
981967
if self.tile_tokens_dim is None
982968
else self.tile_tokens_dim
983969
)
@@ -1005,7 +991,6 @@ def forward(
1005991
hidden_states_scale.dim() == 2
1006992
and hidden_states_scale.shape[0] == num_tokens
1007993
), "hidden_states_scale's first dimension must be batch size"
1008-
1009994
# TODO(siyuan): support fp8
1010995
moe_op.trtllm_fp4_block_scale_moe(
1011996
routing_logits,

flashinfer/utils.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,23 @@ def next_positive_power_of_2(x: int) -> int:
116116
def calculate_tile_tokens_dim(
117117
num_tokens: int, num_experts: int, top_k: int, max_tile_tokens_dim: int = 128
118118
) -> int:
119-
# Guess tokens per expert assuming perfect expert distribution first.
120-
num_tokens_per_expert = num_tokens * top_k // num_experts
121-
119+
# Factor to account for the imbalance of the experts.
120+
# factor equals to the
121+
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
122+
# - 1.0 means perfect expert distribution.
123+
# - > 1.0 means some experts have more
124+
# tokens than the perfect distribution.
125+
# - < 1.0 does not make sense.
126+
imbalance_factor = 1.3
127+
# Calculate the number of tokens per expert
128+
# assuming perfect distribution.
129+
num_tokens_per_expert = (num_tokens * top_k) // num_experts
130+
# Apply the imbalance factor.
131+
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
122132
# And pad the number to the next power of 2.
123133
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
124-
if num_tokens_per_expert > 128 and num_tokens_per_expert < 256:
125-
tile_tokens_dim = 192
126-
# Cap to 8-max_tile_tokens_dim tokens per CTA tile as it's the range supported by the kernel.
134+
# Cap to 8-max_tile_tokens_dim tokens per CTA tile
135+
# as it's the range supported by the kernel.
127136
tile_tokens_dim = min(max(tile_tokens_dim, 8), max_tile_tokens_dim)
128137
return tile_tokens_dim
129138

tests/moe/test_trtllm_gen_fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1838,7 +1838,7 @@ def cache_permute_indices():
18381838

18391839
@pytest.mark.parametrize("num_tokens", [1, 8, 1024])
18401840
@pytest.mark.parametrize("hidden_size", [1024, 8192])
1841-
@pytest.mark.parametrize("intermediate_size", [384, 768, 1024, 2048])
1841+
@pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 384])
18421842
@pytest.mark.parametrize(
18431843
"moe_impl",
18441844
[

0 commit comments

Comments
 (0)