Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,7 @@ def _fake_cutlass_fused_moe_sm100(
use_mxfp8_act_scaling: bool = False,
min_latency_mode: bool = False,
tune_max_num_tokens: int = 8192,
enable_pdl: Optional[bool] = None,
):
seq_len = input.shape[0]
hidden_size = fc2_expert_weights.shape[1]
Expand Down Expand Up @@ -947,6 +948,7 @@ def _fake_trtllm_fp8_per_tensor_scale_moe(
use_routing_scales_on_input: bool,
tile_tokens_dim: int = 8,
routing_method_type: int = 0,
enable_pdl: Optional[bool] = None,
):
seq_len = hidden_states.shape[0]
hidden_size = hidden_states.shape[1]
Expand Down Expand Up @@ -1031,6 +1033,7 @@ def _fake_trtllm_fp8_block_scale_moe(
routing_method_type: int = 0,
use_shuffled_weight: bool = False,
weight_layout: int = 0,
enable_pdl: Optional[bool] = None,
):
seq_len = hidden_states.shape[0]
hidden_size = hidden_states.shape[1]
Expand Down Expand Up @@ -1174,6 +1177,8 @@ def _fake_trtllm_fp4_block_scale_moe(
tile_tokens_dim: int,
routing_method_type: int,
do_finalize: bool,
enable_pdl: Optional[bool],
output: Optional[torch.Tensor],
):
seq_len = hidden_states.shape[0]
hidden_size = hidden_states.shape[1]
Expand Down Expand Up @@ -1360,6 +1365,7 @@ def trtllm_fp4_block_scale_moe(
tile_tokens_dim: int = 8,
routing_method_type: int = 0,
do_finalize: bool = True,
enable_pdl: Optional[bool] = None,
output: Optional[torch.Tensor] = None,
) -> List[torch.Tensor]:
"""FP4 block scale MoE operation.
Expand Down Expand Up @@ -1440,6 +1446,7 @@ def trtllm_fp4_block_scale_moe(
tile_tokens_dim,
routing_method_type,
do_finalize,
enable_pdl,
output,
)

Expand Down