diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index e718ed5c9d9..6af943d9b46 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -126,6 +126,9 @@ transforms: stage: post_load_fusion enabled: true backend: trtllm + fuse_nvfp4_moe: + stage: post_load_fusion + enabled: false fuse_allreduce_residual_rmsnorm: stage: post_load_fusion # TODO (lucaslie): add backend selection as part of configurable inference optimizers diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index 17b5a0afa63..3a0ab6b4a95 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -14,8 +14,15 @@ # limitations under the License. +import math + import torch +from tensorrt_llm._torch.auto_deploy.custom_ops.quant import ( + TRTLLM_NVFP4_COLUMN_SIZE, + TRTLLM_NVFP4_ROW_SIZE, + TRTLLM_NVFP4_SCALING_VECTOR_SIZE, +) from tensorrt_llm._torch.utils import ActivationType @@ -212,17 +219,17 @@ def trtllm_quant_fp8_moe_fused_fake( @torch.library.custom_op("auto_deploy::trtllm_quant_nvfp4_moe_fused", mutates_args=()) def trtllm_quant_nvfp4_moe_fused( - x: torch.Tensor, # [B, S, H] or [B*S, H], 16-bit float + x: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor, - fc1_expert_weights_fp4: torch.Tensor, # [E, 2*I, H] or [E, I, H]; uint8 - fc2_expert_weights_fp4: torch.Tensor, # [E, H, I]; uint8 - fc1_weight_blockscale_fp8: torch.Tensor, # Global scale for fc1 (scalar) - fc2_weight_blockscale_fp8: torch.Tensor, # Global scale for w2 (scalar) - fc1_act_global_scale: torch.Tensor, # Global scale for FC1 activations - fc2_act_global_scale: torch.Tensor, # Global scale for FC2 activations - fc1_alpha: torch.Tensor, # Precomputed FC1 alpha (1.0 / (fc1_act_global_scale * fc1_weight_blockscale_fp8)) - fc2_alpha: torch.Tensor, # Precomputed FC2 alpha (1.0 / (fc2_act_global_scale * fc2_weight_blockscale_fp8)) + fc1_expert_weights_fp4: torch.Tensor, + fc2_expert_weights_fp4: torch.Tensor, + fc1_weight_blockscale_fp8: torch.Tensor, + fc2_weight_blockscale_fp8: torch.Tensor, + fc1_act_global_scale: torch.Tensor, + fc2_act_global_scale: torch.Tensor, + fc1_alpha: torch.Tensor, + fc2_alpha: torch.Tensor, is_gated_mlp: bool = True, act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: @@ -234,28 +241,100 @@ def trtllm_quant_nvfp4_moe_fused( For mlp: y = act(x @ w1.T) @ w2.T # act := ReLU^2 + Notes: + - FC1 implements: fc1_output = (act(x @ w1.T) * (x @ w3.T)) or fc1_output = act(x @ w1.T) + - FC2 implements: fc2_output = fc1_output @ w2.T + - FC1 weights are concatenated w3 and w1 if gated_mlp, otherwise w1 + - FP4 elements pairs are packed as a single uint8 element - FC1 implements: fc1_output = (act(x @ w1.T) * (x @ w3.T)) or fc1_output = act(x @ w1.T) - FC2 implements: fc2_output = fc1_output @ w2.T - + Parameters: + x: BF16/FP16 input tensor of shape (B, H) or (B, S, H) + selected_experts: Expert indices (B*S, TOP_K) + routing_weights: Routing weights (B*S, TOP_K) + fc1_expert_weights_fp4: FP4 FC1 weights [E, 2*I, H/2] or [E, I, H/2]; packed uint8 + fc2_expert_weights_fp4: FP4 FC2 weights [E, H, I/2]; packed uint8 + fc1_weight_blockscale_fp8: Block scales for FC1 weights (w1 or cat(w3, w1)) + fc2_weight_blockscale_fp8: Block scales for FC2 weights (w2) + fc1_act_global_scale: Global scale for FC1 activations (scalar) + fc2_act_global_scale: Global scale for FC2 activations (scalar) + fc1_alpha: FC1 dequant scales = 1.0 / (fc1_act_global_scale * fc1_weight_global_scale) + fc2_alpha: FC2 dequant scales = 1.0 / (fc2_act_global_scale * fc2_weight_global_scale) + mlp_style: "gated_mlp" or "mlp" + act_fn: "silu" for gated_mlp, "relu2" for mlp """ - NVFP4_BLOCK_SIZE = 16 + NVFP4_BLOCK_SIZE = TRTLLM_NVFP4_SCALING_VECTOR_SIZE + FP4_PER_UINT8 = 2 - activation_type = ActivationType.Swiglu - if is_gated_mlp: - if act_fn in [ActivationType.Silu, ActivationType.Swiglu]: - activation_type = ActivationType.Swiglu - else: - raise ValueError( - f"Unsupported activation '{ActivationType(act_fn).name}' for gated_mlp. Use 'silu'." - ) + _, fc1_inter_size, _ = fc1_expert_weights_fp4.shape + n_experts, hidden_size, inter_size = fc2_expert_weights_fp4.shape + + # Convert the inter_size from number of uint8 elements to number of FP4 elements. + inter_size *= FP4_PER_UINT8 + + # Validate shapes and padding requirements as defined by the cutlass kernel. + assert fc1_weight_blockscale_fp8.ndim == 3, "fc1_weight_blockscale_fp8 must be 3D" + assert fc2_weight_blockscale_fp8.ndim == 3, "fc2_weight_blockscale_fp8 must be 3D" + assert fc1_weight_blockscale_fp8.size(1) % TRTLLM_NVFP4_ROW_SIZE == 0 + assert fc2_weight_blockscale_fp8.size(1) % TRTLLM_NVFP4_ROW_SIZE == 0 + assert fc1_weight_blockscale_fp8.size(2) % TRTLLM_NVFP4_COLUMN_SIZE == 0 + assert fc2_weight_blockscale_fp8.size(2) % TRTLLM_NVFP4_COLUMN_SIZE == 0 + + _validate_mlp_style_and_act_fn(is_gated_mlp, act_fn) + act_fn = ActivationType.Swiglu if act_fn == ActivationType.Silu else act_fn + + if x.dtype in (torch.float16, torch.bfloat16): + x_q_fp4, input_blockscale = torch.ops.trtllm.fp4_quantize( + x, fc1_act_global_scale, NVFP4_BLOCK_SIZE + ) + output_dtype = x.dtype else: - if act_fn == ActivationType.Relu2: - activation_type = ActivationType.Relu2 - else: - raise ValueError( - f"Unsupported activation '{ActivationType(act_fn).name}' for mlp. Use 'relu2'." - ) + x_q_fp4 = x + input_blockscale = None + output_dtype = x.dtype + + # Pad inter_size to be divisible by 128 + inter_size_padded = math.ceil(inter_size / TRTLLM_NVFP4_ROW_SIZE) * TRTLLM_NVFP4_ROW_SIZE + fc1_inter_size_padded = ( + math.ceil(fc1_inter_size / TRTLLM_NVFP4_ROW_SIZE) * TRTLLM_NVFP4_ROW_SIZE + ) + hidden_size_padded = ( + math.ceil(hidden_size / TRTLLM_NVFP4_COLUMN_SIZE) * TRTLLM_NVFP4_COLUMN_SIZE + ) + + inter_size_needs_padding = (is_gated_mlp and fc1_inter_size_padded != fc1_inter_size) or ( + not is_gated_mlp and inter_size_padded != inter_size + ) + hidden_size_needs_padding = hidden_size % TRTLLM_NVFP4_COLUMN_SIZE != 0 + if inter_size_needs_padding or hidden_size_needs_padding: + assert False, "See https://github.com/NVIDIA/TensorRT-LLM/issues/10331" + # fc1_expert_weights_fp4: [E, I, H] or [E, 2*I, H] + fc1_padded = fc1_expert_weights_fp4.new_zeros( + fc1_expert_weights_fp4.size(0), + fc1_inter_size_padded, + hidden_size_padded // FP4_PER_UINT8, + ) + fc1_padded[:, :fc1_inter_size, :] = fc1_expert_weights_fp4 + fc1_expert_weights_fp4 = fc1_padded + + # fc2_expert_weights_fp4: [E, H, I] + fc2_padded = fc2_expert_weights_fp4.new_zeros( + n_experts, hidden_size_padded, inter_size_padded // FP4_PER_UINT8 + ) + + assert inter_size % NVFP4_BLOCK_SIZE == 0, ( + f"inter_size {inter_size} must be divisible by {NVFP4_BLOCK_SIZE}" + ) + + fc2_padded[:, :, : inter_size // FP4_PER_UINT8] = fc2_expert_weights_fp4 + fc2_expert_weights_fp4 = fc2_padded + + fc2_blockscale_fp8_padded = fc2_weight_blockscale_fp8.new_zeros( + n_experts, hidden_size_padded, inter_size_padded // NVFP4_BLOCK_SIZE + ) + fc2_blockscale_fp8_padded[:, :, : inter_size // NVFP4_BLOCK_SIZE] = ( + fc2_weight_blockscale_fp8 + ) + fc2_weight_blockscale_fp8 = fc2_blockscale_fp8_padded # quant_scales is described by this code: # https://github.com/NVIDIA/TensorRT-LLM/blob/c9771ebb997683c08b26bbba796a7fc6aff09d93/cpp/tensorrt_llm/thop/moeOp.cpp#L1015 @@ -270,26 +349,19 @@ def trtllm_quant_nvfp4_moe_fused( fc2_alpha, # torch.float32; [E] ] - if x.dtype in (torch.float16, torch.bfloat16): - x_q_fp4, input_blockscale = torch.ops.trtllm.fp4_quantize( - x, fc1_act_global_scale, NVFP4_BLOCK_SIZE - ) - output_dtype = x.dtype - else: - x_q_fp4 = x - trtllm_output = torch.ops.trtllm.fused_moe( - x_q_fp4, - selected_experts.to(torch.int), - routing_weights, - fc1_expert_weights=fc1_expert_weights_fp4, + x_q_fp4.view(torch.long), + selected_experts.to(torch.int32), + routing_weights.to(torch.float32), + # Groups of 16 FP4 weight elements are packed as a single int64 element (see isNvfp4Quant in moeOp.cpp) + fc1_expert_weights=fc1_expert_weights_fp4.view(torch.long), fc1_expert_biases=None, fc2_expert_weights=fc2_expert_weights_fp4.view(torch.long), fc2_expert_biases=None, output_dtype=output_dtype, quant_scales=quant_scales, input_sf=input_blockscale, - activation_type=activation_type, + activation_type=act_fn, )[0].view(x.shape) return trtllm_output diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py index 4d9c7949644..5c5bcf6e3c0 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py @@ -12,6 +12,8 @@ TRTLLM_FP4_OP_AVAILABLE = True TRTLLM_NVFP4_SCALING_VECTOR_SIZE = 16 +TRTLLM_NVFP4_ROW_SIZE = 128 +TRTLLM_NVFP4_COLUMN_SIZE = 4 @torch.library.custom_op("auto_deploy::torch_quant_fn", mutates_args=()) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py index 7ac6b20f988..e1b4095dd99 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py @@ -1572,3 +1572,212 @@ def _apply( has_valid_shapes=fused_key_counter == 0, ) return gm, info + + +def _stack_nvfp4_moe_weights(gm: GraphModule) -> int: + def _register_parameter(gm: GraphModule, target, value): + gm.register_parameter(target, torch.nn.Parameter(value, requires_grad=False)) + + # Helper to get parameter or buffer + def get_param_or_buffer(target): + """Get parameter or buffer by target name.""" + try: + return gm.get_parameter(target) + except AttributeError: + # It's a buffer, not a parameter + parts = target.rsplit(".", 1) + if len(parts) == 2: + mod = gm.get_submodule(parts[0]) + return getattr(mod, parts[1]) + else: + return getattr(gm, target) + + def _extract_op_args(node): + return extract_op_args( + node, + "x", + "selected_experts", + "routing_weights", + "w1_weight", + "w2_weight", + "w3_weight", + "w1_input_scale", + "w2_input_scale", + "w3_input_scale", + "w1_weight_scale", + "w2_weight_scale", + "w3_weight_scale", + "w1_alpha", + "w2_alpha", + "w3_alpha", + "is_gated_mlp", + ) + + def _stack(param_list, dim=0, device=None, dtype=None): + if param_list: + return torch.stack( + [get_param_or_buffer(element.target) for element in param_list], dim=dim + ).contiguous() + else: + return torch.empty(0, device=device, dtype=dtype) + + def _prepare_args_cutlass_format_nvfp4(): + if is_gated_mlp: + # For gated MLP, concatenate w1 and w3 as [w3, w1] + fc1_expert_weights = torch.cat( + [w3_stacked, w1_stacked], dim=1 + ).contiguous() # [E, 2*I, H] + fc1_act_scale = torch.cat( + [w3_input_scale_stacked, w1_input_scale_stacked], dim=1 + ).contiguous() + fc1_alpha_stacked = torch.cat([w3_alpha_stacked, w1_alpha_stacked], dim=1).contiguous() + fc1_weight_blockscale_fp8_stacked = torch.cat( + [w3_weight_blockscale_fp8_stacked, w1_weight_blockscale_fp8_stacked], dim=1 + ).contiguous() + else: + fc1_expert_weights = w1_stacked + fc1_act_scale = w1_input_scale_stacked + fc1_alpha_stacked = w1_alpha_stacked + fc1_weight_blockscale_fp8_stacked = w1_weight_blockscale_fp8_stacked + + fc2_expert_weights = w2_stacked + fc2_act_scale = w2_input_scale_stacked + + new_key_fc1_expert_weights = f"nvfp4_moe_w3_w1_stacked_{fused_key_counter}" + new_key_fc2_expert_weights = f"nvfp4_moe_w2_stacked_{fused_key_counter}" + + new_key_fc1_weight_blockscale_fp8 = ( + f"nvfp4_moe_fc1_weight_blockscale_fp8_stacked_{fused_key_counter}" + ) + new_key_fc2_weight_blockscale_fp8 = ( + f"nvfp4_moe_fc2_weight_blockscale_fp8_stacked_{fused_key_counter}" + ) + new_key_fc1_act_scale = f"nvfp4_moe_w3_w1_input_scale_stacked_{fused_key_counter}" + new_key_fc2_act_scale = f"nvfp4_moe_w2_input_scale_stacked_{fused_key_counter}" + new_key_fc1_alpha = f"nvfp4_moe_w1_alpha_stacked_{fused_key_counter}" + new_key_fc2_alpha = f"nvfp4_moe_w2_alpha_stacked_{fused_key_counter}" + + weight_dtype = torch.float8_e4m3fn + _register_parameter(gm, new_key_fc1_expert_weights, fc1_expert_weights.to(weight_dtype)) + _register_parameter(gm, new_key_fc2_expert_weights, fc2_expert_weights.to(weight_dtype)) + _register_parameter( + gm, new_key_fc1_weight_blockscale_fp8, fc1_weight_blockscale_fp8_stacked + ) + _register_parameter(gm, new_key_fc2_weight_blockscale_fp8, w2_weight_blockscale_fp8_stacked) + _register_parameter(gm, new_key_fc1_act_scale, fc1_act_scale) + _register_parameter(gm, new_key_fc2_act_scale, fc2_act_scale) + _register_parameter(gm, new_key_fc1_alpha, fc1_alpha_stacked) + _register_parameter(gm, new_key_fc2_alpha, w2_alpha_stacked) + + with graph.inserting_before(node): + args = ( + hidden_states, + selected_experts, + routing_weights, + graph.get_attr(new_key_fc1_expert_weights), + graph.get_attr(new_key_fc2_expert_weights), + graph.get_attr(new_key_fc1_weight_blockscale_fp8), + graph.get_attr(new_key_fc2_weight_blockscale_fp8), + graph.get_attr(new_key_fc1_act_scale), + graph.get_attr(new_key_fc2_act_scale), + graph.get_attr(new_key_fc1_alpha), + graph.get_attr(new_key_fc2_alpha), + ) + return args + + fused_key_counter = 0 + graph = gm.graph + + replacement_op = torch.ops.auto_deploy.trtllm_quant_nvfp4_moe_fused + replaced_op = torch.ops.auto_deploy.torch_quant_nvfp4_moe + + matched_nodes = [node for node in graph.nodes if is_op(node, replaced_op)] + for node in matched_nodes: + # Extract weight and scale lists from args + ( + hidden_states, + selected_experts, + routing_weights, + w1_list, + w2_list, + w3_list, + w1_input_scale, + w2_input_scale, + w3_input_scale, + w1_weight_scale, + w2_weight_scale, + w3_weight_scale, + w1_alpha, + w2_alpha, + w3_alpha, + is_gated_mlp, + ) = _extract_op_args(node) + + # Stack the actual tensor values (fast, like in quantize_moe.py) + w1_stacked = _stack(w1_list, dim=0) + w2_stacked = _stack(w2_list, dim=0) + device, dtype = (w1_stacked.device, w1_stacked.dtype) + w3_stacked = _stack(w3_list, dim=0, device=device, dtype=dtype) + + # Scales are buffers, not parameters + w1_input_scale_stacked = _stack(w1_input_scale, dim=0) + w2_input_scale_stacked = _stack(w2_input_scale, dim=0) + w3_input_scale_stacked = _stack(w3_input_scale, dim=0, device=device, dtype=dtype) + + w1_weight_blockscale_fp8_stacked = _stack(w1_weight_scale, dim=0).to(torch.float8_e4m3fn) + w2_weight_blockscale_fp8_stacked = _stack(w2_weight_scale, dim=0).to(torch.float8_e4m3fn) + w3_weight_blockscale_fp8_stacked = _stack( + w3_weight_scale, dim=0, device=device, dtype=dtype + ).to(torch.float8_e4m3fn) + + w1_alpha_stacked = _stack(w1_alpha, dim=0) + w2_alpha_stacked = _stack(w2_alpha, dim=0) + w3_alpha_stacked = _stack(w3_alpha, dim=0, device=device, dtype=dtype) + + args = _prepare_args_cutlass_format_nvfp4() + + fused_key_counter += 1 + + # Create new node with get_attr for stacked parameters + with graph.inserting_before(node): + new_node = graph.call_function( + replacement_op, + args, + kwargs=node.kwargs, + ) + + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + # Clean up after processing all nodes + # eliminate_dead_code will remove unused get_attr nodes, then delete_all_unused_submodules + # will remove the parameters/buffers that are no longer referenced + gm.graph.eliminate_dead_code() + gm.delete_all_unused_submodules() + return fused_key_counter + + +@TransformRegistry.register("fuse_nvfp4_moe") +class FuseNVFP4Moe(BaseTransform): + """ + Stack per-expert NVFP4 MoE weights and scales to avoid runtime stacking overhead. + This runs after weights are loaded, similar to FuseFP8Moe. + """ + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + with cuda_memory_tracker(): + fused_key_counter = _stack_nvfp4_moe_weights(gm) + + info = TransformInfo( + skipped=(fused_key_counter == 0), + num_matches=fused_key_counter, + is_clean=fused_key_counter == 0, + has_valid_shapes=fused_key_counter == 0, + ) + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py index 2fdaaf55067..21d1ccd2348 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py @@ -1,3 +1,4 @@ +import math from functools import partial from typing import Dict, List, Tuple @@ -8,6 +9,8 @@ from ...custom_ops.quant import ( FP4_GLOBAL_SCALE_MAX, FP8_MAX, + TRTLLM_NVFP4_COLUMN_SIZE, + TRTLLM_NVFP4_ROW_SIZE, TRTLLM_NVFP4_SCALING_VECTOR_SIZE, is_column_major, ) @@ -317,20 +320,28 @@ def quantize_weight(self, w: torch.Tensor) -> torch.Tensor: def scale_names(self) -> List[str]: return ["input_scale", "weight_scale", "alpha"] + def _pad_m_n(self, m: int, n: int) -> Tuple[int, int]: + """Pad m and n to be divisible by 128 and 4 respectively. + Check cpp/tensorrt_llm/plugins/fp4GemmPlugin/fp4GemmPlugin.cpp for more details. + """ + padded_m = math.ceil(m / TRTLLM_NVFP4_ROW_SIZE) * TRTLLM_NVFP4_ROW_SIZE + padded_n = math.ceil(n / TRTLLM_NVFP4_COLUMN_SIZE) * TRTLLM_NVFP4_COLUMN_SIZE + return padded_m, padded_n + def default_scales(self, original_weight_shape: Tuple) -> Dict[str, torch.Tensor]: m, n = original_weight_shape - # scaling factors m is padded along 128 and n is padded along 4. - # check cpp/tensorrt_llm/plugins/fp4GemmPlugin/fp4GemmPlugin.cpp for more details. n = n // TRTLLM_NVFP4_SCALING_VECTOR_SIZE - padded_m = (m + 127) // 128 * 128 - padded_n = (n + 3) // 4 * 4 + padded_m, padded_n = self._pad_m_n(m, n) # definition of scales # input_scale: FP4_GLOBAL_SCALE_MAX / input_amax # weight_scale_2: FP4_GLOBAL_SCALE_MAX / weight_amax # alpha: 1 / (input_scale * weight_scale_2) return { "input_scale": torch.tensor(1.0 / 6.0), - "weight_scale": torch.empty((padded_m * padded_n), dtype=torch.uint8), + "weight_scale": torch.empty((padded_m, padded_n), dtype=torch.uint8), + # "weight_scale": torch.empty((m, n), dtype=torch.uint8), + # "weight_scale": torch.empty(padded_m * padded_n, dtype=torch.float8_e4m3fn), + # "weight_scale": torch.empty(padded_m * padded_n, dtype=torch.uint8), "alpha": torch.tensor(1.0 / 6.0), } @@ -375,12 +386,19 @@ def load_hook(self, state_dict, prefix, *args, weight_name): ) state_dict[input_scale_name] = 1 / state_dict[input_scale_name] weight_scale = state_dict[weight_name + "_scale"].view(float4_sf_dtype) - state_dict[weight_name + "_scale"] = ( - torch.ops.trtllm.block_scale_interleave( - weight_scale.view(torch.uint8).cpu().contiguous() - ) - .view(float4_sf_dtype) - .reshape(-1) + # Round the weight block scale factors to 128x4 and then swizzle. + weight_scale_swizzled = torch.ops.trtllm.block_scale_interleave( + weight_scale.view(torch.uint8).cpu().contiguous() + ).view(float4_sf_dtype) + + m, n = weight_scale.shape + # scaling factors m is padded along 128 and n is padded along 4. + # check cpp/tensorrt_llm/plugins/fp4GemmPlugin/fp4GemmPlugin.cpp for more details. + padded_m, padded_n = self._pad_m_n(m, n) + swizzled_shape = (padded_m, padded_n) + + state_dict[weight_name + "_scale"] = weight_scale_swizzled.reshape( + swizzled_shape ) def convert_amax_hook(self, state_dict, prefix, *args, scale_name: str, amax_name: str): diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py index ae3123200b9..ad8d549ec33 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py @@ -3,6 +3,7 @@ https://github.com/flashinfer-ai/flashinfer/blob/main/tests/moe/test_trtllm_cutlass_fused_moe.py """ +import math from typing import Callable import pytest @@ -12,12 +13,17 @@ from utils.util import skip_pre_hopper import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 +from tensorrt_llm._torch.auto_deploy.custom_ops.quant import ( + TRTLLM_NVFP4_COLUMN_SIZE, + TRTLLM_NVFP4_ROW_SIZE, + TRTLLM_NVFP4_SCALING_VECTOR_SIZE, +) from tensorrt_llm._torch.utils import ActivationType FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max FLOAT4_E2M1_MAX = 6.0 FP8_DTYPE = torch.float8_e4m3fn -NVFP4_BLOCK_SIZE = 16 +NVFP4_BLOCK_SIZE = TRTLLM_NVFP4_SCALING_VECTOR_SIZE def dynamic_per_tensor_fp8_quant(x: torch.tensor) -> tuple[torch.tensor, torch.tensor]: @@ -152,21 +158,14 @@ def _print_diff_if( # Test configurations -BATCH_SIZES = [ - 1, -] -HIDDEN_SIZES = [ - 128, -] +BATCH_SIZES = [1] +HIDDEN_SIZES = [128] NUM_EXPERTS = [2] TOP_K_VALUES = [2] -INTERMEDIATE_SIZES = [ - 128, -] +INTERMEDIATE_SIZES = [128] EP_NUM_EXPERTS = [8] EP_TOP_K = [2] - F16_TEST_DTYPES = [ (torch.float16, torch.float16, torch.float16), (torch.bfloat16, torch.bfloat16, torch.bfloat16), @@ -504,12 +503,12 @@ def dequantize_nvfp4_to_dtype(tensor_fp4, tensor_sf, global_scale, dtype, device """Dequantize the fp4 tensor back to high precision.""" def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): - m_tiles = (m + 128 - 1) // 128 + m_tiles = (m + TRTLLM_NVFP4_ROW_SIZE - 1) // TRTLLM_NVFP4_ROW_SIZE f = block_size * 4 k_tiles = (k + f - 1) // f tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) - out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) + out = tmp.reshape(m_tiles * TRTLLM_NVFP4_ROW_SIZE, k_tiles * f // block_size) return out[0:m, 0:k] # Originally from https://github.com/flashinfer-ai/flashinfer/blob/main/tests/moe/test_trtllm_cutlass_fused_moe.py @@ -552,18 +551,23 @@ def break_fp4_bytes(a, dtype): return out.to(dtype=dtype) -NVFP4_TEST_DTYPES = [ - (torch.float16, torch.float8_e4m3fn), - (torch.bfloat16, torch.float8_e4m3fn), +NVFP4_TEST_DTYPES = [torch.float16, torch.bfloat16] + +FP4_TEST_SHAPES = [ + (128, 128), # Trivial test case (no padding required) + (2688, 1856), # Nemotron-Nano-3-30B-A3 sizes (padding required) ] +# Scale the input and weights to prevent large absolute values. +FP4_X_GEN_SCALE = 0.5 +FP4_W_GEN_SCALE = 0.1 + @pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("hidden_size, intermediate_size", FP4_TEST_SHAPES) @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.parametrize("top_k", TOP_K_VALUES) -@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) -@pytest.mark.parametrize("otype, wtype", NVFP4_TEST_DTYPES) +@pytest.mark.parametrize("otype", NVFP4_TEST_DTYPES) @pytest.mark.parametrize("activation_func", [ActivationType.Silu, ActivationType.Relu2]) @pytest.mark.skipif( not fp4_compatible() or not trtllm_ops_available(), @@ -576,12 +580,17 @@ def test_trtllm_fused_moe_nvfp4( top_k, intermediate_size, otype, - wtype, activation_func, ): + # Skip known failing configuration + if activation_func == ActivationType.Relu2 and intermediate_size == 1856: + pytest.skip( + "test fails for Relu2 with intermediate_size=1856; see https://github.com/NVIDIA/TensorRT-LLM/issues/10331" + ) + # In the code below: # sf := block scale factors for NVFP4 - # blockscale := block scale factor for NVFP4 + # blockscale := block scale factors for NVFP4 # gs := global scale for NVFP4 # Skip invalid configurations @@ -596,137 +605,94 @@ def _get_test_data( num_experts, intermediate_size, ): - x = gen_tensor((batch_size, hidden_size), otype) - w1_shape = (num_experts, intermediate_size, hidden_size) - w3_shape = w1_shape - w1 = gen_tensor(w1_shape, otype, scale=0.1) - w2 = gen_tensor((num_experts, hidden_size, intermediate_size), otype, scale=0.1) - w3 = gen_tensor(w3_shape, otype, scale=0.1) + x = gen_tensor((batch_size, hidden_size), otype) * FP4_X_GEN_SCALE router_logits = torch.randn(batch_size, num_experts, dtype=otype).cuda() - return x, w1, w2, w3, router_logits - def _quantize_weights(w1, w2, w3): + if is_gated_mlp: + # For gated MLP, concatenate w1 and w3 as [w3, w1] + fc1_weights_shape = (num_experts, intermediate_size * 2, hidden_size) + else: + fc1_weights_shape = (num_experts, intermediate_size, hidden_size) + + fc1_weights = gen_tensor(fc1_weights_shape, otype, scale=FP4_W_GEN_SCALE) + fc2_weights = gen_tensor( + (num_experts, hidden_size, intermediate_size), otype, scale=FP4_W_GEN_SCALE + ) + + return x, fc1_weights, fc2_weights, router_logits + + def _quantize_weights(fc1_weights, fc2_weights, is_gated_mlp): def round_up(x, y): - return (x + y - 1) // y * y - - w1_n = w1.shape[1] - w3_n = w3.shape[1] - sf_w1_n = round_up(w1_n, 128) - sf_w3_n = round_up(w3_n, 128) - sf_w1_k = round_up(hidden_size // NVFP4_BLOCK_SIZE, 4) - w1_blockscale = torch.empty( - (num_experts, sf_w1_n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn + return math.ceil(x / y) * y + + fc1_weights_n = fc1_weights.shape[1] + sf_fc1_weights_n = round_up(fc1_weights_n, TRTLLM_NVFP4_ROW_SIZE) + sf_fc1_weights_k = round_up(hidden_size // NVFP4_BLOCK_SIZE, TRTLLM_NVFP4_COLUMN_SIZE) + fc1_weights_blockscale = torch.empty( + (num_experts, sf_fc1_weights_n, sf_fc1_weights_k), + device="cuda", + dtype=torch.float8_e4m3fn, ) - sf_w2_k = round_up(hidden_size, 128) - sf_w2_n = round_up(intermediate_size // NVFP4_BLOCK_SIZE, 4) - w2_blockscale = torch.empty( - (num_experts, sf_w2_k, sf_w2_n), device="cuda", dtype=torch.float8_e4m3fn + sf_fc2_weights_k = round_up(hidden_size, TRTLLM_NVFP4_ROW_SIZE) + sf_fc2_weights_n = round_up(intermediate_size // NVFP4_BLOCK_SIZE, TRTLLM_NVFP4_COLUMN_SIZE) + fc2_weights_blockscale = torch.empty( + (num_experts, sf_fc2_weights_k, sf_fc2_weights_n), + device="cuda", + dtype=torch.float8_e4m3fn, ) - w3_blockscale = torch.empty( - (num_experts, sf_w3_n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn + + fc1_weights_q = torch.empty( + (num_experts, fc1_weights_n, hidden_size // 2), device="cuda", dtype=torch.uint8 ) - w1_q = torch.empty((num_experts, w1_n, hidden_size // 2), device="cuda", dtype=torch.uint8) - w2_q = torch.empty( + fc2_weights_q = torch.empty( (num_experts, hidden_size, intermediate_size // 2), device="cuda", dtype=torch.uint8 ) - w3_q = torch.empty((num_experts, w3_n, hidden_size // 2), device="cuda", dtype=torch.uint8) - w1_gs = torch.empty((num_experts,), device="cuda", dtype=torch.float32) - w2_gs = torch.empty((num_experts,), device="cuda", dtype=torch.float32) - w3_gs = torch.empty((num_experts,), device="cuda", dtype=torch.float32) + fc1_weights_gs = torch.empty((num_experts,), device="cuda", dtype=torch.float32) + fc2_weights_gs = torch.empty((num_experts,), device="cuda", dtype=torch.float32) for expert in range(num_experts): - w1_amax = torch.abs(w1[expert]).max().to(torch.float32) - w2_amax = torch.abs(w2[expert]).max().to(torch.float32) - w3_amax = torch.abs(w3[expert]).max().to(torch.float32) - w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax - w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax - w3_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w3_amax - + fc1_weights_amax = torch.abs(fc1_weights[expert]).max().to(torch.float32) + fc2_weights_amax = torch.abs(fc2_weights[expert]).max().to(torch.float32) + fc1_weights_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / fc1_weights_amax + fc2_weights_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / fc2_weights_amax + + # Quantize the weights to NVFP4 and ask for swizzled block scale factors because the + # MoE operator expects pre-swizzled block scale factors. Swizzling also flattens the + # block scale factors to a 1D tensor. + # Note that swizzling might create padded block scales + # because the block-scales are required to be padded to the nearest multiple of 128x4. nvfp4_vals, fp8_block_scales = torch.ops.trtllm.fp4_quantize( - w1[expert], w1_gs[expert], NVFP4_BLOCK_SIZE, isSfSwizzledLayout=True + fc1_weights[expert], + fc1_weights_gs[expert], + NVFP4_BLOCK_SIZE, + isSfSwizzledLayout=True, ) - w1_q[expert] = nvfp4_vals - w1_blockscale[expert] = fp8_block_scales.reshape(w1_blockscale[expert].shape) - - nvfp4_vals, fp8_block_scales = torch.ops.trtllm.fp4_quantize( - w2[expert], w2_gs[expert], NVFP4_BLOCK_SIZE, isSfSwizzledLayout=True + fc1_weights_q[expert] = nvfp4_vals + fc1_weights_blockscale[expert] = fp8_block_scales.reshape( + fc1_weights_blockscale[expert].shape ) - w2_q[expert] = nvfp4_vals - w2_blockscale[expert] = fp8_block_scales.reshape(w2_blockscale[expert].shape) nvfp4_vals, fp8_block_scales = torch.ops.trtllm.fp4_quantize( - w3[expert], w3_gs[expert], NVFP4_BLOCK_SIZE, isSfSwizzledLayout=True + fc2_weights[expert], + fc2_weights_gs[expert], + NVFP4_BLOCK_SIZE, + isSfSwizzledLayout=True, ) - w3_q[expert] = nvfp4_vals - w3_blockscale[expert] = fp8_block_scales.reshape(w3_blockscale[expert].shape) - - return w1_q, w2_q, w3_q, w1_blockscale, w2_blockscale, w3_blockscale, w1_gs, w2_gs, w3_gs - - x, w1, w2, w3, router_logits = _get_test_data( - otype, batch_size, hidden_size, num_experts, intermediate_size - ) - - ( - w1_q_fp4, - w2_q_fp4, - w3_q_fp4, - w1_blockscale, - w2_blockscale, - w3_blockscale, - w1_gs, - w2_gs, - w3_gs, - ) = _quantize_weights(w1, w2, w3) - - fc1_activation_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32) - fc2_activation_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32) - - routing_weights, selected_experts = compute_routing(router_logits, top_k) - - fc1_weight_gs = torch.max(w3_gs, w1_gs) - fc1_alpha = 1.0 / (fc1_activation_gs * fc1_weight_gs) - fc2_alpha = 1.0 / (fc2_activation_gs * w2_gs) - - is_gated_mlp = False if activation_func == ActivationType.Relu2 else True - if is_gated_mlp: - # For gated MLP, concatenate w1 and w3 as [w3, w1] - fc1_expert_weights_fp4 = torch.cat([w3_q_fp4, w1_q_fp4], dim=1).contiguous() - fc1_weight_blockscale_fp8 = torch.cat([w3_blockscale, w1_blockscale], dim=1) - fc1_weight_gs = torch.max(w3_gs, w1_gs) - if activation_func != ActivationType.Silu: - raise ValueError( - f"Unsupported activation '{activation_func}' for gated_mlp. Use 'silu'." + fc2_weights_q[expert] = nvfp4_vals + fc2_weights_blockscale[expert] = fp8_block_scales.reshape( + fc2_weights_blockscale[expert].shape ) - else: - # For non-gated MLP with ReLU^2 - fc1_expert_weights_fp4 = w1_q_fp4 - fc1_weight_blockscale_fp8 = w1_blockscale.view(torch.long) - fc1_weight_gs = w1_gs - if activation_func != ActivationType.Relu2: - raise ValueError(f"Unsupported activation '{activation_func}' for mlp. Use 'relu2'.") - - fc2_expert_weights_fp4 = w2_q_fp4.view(torch.long) - fc2_weight_blockscale_fp8 = w2_blockscale.view(torch.long) - fc1_expert_weights_fp4 = fc1_expert_weights_fp4.view(torch.long) - - trtllm_output = torch.ops.auto_deploy.trtllm_quant_nvfp4_moe_fused( - x, - selected_experts.to(torch.int), - routing_weights, - fc1_expert_weights_fp4, - fc2_expert_weights_fp4, - fc1_weight_blockscale_fp8, - fc2_weight_blockscale_fp8, - fc1_activation_gs, - fc2_activation_gs, - fc1_alpha, - fc2_alpha, - is_gated_mlp=is_gated_mlp, - act_fn=activation_func, - ) + return ( + fc1_weights_q, + fc2_weights_q, + fc1_weights_blockscale, + fc2_weights_blockscale, + fc1_weights_gs, + fc2_weights_gs, + ) - def compute_ref_output(w1_gs, w3_gs): + def compute_ref_output(fc1_weights_gs, fc2_weights_gs): # Quantize then dequantize the input to emulate the precision loss. a_fp4, a_scale_interleaved = torch.ops.trtllm.fp4_quantize( x, fc1_activation_gs, NVFP4_BLOCK_SIZE @@ -739,38 +705,25 @@ def compute_ref_output(w1_gs, w3_gs): device=x.device, block_size=NVFP4_BLOCK_SIZE, ) - - if is_gated_mlp: - w1_gs = w3_gs = torch.max(w1_gs, w3_gs) - - w1_dq = torch.empty(w1.shape, device="cuda", dtype=otype) - w3_dq = torch.empty(w3.shape, device="cuda", dtype=otype) - w2_dq = torch.empty(w2.shape, device="cuda", dtype=otype) + fc1_weights_dq = torch.empty(fc1_expert_weights.shape, device="cuda", dtype=otype) + fc2_weights_dq = torch.empty(fc2_expert_weights.shape, device="cuda", dtype=otype) # Dequantize the weights to emulate the precision loss. for idx in range(0, num_experts): - w1_dq[idx] = dequantize_nvfp4_to_dtype( - w1_q_fp4[idx], - w1_blockscale[idx], - w1_gs[idx], - dtype=w1.dtype, - device=w1.device, + fc1_weights_dq[idx] = dequantize_nvfp4_to_dtype( + fc1_expert_weights_fp4[idx], + fc1_weight_blockscale_fp8[idx], + fc1_weights_gs[idx], + dtype=fc1_expert_weights.dtype, + device=fc1_expert_weights.device, block_size=NVFP4_BLOCK_SIZE, ) - w2_dq[idx] = dequantize_nvfp4_to_dtype( - w2_q_fp4[idx], - w2_blockscale[idx], - w2_gs[idx], - dtype=w2.dtype, - device=w2.device, - block_size=NVFP4_BLOCK_SIZE, - ) - w3_dq[idx] = dequantize_nvfp4_to_dtype( - w3_q_fp4[idx], - w3_blockscale[idx], - w3_gs[idx], - dtype=w3.dtype, - device=w3.device, + fc2_weights_dq[idx] = dequantize_nvfp4_to_dtype( + fc2_expert_weights_fp4[idx], + fc2_weight_blockscale_fp8[idx], + fc2_weights_gs[idx], + dtype=fc2_expert_weights.dtype, + device=fc2_expert_weights.device, block_size=NVFP4_BLOCK_SIZE, ) @@ -780,8 +733,8 @@ def compute_ref_output(w1_gs, w3_gs): ) ref_output = torch_moe_nvfp4( x_dq, - torch.cat([w3_dq, w1_dq], dim=1) if is_gated_mlp else w1_dq, - w2_dq, + fc1_weights_dq, + fc2_weights_dq, top_k, routing_weights, selected_experts, @@ -789,9 +742,53 @@ def compute_ref_output(w1_gs, w3_gs): ) return ref_output - ref_output = compute_ref_output(w1_gs, w3_gs) - print(f"max diff: {(ref_output - trtllm_output).abs().max()}") - print(f"diff = {ref_output - trtllm_output}") - print(f"ref_output = {ref_output}") - print(f"flash_output = {trtllm_output}") + # Begin test + is_gated_mlp = False if activation_func == ActivationType.Relu2 else True + + x, fc1_expert_weights, fc2_expert_weights, router_logits = _get_test_data( + otype, batch_size, hidden_size, num_experts, intermediate_size + ) + + ( + fc1_expert_weights_fp4, + fc2_expert_weights_fp4, + fc1_weight_blockscale_fp8, + fc2_weight_blockscale_fp8, + fc1_weights_gs, + fc2_weights_gs, + ) = _quantize_weights(fc1_expert_weights, fc2_expert_weights, is_gated_mlp) + + # Simplify by assuming a scale of 1.0 for the activations + fc1_activation_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32) + fc2_activation_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32) + + routing_weights, selected_experts = compute_routing(router_logits, top_k) + + fc1_alpha = 1.0 / (fc1_activation_gs * fc1_weights_gs) + fc2_alpha = 1.0 / (fc2_activation_gs * fc2_weights_gs) + + trtllm_output = torch.ops.auto_deploy.trtllm_quant_nvfp4_moe_fused( + x, + selected_experts.to(torch.int), + routing_weights, + fc1_expert_weights_fp4, + fc2_expert_weights_fp4, + fc1_weight_blockscale_fp8, + fc2_weight_blockscale_fp8, + fc1_activation_gs, + fc2_activation_gs, + fc1_alpha, + fc2_alpha, + is_gated_mlp=is_gated_mlp, + act_fn=activation_func, + ) + + ref_output = compute_ref_output(fc1_weights_gs, fc2_weights_gs) + diff = ref_output - trtllm_output + print(f"max diff: {diff.abs().max()}") + # torch.set_printoptions(profile="full") + print(f"{diff=}") + print(f"{ref_output=}") + print(f"{trtllm_output=}") + # print(f"{diff.abs()>=2e-1=}") torch.testing.assert_close(ref_output, trtllm_output, rtol=2e-1, atol=2e-1)