diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index b9ece319e..94453fa5b 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -593,6 +593,7 @@ def _fake_cutlass_fused_moe_sm100( use_mxfp8_act_scaling: bool = False, min_latency_mode: bool = False, tune_max_num_tokens: int = 8192, + enable_pdl: Optional[bool] = None, ): seq_len = input.shape[0] hidden_size = fc2_expert_weights.shape[1] @@ -947,6 +948,7 @@ def _fake_trtllm_fp8_per_tensor_scale_moe( use_routing_scales_on_input: bool, tile_tokens_dim: int = 8, routing_method_type: int = 0, + enable_pdl: Optional[bool] = None, ): seq_len = hidden_states.shape[0] hidden_size = hidden_states.shape[1] @@ -1031,6 +1033,7 @@ def _fake_trtllm_fp8_block_scale_moe( routing_method_type: int = 0, use_shuffled_weight: bool = False, weight_layout: int = 0, + enable_pdl: Optional[bool] = None, ): seq_len = hidden_states.shape[0] hidden_size = hidden_states.shape[1] @@ -1174,6 +1177,8 @@ def _fake_trtllm_fp4_block_scale_moe( tile_tokens_dim: int, routing_method_type: int, do_finalize: bool, + enable_pdl: Optional[bool] = None, + output: Optional[torch.Tensor] = None, ): seq_len = hidden_states.shape[0] hidden_size = hidden_states.shape[1] @@ -1231,6 +1236,7 @@ def trtllm_fp8_per_tensor_scale_moe( use_routing_scales_on_input: Whether to use routing scales on input tile_tokens_dim: Tile dimension for tokens (default: 8) routing_method_type: Type of routing method to use (default: 0) + enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. Returns: torch.Tensor: Output tensor of shape [seq_len, hidden_size] @@ -1303,7 +1309,7 @@ def trtllm_fp8_block_scale_moe( routed_scaling_factor: Scaling factor for routing tile_tokens_dim: Tile dimension for tokens (default: 8) routing_method_type: Type of routing method to use (default: 0) - + enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. Returns: torch.Tensor: Output tensor of shape [seq_len, hidden_size] """ @@ -1360,6 +1366,7 @@ def trtllm_fp4_block_scale_moe( tile_tokens_dim: int = 8, routing_method_type: int = 0, do_finalize: bool = True, + enable_pdl: Optional[bool] = None, output: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: """FP4 block scale MoE operation. @@ -1405,7 +1412,7 @@ def trtllm_fp4_block_scale_moe( do_finalize (bool): Whether to finalize the output (default: False) output (Optional[torch.Tensor]): shape [seq_len, hidden_size] Optional inplace output tensor. - + enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. Returns: List[torch.Tensor]: List of output tensors. If do_finalize=True, returns the final MoE output. Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing. @@ -1440,6 +1447,7 @@ def trtllm_fp4_block_scale_moe( tile_tokens_dim, routing_method_type, do_finalize, + enable_pdl, output, ) diff --git a/tests/test_trtllm_gen_fused_moe.py b/tests/test_trtllm_gen_fused_moe.py index 877180f72..f7e7ec09b 100644 --- a/tests/test_trtllm_gen_fused_moe.py +++ b/tests/test_trtllm_gen_fused_moe.py @@ -709,6 +709,7 @@ def call_moe( routed_scaling = kwargs["routed_scaling"] routing_method_type = kwargs["routing_method_type"] tile_tokens_dim = kwargs["tile_tokens_dim"] + enable_pdl = kwargs.get("enable_pdl") # Generate block scales and quantize hidden states at runtime hidden_states_fp8 = hidden_states_orig.to(torch.float8_e4m3fn) @@ -738,6 +739,7 @@ def call_moe( routing_method_type, use_shuffled_weight=static_data["use_shuffled_weight"], weight_layout=static_data["weight_layout"], + enable_pdl=enable_pdl, ) return output.to(torch.float) @@ -2040,6 +2042,7 @@ def test_moe_quantization_classes( routing_method_type=routing_method_type, tile_tokens_dim=tile_tokens_dim, weight_processing=weight_processing, + enable_pdl=True, ) # Compare outputs using moe_impl-specific tolerances