@@ -1082,6 +1082,8 @@ def forward(
1082
1082
local_expert_offset : int ,
1083
1083
routed_scaling_factor : Optional [float ],
1084
1084
routing_method_type : int ,
1085
+ enable_pdl : bool ,
1086
+ do_finalize : bool ,
1085
1087
tactic : int = - 1 ,
1086
1088
do_preparation : bool = False ,
1087
1089
):
@@ -1154,7 +1156,8 @@ def forward(
1154
1156
routed_scaling_factor ,
1155
1157
tile_tokens_dim ,
1156
1158
routing_method_type ,
1157
- True , # do_finalize
1159
+ enable_pdl ,
1160
+ do_finalize ,
1158
1161
output ,
1159
1162
tactic ,
1160
1163
)
@@ -1476,6 +1479,8 @@ def trtllm_fp4_block_scale_moe_op(
1476
1479
local_expert_offset = local_expert_offset ,
1477
1480
routed_scaling_factor = routed_scaling_factor ,
1478
1481
routing_method_type = routing_method_type ,
1482
+ enable_pdl = enable_pdl ,
1483
+ do_finalize = do_finalize ,
1479
1484
)
1480
1485
1481
1486
# Call the C++ function for block scale MoE
@@ -1547,8 +1552,9 @@ def _fake_trtllm_fp4_block_scale_moe(
1547
1552
tile_tokens_dim : int ,
1548
1553
routing_method_type : int ,
1549
1554
do_finalize : bool ,
1550
- enable_pdl : Optional [bool ] = None ,
1551
- output : Optional [torch .Tensor ] = None ,
1555
+ enable_pdl : bool ,
1556
+ tune_max_num_tokens : int ,
1557
+ output : Optional [torch .Tensor ],
1552
1558
):
1553
1559
seq_len = hidden_states .shape [0 ]
1554
1560
hidden_size = hidden_states .shape [1 ]
@@ -1789,6 +1795,7 @@ def trtllm_fp4_block_scale_moe(
1789
1795
List[torch.Tensor]: List of output tensors. If do_finalize=True, returns the final MoE output.
1790
1796
Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing.
1791
1797
"""
1798
+ print (f"in trtllm_fp4_block_scale_moe, tune_max_num_tokens={ tune_max_num_tokens } " )
1792
1799
return get_trtllm_moe_sm100_module ().trtllm_fp4_block_scale_moe (
1793
1800
routing_logits ,
1794
1801
None ,
0 commit comments