Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,7 @@ def _fake_cutlass_fused_moe_sm100(
use_mxfp8_act_scaling: bool = False,
min_latency_mode: bool = False,
tune_max_num_tokens: int = 8192,
enable_pdl: Optional[bool] = None,
):
seq_len = input.shape[0]
hidden_size = fc2_expert_weights.shape[1]
Expand Down Expand Up @@ -947,6 +948,7 @@ def _fake_trtllm_fp8_per_tensor_scale_moe(
use_routing_scales_on_input: bool,
tile_tokens_dim: int = 8,
routing_method_type: int = 0,
enable_pdl: Optional[bool] = None,
):
seq_len = hidden_states.shape[0]
hidden_size = hidden_states.shape[1]
Expand Down Expand Up @@ -1031,6 +1033,7 @@ def _fake_trtllm_fp8_block_scale_moe(
routing_method_type: int = 0,
use_shuffled_weight: bool = False,
weight_layout: int = 0,
enable_pdl: Optional[bool] = None,
):
seq_len = hidden_states.shape[0]
hidden_size = hidden_states.shape[1]
Expand Down Expand Up @@ -1174,6 +1177,8 @@ def _fake_trtllm_fp4_block_scale_moe(
tile_tokens_dim: int,
routing_method_type: int,
do_finalize: bool,
enable_pdl: Optional[bool] = None,
output: Optional[torch.Tensor] = None,
):
seq_len = hidden_states.shape[0]
hidden_size = hidden_states.shape[1]
Expand Down Expand Up @@ -1231,6 +1236,7 @@ def trtllm_fp8_per_tensor_scale_moe(
use_routing_scales_on_input: Whether to use routing scales on input
tile_tokens_dim: Tile dimension for tokens (default: 8)
routing_method_type: Type of routing method to use (default: 0)
enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90.

Returns:
torch.Tensor: Output tensor of shape [seq_len, hidden_size]
Expand Down Expand Up @@ -1303,7 +1309,7 @@ def trtllm_fp8_block_scale_moe(
routed_scaling_factor: Scaling factor for routing
tile_tokens_dim: Tile dimension for tokens (default: 8)
routing_method_type: Type of routing method to use (default: 0)

enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90.
Returns:
torch.Tensor: Output tensor of shape [seq_len, hidden_size]
"""
Expand Down Expand Up @@ -1360,6 +1366,7 @@ def trtllm_fp4_block_scale_moe(
tile_tokens_dim: int = 8,
routing_method_type: int = 0,
do_finalize: bool = True,
enable_pdl: Optional[bool] = None,
output: Optional[torch.Tensor] = None,
) -> List[torch.Tensor]:
"""FP4 block scale MoE operation.
Expand Down Expand Up @@ -1405,7 +1412,7 @@ def trtllm_fp4_block_scale_moe(
do_finalize (bool): Whether to finalize the output (default: False)
output (Optional[torch.Tensor]): shape [seq_len, hidden_size]
Optional inplace output tensor.

enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90.
Returns:
List[torch.Tensor]: List of output tensors. If do_finalize=True, returns the final MoE output.
Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing.
Expand Down Expand Up @@ -1440,6 +1447,7 @@ def trtllm_fp4_block_scale_moe(
tile_tokens_dim,
routing_method_type,
do_finalize,
enable_pdl,
output,
)

Expand Down
3 changes: 3 additions & 0 deletions tests/test_trtllm_gen_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ def call_moe(
routed_scaling = kwargs["routed_scaling"]
routing_method_type = kwargs["routing_method_type"]
tile_tokens_dim = kwargs["tile_tokens_dim"]
enable_pdl = kwargs.get("enable_pdl")

# Generate block scales and quantize hidden states at runtime
hidden_states_fp8 = hidden_states_orig.to(torch.float8_e4m3fn)
Expand Down Expand Up @@ -738,6 +739,7 @@ def call_moe(
routing_method_type,
use_shuffled_weight=static_data["use_shuffled_weight"],
weight_layout=static_data["weight_layout"],
enable_pdl=enable_pdl,
)

return output.to(torch.float)
Expand Down Expand Up @@ -2040,6 +2042,7 @@ def test_moe_quantization_classes(
routing_method_type=routing_method_type,
tile_tokens_dim=tile_tokens_dim,
weight_processing=weight_processing,
enable_pdl=True,
)

# Compare outputs using moe_impl-specific tolerances
Expand Down