Skip to content

Commit 8870384

Browse files
IwakuraReingemini-code-assist[bot]yyihuang
authored
fix missing enable_pdl argument in trtllm-gen fp4 moe (#1480)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description Fix the missing `enable_pdl` argument introduced in #1446 . ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Avery Yingyi Huang <[email protected]>
1 parent 1d29426 commit 8870384

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

β€Žflashinfer/fused_moe/core.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,7 @@ def _fake_cutlass_fused_moe_sm100(
593593
use_mxfp8_act_scaling: bool = False,
594594
min_latency_mode: bool = False,
595595
tune_max_num_tokens: int = 8192,
596+
enable_pdl: Optional[bool] = None,
596597
):
597598
seq_len = input.shape[0]
598599
hidden_size = fc2_expert_weights.shape[1]
@@ -947,6 +948,7 @@ def _fake_trtllm_fp8_per_tensor_scale_moe(
947948
use_routing_scales_on_input: bool,
948949
tile_tokens_dim: int = 8,
949950
routing_method_type: int = 0,
951+
enable_pdl: Optional[bool] = None,
950952
):
951953
seq_len = hidden_states.shape[0]
952954
hidden_size = hidden_states.shape[1]
@@ -1031,6 +1033,7 @@ def _fake_trtllm_fp8_block_scale_moe(
10311033
routing_method_type: int = 0,
10321034
use_shuffled_weight: bool = False,
10331035
weight_layout: int = 0,
1036+
enable_pdl: Optional[bool] = None,
10341037
):
10351038
seq_len = hidden_states.shape[0]
10361039
hidden_size = hidden_states.shape[1]
@@ -1174,6 +1177,8 @@ def _fake_trtllm_fp4_block_scale_moe(
11741177
tile_tokens_dim: int,
11751178
routing_method_type: int,
11761179
do_finalize: bool,
1180+
enable_pdl: Optional[bool] = None,
1181+
output: Optional[torch.Tensor] = None,
11771182
):
11781183
seq_len = hidden_states.shape[0]
11791184
hidden_size = hidden_states.shape[1]
@@ -1231,6 +1236,7 @@ def trtllm_fp8_per_tensor_scale_moe(
12311236
use_routing_scales_on_input: Whether to use routing scales on input
12321237
tile_tokens_dim: Tile dimension for tokens (default: 8)
12331238
routing_method_type: Type of routing method to use (default: 0)
1239+
enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90.
12341240
12351241
Returns:
12361242
torch.Tensor: Output tensor of shape [seq_len, hidden_size]
@@ -1303,7 +1309,7 @@ def trtllm_fp8_block_scale_moe(
13031309
routed_scaling_factor: Scaling factor for routing
13041310
tile_tokens_dim: Tile dimension for tokens (default: 8)
13051311
routing_method_type: Type of routing method to use (default: 0)
1306-
1312+
enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90.
13071313
Returns:
13081314
torch.Tensor: Output tensor of shape [seq_len, hidden_size]
13091315
"""
@@ -1360,6 +1366,7 @@ def trtllm_fp4_block_scale_moe(
13601366
tile_tokens_dim: int = 8,
13611367
routing_method_type: int = 0,
13621368
do_finalize: bool = True,
1369+
enable_pdl: Optional[bool] = None,
13631370
output: Optional[torch.Tensor] = None,
13641371
) -> List[torch.Tensor]:
13651372
"""FP4 block scale MoE operation.
@@ -1405,7 +1412,7 @@ def trtllm_fp4_block_scale_moe(
14051412
do_finalize (bool): Whether to finalize the output (default: False)
14061413
output (Optional[torch.Tensor]): shape [seq_len, hidden_size]
14071414
Optional inplace output tensor.
1408-
1415+
enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90.
14091416
Returns:
14101417
List[torch.Tensor]: List of output tensors. If do_finalize=True, returns the final MoE output.
14111418
Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing.
@@ -1440,6 +1447,7 @@ def trtllm_fp4_block_scale_moe(
14401447
tile_tokens_dim,
14411448
routing_method_type,
14421449
do_finalize,
1450+
enable_pdl,
14431451
output,
14441452
)
14451453

β€Žtests/test_trtllm_gen_fused_moe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,7 @@ def call_moe(
709709
routed_scaling = kwargs["routed_scaling"]
710710
routing_method_type = kwargs["routing_method_type"]
711711
tile_tokens_dim = kwargs["tile_tokens_dim"]
712+
enable_pdl = kwargs.get("enable_pdl")
712713

713714
# Generate block scales and quantize hidden states at runtime
714715
hidden_states_fp8 = hidden_states_orig.to(torch.float8_e4m3fn)
@@ -738,6 +739,7 @@ def call_moe(
738739
routing_method_type,
739740
use_shuffled_weight=static_data["use_shuffled_weight"],
740741
weight_layout=static_data["weight_layout"],
742+
enable_pdl=enable_pdl,
741743
)
742744

743745
return output.to(torch.float)
@@ -2040,6 +2042,7 @@ def test_moe_quantization_classes(
20402042
routing_method_type=routing_method_type,
20412043
tile_tokens_dim=tile_tokens_dim,
20422044
weight_processing=weight_processing,
2045+
enable_pdl=True,
20432046
)
20442047

20452048
# Compare outputs using moe_impl-specific tolerances

0 commit comments

Comments
Β (0)