Skip to content

Commit e4fb808

Browse files
committed
upd
1 parent c4270bb commit e4fb808

File tree

3 files changed

+16
-14
lines changed

3 files changed

+16
-14
lines changed

flashinfer/fused_moe/core.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -944,8 +944,8 @@ class MoERunner(TunableRunner):
944944
DynamicTensorSpec(
945945
(0, 1, 2, 3, 4, 5),
946946
(0, 0, 0, 0, 0, 0),
947-
get_last_power_of_2_num_tokens_buckets(8192),
948-
lambda x: min(last_positive_power_of_2(x), 8192),
947+
get_last_power_of_2_num_tokens_buckets(1024, 8),
948+
lambda x: min(last_positive_power_of_2(x), 1024),
949949
dynamic_tensor_initializers,
950950
),
951951
)
@@ -955,8 +955,8 @@ class MoERunner(TunableRunner):
955955
DynamicTensorSpec(
956956
(0, 1, 2, 3, 4),
957957
(0, 0, 0, 0, 0),
958-
get_last_power_of_2_num_tokens_buckets(8192),
959-
lambda x: min(last_positive_power_of_2(x), 8192),
958+
get_last_power_of_2_num_tokens_buckets(1024, 8),
959+
lambda x: min(last_positive_power_of_2(x), 1024),
960960
dynamic_tensor_initializers[:5],
961961
),
962962
),
@@ -975,7 +975,6 @@ def __init__(
975975
hidden_size: int,
976976
intermediate_size: int,
977977
tile_tokens_dim: Optional[int] = None,
978-
tune_max_num_tokens: int = 8192,
979978
):
980979
self.num_experts = num_experts
981980
self.top_k = top_k
@@ -1360,7 +1359,7 @@ def trtllm_fp4_block_scale_moe_op(
13601359
tile_tokens_dim: int,
13611360
routing_method_type: int,
13621361
do_finalize: bool,
1363-
tune_max_num_tokens: int = 8192,
1362+
tune_max_num_tokens: int = 1024,
13641363
output: Optional[torch.Tensor] = None,
13651364
) -> List[torch.Tensor]:
13661365
if routing_logits is None:
@@ -1408,7 +1407,6 @@ def trtllm_fp4_block_scale_moe_op(
14081407
hidden_size=hidden_size,
14091408
intermediate_size=intermediate_size,
14101409
tile_tokens_dim=tile_tokens_dim,
1411-
tune_max_num_tokens=tune_max_num_tokens,
14121410
)
14131411
tunning_config = (
14141412
MoERunner.tuning_config_no_hidden_states_scales
@@ -1701,7 +1699,7 @@ def trtllm_fp4_block_scale_moe(
17011699
tile_tokens_dim: int = 8,
17021700
routing_method_type: int = 0,
17031701
do_finalize: bool = True,
1704-
tune_max_num_tokens: int = 8192,
1702+
tune_max_num_tokens: int = 1024,
17051703
output: Optional[torch.Tensor] = None,
17061704
) -> List[torch.Tensor]:
17071705
"""FP4 block scale MoE operation.
@@ -1745,6 +1743,7 @@ def trtllm_fp4_block_scale_moe(
17451743
- 3: Llama4 (Top1 -> Sigmoid)
17461744
- 4: RenormalizeNaive (Softmax -> TopK -> Renormalize)
17471745
do_finalize (bool): Whether to finalize the output (default: False)
1746+
tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 1024)
17481747
output (Optional[torch.Tensor]): shape [seq_len, hidden_size]
17491748
Optional inplace output tensor.
17501749
@@ -1815,7 +1814,7 @@ def trtllm_fp4_block_scale_routed_moe(
18151814
tile_tokens_dim: int = 8,
18161815
routing_method_type: int = 0,
18171816
do_finalize: bool = True,
1818-
tune_max_num_tokens: int = 8192,
1817+
tune_max_num_tokens: int = 1024,
18191818
output: Optional[torch.Tensor] = None,
18201819
) -> List[torch.Tensor]:
18211820
"""FP4 block scale MoE operation.
@@ -1861,6 +1860,7 @@ def trtllm_fp4_block_scale_routed_moe(
18611860
- 3: Llama4 (Top1 -> Sigmoid)
18621861
- 4: RenormalizeNaive (Softmax -> TopK -> Renormalize)
18631862
do_finalize (bool): Whether to finalize the output (default: False)
1863+
tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 1024)
18641864
output (Optional[torch.Tensor]): shape [seq_len, hidden_size]
18651865
Optional inplace output tensor.
18661866

flashinfer/fused_moe/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,11 +203,13 @@ def get_power_of_2_num_tokens_buckets(max_num_tokens) -> Tuple[int]:
203203
return tuple(num_token_buckets)
204204

205205

206-
def get_last_power_of_2_num_tokens_buckets(max_num_tokens) -> Tuple[int]:
206+
def get_last_power_of_2_num_tokens_buckets(
207+
max_num_tokens, min_num_tokens=1
208+
) -> Tuple[int]:
207209
max_num_tokens = last_positive_power_of_2(max_num_tokens)
208210
num_token_buckets = []
209211
m = max_num_tokens
210-
while m >= 1:
212+
while m >= min_num_tokens:
211213
num_token_buckets.append(m)
212214
m //= 2
213215
return tuple(num_token_buckets)

tests/test_trtllm_gen_fused_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1733,9 +1733,9 @@ def cache_permute_indices():
17331733
@pytest.mark.parametrize(
17341734
"moe_impl",
17351735
[
1736-
pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4 x NvFP4"),
1737-
pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4 x MxFP8"),
1738-
pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4 x Bf16"),
1736+
pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"),
1737+
pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"),
1738+
pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"),
17391739
pytest.param(FP8BlockScaleMoe(), id="FP8_Block"),
17401740
pytest.param(FP8PerTensorMoe(), id="FP8_Tensor"),
17411741
],

0 commit comments

Comments
 (0)