Skip to content

Commit eaa4fea

Browse files
committed
fix docstrings
1 parent 1b77ccc commit eaa4fea

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

flashinfer/fused_moe/core.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,10 +1282,10 @@ def trtllm_fp4_block_scale_moe(
12821282
Input tensor of routing logits. Supports float32, bfloat16.
12831283
routing_bias (Optional[torch.Tensor]): shape [num_experts]
12841284
Tensor of routing bias. Can be None for some routing methods. Must be the same type as routing logits.
1285-
hidden_states (torch.Tensor): shape [seq_len, hidden_size]
1286-
Tensor of input hidden states. Supports bfloat16, mxfp8
1287-
hidden_states_scale (Optional[torch.Tensor]): shape [seq_len, hidden_size // 32]
1288-
Scale tensor of mxfp8 hidden states. Dtype must be float8.
1285+
hidden_states (torch.Tensor): shape [seq_len, hidden_size // 2 if nvfp4 else hidden_size]
1286+
Tensor of input hidden states. Supports bfloat16, mxfp8, and nvfp4 (packed into uint8)
1287+
hidden_states_scale (Optional[torch.Tensor]): shape [seq_len, hidden_size // (32 if mxfp8, 16 if mxfp4)]
1288+
Scale tensor of mxfp8 / nvfp4 hidden states. Dtype must be float8.
12891289
gemm1_weights (torch.Tensor): shape [num_experts, 2 * intermediate_size, hidden_size // 2]
12901290
Tensor of FC1 weights. Dtype must be uint8 (packed fp4)
12911291
gemm1_weights_scale (torch.Tensor): shape [num_experts, 2 * intermediate_size, hidden_size // (32 if mxfp4 else 16)]
@@ -1396,10 +1396,10 @@ def trtllm_fp4_block_scale_routed_moe(
13961396
the least significant 16 bits represent the index of the chosen expert (unsigned).
13971397
routing_bias (Optional[torch.Tensor]): shape [num_experts]
13981398
Tensor of routing bias. Can be None for some routing methods. Must be the same type as routing logits.
1399-
hidden_states (torch.Tensor): shape [seq_len, hidden_size // 32]
1400-
Tensor of input hidden states. Supports bfloat16, mxfp8
1401-
hidden_states_scale (Optional[torch.Tensor]): shape [seq_len, hidden_size // 32]
1402-
Scale tensor of mxfp8 hidden states. Dtype must be float8.
1399+
hidden_states (torch.Tensor): shape [seq_len, hidden_size // 2 if nvfp4 else hidden_size]
1400+
Tensor of input hidden states. Supports bfloat16, mxfp8, and nvfp4 (packed into uint8)
1401+
hidden_states_scale (Optional[torch.Tensor]): shape [seq_len, hidden_size // (32 if mxfp8, 16 if mxfp4)]
1402+
Scale tensor of mxfp8 / nvfp4 hidden states. Dtype must be float8.
14031403
gemm1_weights (torch.Tensor): shape [num_experts, 2 * intermediate_size, hidden_size // 2]
14041404
Tensor of FC1 weights. Dtype must be uint8 (packed fp4)
14051405
gemm1_weights_scale (torch.Tensor): shape [num_experts, 2 * intermediate_size, hidden_size // (32 if mxfp4 else 16)]

0 commit comments

Comments
 (0)