diff --git a/modelopt/onnx/quantization/qdq_utils.py b/modelopt/onnx/quantization/qdq_utils.py index ec5d10586..326d30e14 100644 --- a/modelopt/onnx/quantization/qdq_utils.py +++ b/modelopt/onnx/quantization/qdq_utils.py @@ -1111,17 +1111,20 @@ def quantize_weights_to_int4( scale_shape = [*weight_shape[:-1], weight_shape[-1] // block_size] scale = scale.reshape(scale_shape) reshape_child_nodes = [n for n in graph.node if reshape_node.output[0] in n.input] - assert len(reshape_child_nodes) == 1, f"Expected exactly one transpose node for {node.name}" + assert len(reshape_child_nodes) == 1, f"Expected exactly one child node for {node.name}" - # Remove unnecessary Cast node - cast_node = reshape_child_nodes[0] - assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}" - nodes_to_remove.append(cast_node.name) - cast_child_nodes = [n for n in graph.node if cast_node.output[0] in n.input] + # Check if there's an optional Cast node between Reshape and Transpose/MatMul/Gemm + next_node = reshape_child_nodes[0] + if next_node.op_type == "Cast": + # Remove unnecessary Cast node + cast_node = next_node + nodes_to_remove.append(cast_node.name) + cast_child_nodes = [n for n in graph.node if cast_node.output[0] in n.input] + next_node = cast_child_nodes[0] # Transpose weights and scales if present - if cast_child_nodes[0].op_type == "Transpose": - transpose_node = cast_child_nodes[0] + if next_node.op_type == "Transpose": + transpose_node = next_node nodes_to_remove.append(transpose_node.name) assert transpose_node.op_type == "Transpose", f"Expected Transpose node for {node.name}" perm = None @@ -1138,7 +1141,7 @@ def quantize_weights_to_int4( ) matmul_node = transpose_child_nodes[0] else: - matmul_node = cast_child_nodes[0] + matmul_node = next_node assert matmul_node.op_type in ["MatMul", "Gemm"], ( f"Expected MatMul or Gemm node for {node.name}" ) @@ -1189,21 +1192,6 @@ def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool: del graph.node[:] graph.node.extend(new_nodes) - def is_fp32_cast(node: onnx.NodeProto) -> bool: - return any( - attr.name == "to" and attr.i == onnx.TensorProto.FLOAT for attr in node.attribute - ) - - # Change all Cast nodes that cast to float32 (TensorProto.FLOAT) to cast to float16 (TensorProto.FLOAT16) - for node in graph.node: - if node.op_type == "Cast": - # Skip Cast nodes that are part of normalization layers and outputs - if "norm/Cast" in node.name and is_fp32_cast(node): - continue - for attr in node.attribute: - if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT: - attr.i = onnx.TensorProto.FLOAT16 - # Cast bias to float16 for node in graph.node: if node.op_type == "Add" and "proj/Add" in node.name: @@ -1310,13 +1298,6 @@ def quantize_weights_to_mxfp8( if attr.name == "output_dtype": attr.i = onnx_dtype_map["Half"] - # set Cast to FP16 - for node in graph.node: - if node.op_type == "Cast": - for attr in node.attribute: - if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT: - attr.i = onnx_dtype_map["Half"] - # Currently only tanh approximation is supported for Gelu for node in gelu_nodes: for attr in node.attribute: diff --git a/tests/unit/onnx/test_qdq_utils.py b/tests/unit/onnx/test_qdq_utils.py index ca7d15189..e661b7c78 100644 --- a/tests/unit/onnx/test_qdq_utils.py +++ b/tests/unit/onnx/test_qdq_utils.py @@ -17,11 +17,17 @@ import pytest from onnx import TensorProto, helper, numpy_helper -from modelopt.onnx.quantization.qdq_utils import _cast_fp4, _cast_fp8, quantize_weights_to_int4 +from modelopt.onnx.quantization.qdq_utils import ( + _cast_fp4, + _cast_fp8, + fp4qdq_to_2dq, + quantize_weights_to_int4, + quantize_weights_to_mxfp8, +) -def create_test_model_with_dq_reshape_transpose_matmul(constant_scale: bool = False): - """Create a test ONNX model with DequantizeLinear -> Reshape -> Transpose -> MatMul pattern. +def create_test_model_with_int4_dq_reshape_transpose_matmul(constant_scale: bool = False): + """Create a test ONNX model with DequantizeLinear -> Reshape -> Transpose -> MatMul pattern for INT4. If constant_scale is True, the scale is a Constant node instead of an initializer.""" # Create weight tensor (4x8 matrix scaled by 32 blocks) weight_data = np.random.randint(-8, 8, size=(32, 8), dtype=np.int8) @@ -186,12 +192,149 @@ def create_test_model_with_proj_nodes(): return model +def create_test_model_with_mxfp8_dq(): + """Create a test ONNX model with TRT_MXFP8DequantizeLinear nodes for testing MXFP8.""" + # Create weight tensor + weight_data = np.random.uniform(-1.0, 1.0, size=(64, 32)).astype(np.float32) + weight_tensor = numpy_helper.from_array(weight_data, "linear.weight") + + # Create scale tensor (constant node) - MXFP8 uses block_size=32 + scale_data = np.random.uniform(0.1, 1.0, size=(2, 1)).astype(np.float32) + + # Create input tensor + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [4, 32]) + + # Create scale constant node + scale_constant = helper.make_node( + "Constant", + inputs=[], + outputs=["Constant_output_0"], + value=numpy_helper.from_array(scale_data), + name="scale_constant", + ) + + # Create TRT_MXFP8DequantizeLinear node + dq_node = helper.make_node( + "TRT_MXFP8DequantizeLinear", + inputs=["linear.weight", "Constant_output_0"], + outputs=["dq_output"], + name="weight_dq", + axis=-1, + block_size=32, + output_dtype=TensorProto.FLOAT, + ) + + # Create MatMul node + matmul_node = helper.make_node( + "MatMul", inputs=["input", "dq_output"], outputs=["output"], name="matmul" + ) + + # Create optional Gelu node to test Gelu approximation update + gelu_node = helper.make_node( + "Gelu", inputs=["output"], outputs=["gelu_output"], name="gelu", approximate="none" + ) + + graph = helper.make_graph( + nodes=[scale_constant, dq_node, matmul_node, gelu_node], + name="test_graph", + inputs=[input_tensor], + outputs=[helper.make_tensor_value_info("gelu_output", TensorProto.FLOAT, [4, 64])], + initializer=[weight_tensor], + ) + + model = helper.make_model(graph) + return model + + +def create_test_model_with_nvfp4_qdq(with_transpose: bool = False): + """Create a test ONNX model with TRT_FP4QDQ nodes for testing NVFP4. + + Args: + with_transpose: If True, adds a Transpose node between TRT_FP4QDQ and MatMul. + """ + if with_transpose: + # For transpose case, weight shape is (32, 64) to match transpose output + weight_data = np.random.uniform(-1.0, 1.0, size=(32, 64)).astype(np.float32) + fp4qdq_output_shape = [32, 64] + transpose_output_shape = [64, 32] + else: + # For non-transpose case, weight shape is (64, 32) (FP16 for testing BFloat16 detection) + weight_data = np.random.uniform(-1.0, 1.0, size=(64, 32)).astype(np.float16) + fp4qdq_output_shape = [64, 32] + transpose_output_shape = None + + weight_tensor = numpy_helper.from_array(weight_data, "linear.weight") + + # Create input tensor + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [4, 32]) + + # Create TRT_FP4QDQ node with correct block_size=16 for NVFP4 + fp4qdq_node = helper.make_node( + "TRT_FP4QDQ", + inputs=["linear.weight"], + outputs=["fp4qdq_output"], + name="weight_fp4qdq", + block_size=16, + ) + + nodes = [fp4qdq_node] + value_info = [] + + # Create value info for fp4qdq output + fp4qdq_output_dtype = TensorProto.FLOAT16 if not with_transpose else TensorProto.FLOAT + fp4qdq_output_info = helper.make_tensor_value_info( + "fp4qdq_output", fp4qdq_output_dtype, fp4qdq_output_shape + ) + value_info.append(fp4qdq_output_info) + + if with_transpose: + # Create Transpose node + transpose_node = helper.make_node( + "Transpose", + inputs=["fp4qdq_output"], + outputs=["transpose_output"], + name="transpose", + perm=[1, 0], + ) + nodes.append(transpose_node) + + # Create value info for transpose output + transpose_output_info = helper.make_tensor_value_info( + "transpose_output", TensorProto.FLOAT, transpose_output_shape + ) + value_info.append(transpose_output_info) + + # MatMul uses transpose output + matmul_input = "transpose_output" + else: + # MatMul uses fp4qdq output directly + matmul_input = "fp4qdq_output" + + # Create MatMul node + matmul_node = helper.make_node( + "MatMul", inputs=["input", matmul_input], outputs=["output"], name="matmul" + ) + nodes.append(matmul_node) + + graph = helper.make_graph( + nodes=nodes, + name="test_graph", + inputs=[input_tensor], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [4, 64])], + initializer=[weight_tensor], + value_info=value_info, + ) + + model = helper.make_model(graph) + return model + + class TestQuantizeWeightsToInt4: """Test suite for quantize_weights_to_int4 function.""" def test_basic_quantization_with_reshape_transpose(self): """Test basic INT4 quantization with Reshape and Transpose removal.""" - model = create_test_model_with_dq_reshape_transpose_matmul() + model = create_test_model_with_int4_dq_reshape_transpose_matmul() # Run quantization quantized_model = quantize_weights_to_int4(model) @@ -216,7 +359,7 @@ def test_basic_quantization_with_reshape_transpose(self): def test_quantization_with_constant_scale(self): """Test quantization when scale comes from a Constant node.""" - model = create_test_model_with_dq_reshape_transpose_matmul(constant_scale=True) + model = create_test_model_with_int4_dq_reshape_transpose_matmul(constant_scale=True) # Run quantization quantized_model = quantize_weights_to_int4(model) @@ -237,25 +380,6 @@ def test_quantization_with_constant_scale(self): ) assert any("scale" in input_name for input_name in dq_node.input) - def test_cast_node_conversion(self): - """Test that Cast nodes are properly converted from float32 to float16.""" - model = create_test_model_with_cast_nodes() - - # Run quantization - quantized_model = quantize_weights_to_int4(model) - - # Check Cast node conversions - for node in quantized_model.graph.node: - if node.op_type == "Cast": - to_attr = next(attr for attr in node.attribute if attr.name == "to") - - if "norm/Cast" in node.name: - # These should remain as float32 - assert to_attr.i == TensorProto.FLOAT - else: - # Regular Cast nodes should be converted to float16 - assert to_attr.i == TensorProto.FLOAT16 - def test_projection_bias_and_scale_casting(self): """Test that projection biases and quantization scales are cast to float16.""" model = create_test_model_with_proj_nodes() @@ -348,3 +472,166 @@ def test_cast_fp4(self, input_array, expected_array): assert result.dtype == np.dtype(np.uint8) assert result.shape == expected_array.shape assert np.all(result == expected_array) + + +class TestQuantizeWeightsToMXFP8: + """Test suite for quantize_weights_to_mxfp8 function.""" + + def test_basic_mxfp8_quantization(self): + """Test basic MXFP8 quantization with TRT_MXFP8DequantizeLinear nodes.""" + model = create_test_model_with_mxfp8_dq() + + # Run MXFP8 quantization + quantized_model = quantize_weights_to_mxfp8(model) + + # Verify weight is converted to FP8 + weight_tensor = next( + init for init in quantized_model.graph.initializer if init.name == "linear.weight" + ) + assert weight_tensor.data_type == TensorProto.FLOAT8E4M3FN + + # Verify scale tensor is created and is uint8 + scale_tensors = [init for init in quantized_model.graph.initializer if "scale" in init.name] + assert len(scale_tensors) > 0 + scale_tensor = scale_tensors[0] + assert scale_tensor.data_type == TensorProto.UINT8 + + # Verify Constant node is removed + constant_nodes = [node for node in quantized_model.graph.node if node.op_type == "Constant"] + assert len(constant_nodes) == 0 + + # Verify DQ node references the new scale + dq_node = next( + node + for node in quantized_model.graph.node + if node.op_type == "TRT_MXFP8DequantizeLinear" + ) + assert any("scale" in input_name for input_name in dq_node.input) + + def test_mxfp8_output_dtype_update(self): + """Test that output_dtype attribute is updated to FP16.""" + model = create_test_model_with_mxfp8_dq() + + # Run MXFP8 quantization + quantized_model = quantize_weights_to_mxfp8(model) + + # Verify output_dtype is set to FP16 + dq_node = next( + node + for node in quantized_model.graph.node + if node.op_type == "TRT_MXFP8DequantizeLinear" + ) + output_dtype_attr = next(attr for attr in dq_node.attribute if attr.name == "output_dtype") + assert output_dtype_attr.i == TensorProto.FLOAT16 + + def test_mxfp8_gelu_approximation_update(self): + """Test that Gelu nodes are updated to use tanh approximation.""" + model = create_test_model_with_mxfp8_dq() + + # Run MXFP8 quantization + quantized_model = quantize_weights_to_mxfp8(model) + + # Verify Gelu approximation is set to tanh + gelu_node = next(node for node in quantized_model.graph.node if node.op_type == "Gelu") + approximate_attr = next(attr for attr in gelu_node.attribute if attr.name == "approximate") + assert approximate_attr.s == b"tanh" + + def test_mxfp8_with_missing_attributes(self): + """Test MXFP8 quantization with missing axis and block_size attributes.""" + # Create a model without axis and block_size attributes + weight_data = np.random.uniform(-1.0, 1.0, size=(64, 32)).astype(np.float32) + weight_tensor = numpy_helper.from_array(weight_data, "linear.weight") + + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [4, 32]) + + scale_data = np.random.uniform(0.1, 1.0, size=(2, 1)).astype(np.float32) + scale_constant = helper.make_node( + "Constant", + inputs=[], + outputs=["Constant_output_0"], + value=numpy_helper.from_array(scale_data), + name="scale_constant", + ) + + # Create TRT_MXFP8DequantizeLinear node without axis and block_size + dq_node = helper.make_node( + "TRT_MXFP8DequantizeLinear", + inputs=["linear.weight", "Constant_output_0"], + outputs=["dq_output"], + name="weight_dq", + output_dtype=TensorProto.FLOAT, + ) + + matmul_node = helper.make_node( + "MatMul", inputs=["input", "dq_output"], outputs=["output"], name="matmul" + ) + + graph = helper.make_graph( + nodes=[scale_constant, dq_node, matmul_node], + name="test_graph", + inputs=[input_tensor], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [4, 64])], + initializer=[weight_tensor], + ) + + model = helper.make_model(graph) + + # Run MXFP8 quantization (should use default values) + quantized_model = quantize_weights_to_mxfp8(model) + + # Verify the model is still processed correctly + weight_tensor = next( + init for init in quantized_model.graph.initializer if init.name == "linear.weight" + ) + assert weight_tensor.data_type == TensorProto.FLOAT8E4M3FN + + +class TestFP4QDQTo2DQ: + """Test suite for fp4qdq_to_2dq function.""" + + @pytest.mark.parametrize("with_transpose", [False, True]) + def test_fp4qdq_conversion(self, with_transpose): + """Test FP4QDQ to 2DQ conversion with and without Transpose node.""" + model = create_test_model_with_nvfp4_qdq(with_transpose=with_transpose) + + # Run FP4QDQ to 2DQ conversion + converted_model = fp4qdq_to_2dq(model) + + # Verify TRT_FP4QDQ node is removed + fp4qdq_nodes = [node for node in converted_model.graph.node if node.op_type == "TRT_FP4QDQ"] + assert len(fp4qdq_nodes) == 0 + + # Verify two DequantizeLinear nodes are created + dq_nodes = [ + node for node in converted_model.graph.node if node.op_type == "DequantizeLinear" + ] + assert len(dq_nodes) == 2 + + # Verify new initializers are created + initializer_names = {init.name for init in converted_model.graph.initializer} + assert "linear.weight_f4" in initializer_names + assert "linear.weight_f8_scale" in initializer_names + assert "linear.weight_f8_scale_f32_scale" in initializer_names + + # Verify original weight initializer is removed + assert "linear.weight" not in initializer_names + + # Verify FP4 weight tensor has correct data type + fp4_weight = next( + init for init in converted_model.graph.initializer if init.name == "linear.weight_f4" + ) + assert fp4_weight.data_type == TensorProto.FLOAT4E2M1 + + # Verify FP8 scale tensor has correct data type + fp8_scale = next( + init + for init in converted_model.graph.initializer + if init.name == "linear.weight_f8_scale" + ) + assert fp8_scale.data_type == TensorProto.FLOAT8E4M3FN + + # Additional verification for transpose case + if with_transpose: + # Verify Cast nodes are added for input type conversion + cast_nodes = [node for node in converted_model.graph.node if node.op_type == "Cast"] + assert len(cast_nodes) >= 1 # At least one cast node should be added