1414# limitations under the License.
1515
1616
17+ import math
18+
1719import torch
1820
21+ from tensorrt_llm ._torch .auto_deploy .custom_ops .quant import (
22+ TRTLLM_NVFP4_COLUMN_SIZE ,
23+ TRTLLM_NVFP4_ROW_SIZE ,
24+ TRTLLM_NVFP4_SCALING_VECTOR_SIZE ,
25+ )
1926from tensorrt_llm ._torch .utils import ActivationType
2027
2128
@@ -212,17 +219,17 @@ def trtllm_quant_fp8_moe_fused_fake(
212219
213220@torch .library .custom_op ("auto_deploy::trtllm_quant_nvfp4_moe_fused" , mutates_args = ())
214221def trtllm_quant_nvfp4_moe_fused (
215- x : torch .Tensor , # [B, S, H] or [B*S, H], 16-bit float
222+ x : torch .Tensor ,
216223 selected_experts : torch .Tensor ,
217224 routing_weights : torch .Tensor ,
218- fc1_expert_weights_fp4 : torch .Tensor , # [E, 2*I, H] or [E, I, H]; uint8
219- fc2_expert_weights_fp4 : torch .Tensor , # [E, H, I]; uint8
220- fc1_weight_blockscale_fp8 : torch .Tensor , # Global scale for fc1 (scalar)
221- fc2_weight_blockscale_fp8 : torch .Tensor , # Global scale for w2 (scalar)
222- fc1_act_global_scale : torch .Tensor , # Global scale for FC1 activations
223- fc2_act_global_scale : torch .Tensor , # Global scale for FC2 activations
224- fc1_alpha : torch .Tensor , # Precomputed FC1 alpha (1.0 / (fc1_act_global_scale * fc1_weight_blockscale_fp8))
225- fc2_alpha : torch .Tensor , # Precomputed FC2 alpha (1.0 / (fc2_act_global_scale * fc2_weight_blockscale_fp8))
225+ fc1_expert_weights_fp4 : torch .Tensor ,
226+ fc2_expert_weights_fp4 : torch .Tensor ,
227+ fc1_weight_blockscale_fp8 : torch .Tensor ,
228+ fc2_weight_blockscale_fp8 : torch .Tensor ,
229+ fc1_act_global_scale : torch .Tensor ,
230+ fc2_act_global_scale : torch .Tensor ,
231+ fc1_alpha : torch .Tensor ,
232+ fc2_alpha : torch .Tensor ,
226233 is_gated_mlp : bool = True ,
227234 act_fn : int = int (ActivationType .Silu ),
228235) -> torch .Tensor :
@@ -234,28 +241,100 @@ def trtllm_quant_nvfp4_moe_fused(
234241 For mlp:
235242 y = act(x @ w1.T) @ w2.T # act := ReLU^2
236243
244+ Notes:
245+ - FC1 implements: fc1_output = (act(x @ w1.T) * (x @ w3.T)) or fc1_output = act(x @ w1.T)
246+ - FC2 implements: fc2_output = fc1_output @ w2.T
247+ - FC1 weights are concatenated w3 and w1 if gated_mlp, otherwise w1
248+ - FP4 elements pairs are packed as a single uint8 element
237249
238- FC1 implements: fc1_output = (act(x @ w1.T) * (x @ w3.T)) or fc1_output = act(x @ w1.T)
239- FC2 implements: fc2_output = fc1_output @ w2.T
240-
250+ Parameters:
251+ x: BF16/FP16 input tensor of shape (B, H) or (B, S, H)
252+ selected_experts: Expert indices (B*S, TOP_K)
253+ routing_weights: Routing weights (B*S, TOP_K)
254+ fc1_expert_weights_fp4: FP4 FC1 weights [E, 2*I, H/2] or [E, I, H/2]; packed uint8
255+ fc2_expert_weights_fp4: FP4 FC2 weights [E, H, I/2]; packed uint8
256+ fc1_weight_blockscale_fp8: Block scales for FC1 weights (w1 or cat(w3, w1))
257+ fc2_weight_blockscale_fp8: Block scales for FC2 weights (w2)
258+ fc1_act_global_scale: Global scale for FC1 activations (scalar)
259+ fc2_act_global_scale: Global scale for FC2 activations (scalar)
260+ fc1_alpha: FC1 dequant scales = 1.0 / (fc1_act_global_scale * fc1_weight_global_scale)
261+ fc2_alpha: FC2 dequant scales = 1.0 / (fc2_act_global_scale * fc2_weight_global_scale)
262+ mlp_style: "gated_mlp" or "mlp"
263+ act_fn: "silu" for gated_mlp, "relu2" for mlp
241264 """
242- NVFP4_BLOCK_SIZE = 16
265+ NVFP4_BLOCK_SIZE = TRTLLM_NVFP4_SCALING_VECTOR_SIZE
266+ FP4_PER_UINT8 = 2
243267
244- activation_type = ActivationType .Swiglu
245- if is_gated_mlp :
246- if act_fn in [ActivationType .Silu , ActivationType .Swiglu ]:
247- activation_type = ActivationType .Swiglu
248- else :
249- raise ValueError (
250- f"Unsupported activation '{ ActivationType (act_fn ).name } ' for gated_mlp. Use 'silu'."
251- )
268+ _ , fc1_inter_size , _ = fc1_expert_weights_fp4 .shape
269+ n_experts , hidden_size , inter_size = fc2_expert_weights_fp4 .shape
270+
271+ # Convert the inter_size from number of uint8 elements to number of FP4 elements.
272+ inter_size *= FP4_PER_UINT8
273+
274+ # Validate shapes and padding requirements as defined by the cutlass kernel.
275+ assert fc1_weight_blockscale_fp8 .ndim == 3 , "fc1_weight_blockscale_fp8 must be 3D"
276+ assert fc2_weight_blockscale_fp8 .ndim == 3 , "fc2_weight_blockscale_fp8 must be 3D"
277+ assert fc1_weight_blockscale_fp8 .size (1 ) % TRTLLM_NVFP4_ROW_SIZE == 0
278+ assert fc2_weight_blockscale_fp8 .size (1 ) % TRTLLM_NVFP4_ROW_SIZE == 0
279+ assert fc1_weight_blockscale_fp8 .size (2 ) % TRTLLM_NVFP4_COLUMN_SIZE == 0
280+ assert fc2_weight_blockscale_fp8 .size (2 ) % TRTLLM_NVFP4_COLUMN_SIZE == 0
281+
282+ _validate_mlp_style_and_act_fn (is_gated_mlp , act_fn )
283+ act_fn = ActivationType .Swiglu if act_fn == ActivationType .Silu else act_fn
284+
285+ if x .dtype in (torch .float16 , torch .bfloat16 ):
286+ x_q_fp4 , input_blockscale = torch .ops .trtllm .fp4_quantize (
287+ x , fc1_act_global_scale , NVFP4_BLOCK_SIZE
288+ )
289+ output_dtype = x .dtype
252290 else :
253- if act_fn == ActivationType .Relu2 :
254- activation_type = ActivationType .Relu2
255- else :
256- raise ValueError (
257- f"Unsupported activation '{ ActivationType (act_fn ).name } ' for mlp. Use 'relu2'."
258- )
291+ x_q_fp4 = x
292+ input_blockscale = None
293+ output_dtype = x .dtype
294+
295+ # Pad inter_size to be divisible by 128
296+ inter_size_padded = math .ceil (inter_size / TRTLLM_NVFP4_ROW_SIZE ) * TRTLLM_NVFP4_ROW_SIZE
297+ fc1_inter_size_padded = (
298+ math .ceil (fc1_inter_size / TRTLLM_NVFP4_ROW_SIZE ) * TRTLLM_NVFP4_ROW_SIZE
299+ )
300+ hidden_size_padded = (
301+ math .ceil (hidden_size / TRTLLM_NVFP4_COLUMN_SIZE ) * TRTLLM_NVFP4_COLUMN_SIZE
302+ )
303+
304+ inter_size_needs_padding = (is_gated_mlp and fc1_inter_size_padded != fc1_inter_size ) or (
305+ not is_gated_mlp and inter_size_padded != inter_size
306+ )
307+ hidden_size_needs_padding = hidden_size % TRTLLM_NVFP4_COLUMN_SIZE != 0
308+ if inter_size_needs_padding or hidden_size_needs_padding :
309+ assert False , "See https://github.com/NVIDIA/TensorRT-LLM/issues/10331"
310+ # fc1_expert_weights_fp4: [E, I, H] or [E, 2*I, H]
311+ fc1_padded = fc1_expert_weights_fp4 .new_zeros (
312+ fc1_expert_weights_fp4 .size (0 ),
313+ fc1_inter_size_padded ,
314+ hidden_size_padded // FP4_PER_UINT8 ,
315+ )
316+ fc1_padded [:, :fc1_inter_size , :] = fc1_expert_weights_fp4
317+ fc1_expert_weights_fp4 = fc1_padded
318+
319+ # fc2_expert_weights_fp4: [E, H, I]
320+ fc2_padded = fc2_expert_weights_fp4 .new_zeros (
321+ n_experts , hidden_size_padded , inter_size_padded // FP4_PER_UINT8
322+ )
323+
324+ assert inter_size % NVFP4_BLOCK_SIZE == 0 , (
325+ f"inter_size { inter_size } must be divisible by { NVFP4_BLOCK_SIZE } "
326+ )
327+
328+ fc2_padded [:, :, : inter_size // FP4_PER_UINT8 ] = fc2_expert_weights_fp4
329+ fc2_expert_weights_fp4 = fc2_padded
330+
331+ fc2_blockscale_fp8_padded = fc2_weight_blockscale_fp8 .new_zeros (
332+ n_experts , hidden_size_padded , inter_size_padded // NVFP4_BLOCK_SIZE
333+ )
334+ fc2_blockscale_fp8_padded [:, :, : inter_size // NVFP4_BLOCK_SIZE ] = (
335+ fc2_weight_blockscale_fp8
336+ )
337+ fc2_weight_blockscale_fp8 = fc2_blockscale_fp8_padded
259338
260339 # quant_scales is described by this code:
261340 # https://github.com/NVIDIA/TensorRT-LLM/blob/c9771ebb997683c08b26bbba796a7fc6aff09d93/cpp/tensorrt_llm/thop/moeOp.cpp#L1015
@@ -270,26 +349,19 @@ def trtllm_quant_nvfp4_moe_fused(
270349 fc2_alpha , # torch.float32; [E]
271350 ]
272351
273- if x .dtype in (torch .float16 , torch .bfloat16 ):
274- x_q_fp4 , input_blockscale = torch .ops .trtllm .fp4_quantize (
275- x , fc1_act_global_scale , NVFP4_BLOCK_SIZE
276- )
277- output_dtype = x .dtype
278- else :
279- x_q_fp4 = x
280-
281352 trtllm_output = torch .ops .trtllm .fused_moe (
282- x_q_fp4 ,
283- selected_experts .to (torch .int ),
284- routing_weights ,
285- fc1_expert_weights = fc1_expert_weights_fp4 ,
353+ x_q_fp4 .view (torch .long ),
354+ selected_experts .to (torch .int32 ),
355+ routing_weights .to (torch .float32 ),
356+ # Groups of 16 FP4 weight elements are packed as a single int64 element (see isNvfp4Quant in moeOp.cpp)
357+ fc1_expert_weights = fc1_expert_weights_fp4 .view (torch .long ),
286358 fc1_expert_biases = None ,
287359 fc2_expert_weights = fc2_expert_weights_fp4 .view (torch .long ),
288360 fc2_expert_biases = None ,
289361 output_dtype = output_dtype ,
290362 quant_scales = quant_scales ,
291363 input_sf = input_blockscale ,
292- activation_type = activation_type ,
364+ activation_type = act_fn ,
293365 )[0 ].view (x .shape )
294366
295367 return trtllm_output
0 commit comments