Skip to content

Commit 35a55d9

Browse files
committed
minor
1 parent e533cac commit 35a55d9

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

flashinfer/fused_moe/core.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,7 @@ def _fake_cutlass_fused_moe_sm100(
631631
use_mxfp8_act_scaling: bool = False,
632632
min_latency_mode: bool = False,
633633
tune_max_num_tokens: int = 8192,
634+
enable_pdl: Optional[bool] = None,
634635
):
635636
seq_len = input.shape[0]
636637
hidden_size = fc2_expert_weights.shape[1]
@@ -1081,6 +1082,8 @@ def forward(
10811082
local_expert_offset: int,
10821083
routed_scaling_factor: Optional[float],
10831084
routing_method_type: int,
1085+
enable_pdl: bool,
1086+
do_finalize: bool,
10841087
tactic: int = -1,
10851088
do_preparation: bool = False,
10861089
):
@@ -1153,7 +1156,8 @@ def forward(
11531156
routed_scaling_factor,
11541157
tile_tokens_dim,
11551158
routing_method_type,
1156-
True, # do_finalize
1159+
enable_pdl,
1160+
do_finalize,
11571161
output,
11581162
tactic,
11591163
)
@@ -1473,6 +1477,8 @@ def trtllm_fp4_block_scale_moe_op(
14731477
local_expert_offset=local_expert_offset,
14741478
routed_scaling_factor=routed_scaling_factor,
14751479
routing_method_type=routing_method_type,
1480+
enable_pdl=enable_pdl,
1481+
do_finalize=do_finalize,
14761482
)
14771483

14781484
# Call the C++ function for block scale MoE
@@ -1544,6 +1550,9 @@ def _fake_trtllm_fp4_block_scale_moe(
15441550
tile_tokens_dim: int,
15451551
routing_method_type: int,
15461552
do_finalize: bool,
1553+
enable_pdl: bool,
1554+
tune_max_num_tokens: int,
1555+
output: Optional[torch.Tensor],
15471556
):
15481557
seq_len = hidden_states.shape[0]
15491558
hidden_size = hidden_states.shape[1]
@@ -1730,6 +1739,7 @@ def trtllm_fp4_block_scale_moe(
17301739
tile_tokens_dim: int = 8,
17311740
routing_method_type: int = 0,
17321741
do_finalize: bool = True,
1742+
enable_pdl: Optional[bool] = None,
17331743
tune_max_num_tokens: int = 1024,
17341744
output: Optional[torch.Tensor] = None,
17351745
) -> List[torch.Tensor]:
@@ -1782,6 +1792,7 @@ def trtllm_fp4_block_scale_moe(
17821792
List[torch.Tensor]: List of output tensors. If do_finalize=True, returns the final MoE output.
17831793
Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing.
17841794
"""
1795+
print(f"in trtllm_fp4_block_scale_moe, tune_max_num_tokens={tune_max_num_tokens}")
17851796
return get_trtllm_moe_sm100_module().trtllm_fp4_block_scale_moe(
17861797
routing_logits,
17871798
None,
@@ -1812,6 +1823,7 @@ def trtllm_fp4_block_scale_moe(
18121823
tile_tokens_dim,
18131824
routing_method_type,
18141825
do_finalize,
1826+
enable_pdl,
18151827
tune_max_num_tokens,
18161828
output,
18171829
)

0 commit comments

Comments
 (0)