@@ -631,6 +631,7 @@ def _fake_cutlass_fused_moe_sm100(
631
631
use_mxfp8_act_scaling : bool = False ,
632
632
min_latency_mode : bool = False ,
633
633
tune_max_num_tokens : int = 8192 ,
634
+ enable_pdl : Optional [bool ] = None ,
634
635
):
635
636
seq_len = input .shape [0 ]
636
637
hidden_size = fc2_expert_weights .shape [1 ]
@@ -1081,6 +1082,8 @@ def forward(
1081
1082
local_expert_offset : int ,
1082
1083
routed_scaling_factor : Optional [float ],
1083
1084
routing_method_type : int ,
1085
+ enable_pdl : bool ,
1086
+ do_finalize : bool ,
1084
1087
tactic : int = - 1 ,
1085
1088
do_preparation : bool = False ,
1086
1089
):
@@ -1153,7 +1156,8 @@ def forward(
1153
1156
routed_scaling_factor ,
1154
1157
tile_tokens_dim ,
1155
1158
routing_method_type ,
1156
- True , # do_finalize
1159
+ enable_pdl ,
1160
+ do_finalize ,
1157
1161
output ,
1158
1162
tactic ,
1159
1163
)
@@ -1473,6 +1477,8 @@ def trtllm_fp4_block_scale_moe_op(
1473
1477
local_expert_offset = local_expert_offset ,
1474
1478
routed_scaling_factor = routed_scaling_factor ,
1475
1479
routing_method_type = routing_method_type ,
1480
+ enable_pdl = enable_pdl ,
1481
+ do_finalize = do_finalize ,
1476
1482
)
1477
1483
1478
1484
# Call the C++ function for block scale MoE
@@ -1544,6 +1550,9 @@ def _fake_trtllm_fp4_block_scale_moe(
1544
1550
tile_tokens_dim : int ,
1545
1551
routing_method_type : int ,
1546
1552
do_finalize : bool ,
1553
+ enable_pdl : bool ,
1554
+ tune_max_num_tokens : int ,
1555
+ output : Optional [torch .Tensor ],
1547
1556
):
1548
1557
seq_len = hidden_states .shape [0 ]
1549
1558
hidden_size = hidden_states .shape [1 ]
@@ -1730,6 +1739,7 @@ def trtllm_fp4_block_scale_moe(
1730
1739
tile_tokens_dim : int = 8 ,
1731
1740
routing_method_type : int = 0 ,
1732
1741
do_finalize : bool = True ,
1742
+ enable_pdl : Optional [bool ] = None ,
1733
1743
tune_max_num_tokens : int = 1024 ,
1734
1744
output : Optional [torch .Tensor ] = None ,
1735
1745
) -> List [torch .Tensor ]:
@@ -1782,6 +1792,7 @@ def trtllm_fp4_block_scale_moe(
1782
1792
List[torch.Tensor]: List of output tensors. If do_finalize=True, returns the final MoE output.
1783
1793
Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing.
1784
1794
"""
1795
+ print (f"in trtllm_fp4_block_scale_moe, tune_max_num_tokens={ tune_max_num_tokens } " )
1785
1796
return get_trtllm_moe_sm100_module ().trtllm_fp4_block_scale_moe (
1786
1797
routing_logits ,
1787
1798
None ,
@@ -1812,6 +1823,7 @@ def trtllm_fp4_block_scale_moe(
1812
1823
tile_tokens_dim ,
1813
1824
routing_method_type ,
1814
1825
do_finalize ,
1826
+ enable_pdl ,
1815
1827
tune_max_num_tokens ,
1816
1828
output ,
1817
1829
)
0 commit comments