@@ -1282,10 +1282,10 @@ def trtllm_fp4_block_scale_moe(
1282
1282
Input tensor of routing logits. Supports float32, bfloat16.
1283
1283
routing_bias (Optional[torch.Tensor]): shape [num_experts]
1284
1284
Tensor of routing bias. Can be None for some routing methods. Must be the same type as routing logits.
1285
- hidden_states (torch.Tensor): shape [seq_len, hidden_size]
1286
- Tensor of input hidden states. Supports bfloat16, mxfp8
1287
- hidden_states_scale (Optional[torch.Tensor]): shape [seq_len, hidden_size // 32 ]
1288
- Scale tensor of mxfp8 hidden states. Dtype must be float8.
1285
+ hidden_states (torch.Tensor): shape [seq_len, hidden_size // 2 if nvfp4 else hidden_size ]
1286
+ Tensor of input hidden states. Supports bfloat16, mxfp8, and nvfp4 (packed into uint8)
1287
+ hidden_states_scale (Optional[torch.Tensor]): shape [seq_len, hidden_size // (32 if mxfp8, 16 if mxfp4) ]
1288
+ Scale tensor of mxfp8 / nvfp4 hidden states. Dtype must be float8.
1289
1289
gemm1_weights (torch.Tensor): shape [num_experts, 2 * intermediate_size, hidden_size // 2]
1290
1290
Tensor of FC1 weights. Dtype must be uint8 (packed fp4)
1291
1291
gemm1_weights_scale (torch.Tensor): shape [num_experts, 2 * intermediate_size, hidden_size // (32 if mxfp4 else 16)]
@@ -1396,10 +1396,10 @@ def trtllm_fp4_block_scale_routed_moe(
1396
1396
the least significant 16 bits represent the index of the chosen expert (unsigned).
1397
1397
routing_bias (Optional[torch.Tensor]): shape [num_experts]
1398
1398
Tensor of routing bias. Can be None for some routing methods. Must be the same type as routing logits.
1399
- hidden_states (torch.Tensor): shape [seq_len, hidden_size // 32 ]
1400
- Tensor of input hidden states. Supports bfloat16, mxfp8
1401
- hidden_states_scale (Optional[torch.Tensor]): shape [seq_len, hidden_size // 32 ]
1402
- Scale tensor of mxfp8 hidden states. Dtype must be float8.
1399
+ hidden_states (torch.Tensor): shape [seq_len, hidden_size // 2 if nvfp4 else hidden_size ]
1400
+ Tensor of input hidden states. Supports bfloat16, mxfp8, and nvfp4 (packed into uint8)
1401
+ hidden_states_scale (Optional[torch.Tensor]): shape [seq_len, hidden_size // (32 if mxfp8, 16 if mxfp4) ]
1402
+ Scale tensor of mxfp8 / nvfp4 hidden states. Dtype must be float8.
1403
1403
gemm1_weights (torch.Tensor): shape [num_experts, 2 * intermediate_size, hidden_size // 2]
1404
1404
Tensor of FC1 weights. Dtype must be uint8 (packed fp4)
1405
1405
gemm1_weights_scale (torch.Tensor): shape [num_experts, 2 * intermediate_size, hidden_size // (32 if mxfp4 else 16)]
0 commit comments