@@ -944,8 +944,8 @@ class MoERunner(TunableRunner):
944
944
DynamicTensorSpec (
945
945
(0 , 1 , 2 , 3 , 4 , 5 ),
946
946
(0 , 0 , 0 , 0 , 0 , 0 ),
947
- get_last_power_of_2_num_tokens_buckets (8192 ),
948
- lambda x : min (last_positive_power_of_2 (x ), 8192 ),
947
+ get_last_power_of_2_num_tokens_buckets (1024 , 8 ),
948
+ lambda x : min (last_positive_power_of_2 (x ), 1024 ),
949
949
dynamic_tensor_initializers ,
950
950
),
951
951
)
@@ -955,8 +955,8 @@ class MoERunner(TunableRunner):
955
955
DynamicTensorSpec (
956
956
(0 , 1 , 2 , 3 , 4 ),
957
957
(0 , 0 , 0 , 0 , 0 ),
958
- get_last_power_of_2_num_tokens_buckets (8192 ),
959
- lambda x : min (last_positive_power_of_2 (x ), 8192 ),
958
+ get_last_power_of_2_num_tokens_buckets (1024 , 8 ),
959
+ lambda x : min (last_positive_power_of_2 (x ), 1024 ),
960
960
dynamic_tensor_initializers [:5 ],
961
961
),
962
962
),
@@ -975,7 +975,6 @@ def __init__(
975
975
hidden_size : int ,
976
976
intermediate_size : int ,
977
977
tile_tokens_dim : Optional [int ] = None ,
978
- tune_max_num_tokens : int = 8192 ,
979
978
):
980
979
self .num_experts = num_experts
981
980
self .top_k = top_k
@@ -1360,7 +1359,7 @@ def trtllm_fp4_block_scale_moe_op(
1360
1359
tile_tokens_dim : int ,
1361
1360
routing_method_type : int ,
1362
1361
do_finalize : bool ,
1363
- tune_max_num_tokens : int = 8192 ,
1362
+ tune_max_num_tokens : int = 1024 ,
1364
1363
output : Optional [torch .Tensor ] = None ,
1365
1364
) -> List [torch .Tensor ]:
1366
1365
if routing_logits is None :
@@ -1408,7 +1407,6 @@ def trtllm_fp4_block_scale_moe_op(
1408
1407
hidden_size = hidden_size ,
1409
1408
intermediate_size = intermediate_size ,
1410
1409
tile_tokens_dim = tile_tokens_dim ,
1411
- tune_max_num_tokens = tune_max_num_tokens ,
1412
1410
)
1413
1411
tunning_config = (
1414
1412
MoERunner .tuning_config_no_hidden_states_scales
@@ -1701,7 +1699,7 @@ def trtllm_fp4_block_scale_moe(
1701
1699
tile_tokens_dim : int = 8 ,
1702
1700
routing_method_type : int = 0 ,
1703
1701
do_finalize : bool = True ,
1704
- tune_max_num_tokens : int = 8192 ,
1702
+ tune_max_num_tokens : int = 1024 ,
1705
1703
output : Optional [torch .Tensor ] = None ,
1706
1704
) -> List [torch .Tensor ]:
1707
1705
"""FP4 block scale MoE operation.
@@ -1745,6 +1743,7 @@ def trtllm_fp4_block_scale_moe(
1745
1743
- 3: Llama4 (Top1 -> Sigmoid)
1746
1744
- 4: RenormalizeNaive (Softmax -> TopK -> Renormalize)
1747
1745
do_finalize (bool): Whether to finalize the output (default: False)
1746
+ tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 1024)
1748
1747
output (Optional[torch.Tensor]): shape [seq_len, hidden_size]
1749
1748
Optional inplace output tensor.
1750
1749
@@ -1815,7 +1814,7 @@ def trtllm_fp4_block_scale_routed_moe(
1815
1814
tile_tokens_dim : int = 8 ,
1816
1815
routing_method_type : int = 0 ,
1817
1816
do_finalize : bool = True ,
1818
- tune_max_num_tokens : int = 8192 ,
1817
+ tune_max_num_tokens : int = 1024 ,
1819
1818
output : Optional [torch .Tensor ] = None ,
1820
1819
) -> List [torch .Tensor ]:
1821
1820
"""FP4 block scale MoE operation.
@@ -1861,6 +1860,7 @@ def trtllm_fp4_block_scale_routed_moe(
1861
1860
- 3: Llama4 (Top1 -> Sigmoid)
1862
1861
- 4: RenormalizeNaive (Softmax -> TopK -> Renormalize)
1863
1862
do_finalize (bool): Whether to finalize the output (default: False)
1863
+ tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 1024)
1864
1864
output (Optional[torch.Tensor]): shape [seq_len, hidden_size]
1865
1865
Optional inplace output tensor.
1866
1866
0 commit comments