diff --git a/modelopt/onnx/quantization/graph_utils.py b/modelopt/onnx/quantization/graph_utils.py index 6b37e3e7e..e10c001b4 100755 --- a/modelopt/onnx/quantization/graph_utils.py +++ b/modelopt/onnx/quantization/graph_utils.py @@ -528,7 +528,7 @@ def build_non_residual_input_map( # Generally if both the inputs have a backbone then both backbones are of the same type if backbone1 and backbone2: - if backbone1 == backbone2 or backbone1.op != backbone2.op: + if backbone1 == backbone2: non_residual_inputs[node.name] = None continue diff --git a/tests/_test_utils/onnx_quantization/lib_test_models.py b/tests/_test_utils/onnx_quantization/lib_test_models.py index 833c27c80..e8b99e1f6 100644 --- a/tests/_test_utils/onnx_quantization/lib_test_models.py +++ b/tests/_test_utils/onnx_quantization/lib_test_models.py @@ -372,3 +372,185 @@ def build_conv_concat_model(): onnx.checker.check_model(model_inferred) return model_inferred + + +def build_convtranspose_conv_residual_model(): + # Define your model inputs and outputs + input_names = ["input_0"] + output_names = ["output_0"] + input_shapes = [(2, 39, 96, 192)] + output_shapes = [(2, 32, 192, 384)] + + inputs = [ + helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape) + for input_name, input_shape in zip(input_names, input_shapes) + ] + outputs = [ + helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape) + for output_name, output_shape in zip(output_names, output_shapes) + ] + + # Create the ONNX graph with the nodes + nodes = [ + helper.make_node( + op_type="ConvTranspose", + inputs=["input_0", "weights_1", "bias_1"], + outputs=["convtranspose1_convtranspose/ConvTranspose:0"], + name="convtranspose1_convtranspose/ConvTranspose", + dilations=[1, 1], + group=1, + kernel_shape=[2, 2], + pads=[0, 0, 0, 0], + strides=[2, 2], + ), + helper.make_node( + op_type="Relu", + inputs=["convtranspose1_convtranspose/ConvTranspose:0"], + outputs=["relu1_relu/Relu:0"], + name="relu1_relu/Relu", + ), + helper.make_node( + op_type="Conv", + inputs=["relu1_relu/Relu:0", "weights_2"], + outputs=["conv2_conv/Conv2D:0"], + name="conv2_conv/Conv2D", + dilations=[1, 1], + group=1, + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + strides=[1, 1], + ), + helper.make_node( + op_type="BatchNormalization", + inputs=["conv2_conv/Conv2D:0", "bn1_scale", "bn1_bias", "bn1_mean", "bn1_var"], + outputs=["bn1_batchnorm/BatchNormalization:0"], + name="bn1_batchnorm/BatchNormalization", + ), + helper.make_node( + op_type="Relu", + inputs=["bn1_batchnorm/BatchNormalization:0"], + outputs=["relu2_relu/Relu:0"], + name="relu2_relu/Relu", + ), + helper.make_node( + op_type="Conv", + inputs=["relu2_relu/Relu:0", "weights_3"], + outputs=["conv3_conv/Conv2D:0"], + name="conv3_conv/Conv2D", + dilations=[1, 1], + group=1, + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + strides=[1, 1], + ), + helper.make_node( + op_type="BatchNormalization", + inputs=["conv3_conv/Conv2D:0", "bn2_scale", "bn2_bias", "bn2_mean", "bn2_var"], + outputs=["bn2_batchnorm/BatchNormalization:0"], + name="bn2_batchnorm/BatchNormalization", + ), + helper.make_node( + op_type="Add", + inputs=["relu1_relu/Relu:0", "bn2_batchnorm/BatchNormalization:0"], + outputs=["add1_add/Add:0"], + name="add1_add/Add", + ), + helper.make_node( + op_type="Relu", + inputs=["add1_add/Add:0"], + outputs=["output_0"], + name="relu3_relu/Relu", + ), + ] + + # Create the ONNX initializers + initializers = [ + helper.make_tensor( + name="weights_1", + data_type=onnx.TensorProto.FLOAT, + dims=(39, 32, 2, 2), + vals=np.random.uniform(low=0.5, high=1.0, size=39 * 32 * 2 * 2), + ), + helper.make_tensor( + name="bias_1", + data_type=onnx.TensorProto.FLOAT, + dims=(32,), + vals=np.random.uniform(low=0.5, high=1.0, size=32), + ), + helper.make_tensor( + name="weights_2", + data_type=onnx.TensorProto.FLOAT, + dims=(32, 32, 3, 3), + vals=np.random.uniform(low=0.5, high=1.0, size=32 * 32 * 3 * 3), + ), + helper.make_tensor( + name="bn1_scale", + data_type=onnx.TensorProto.FLOAT, + dims=(32,), + vals=np.random.uniform(low=0.5, high=1.0, size=32), + ), + helper.make_tensor( + name="bn1_bias", + data_type=onnx.TensorProto.FLOAT, + dims=(32,), + vals=np.random.uniform(low=0.5, high=1.0, size=32), + ), + helper.make_tensor( + name="bn1_mean", + data_type=onnx.TensorProto.FLOAT, + dims=(32,), + vals=np.random.uniform(low=0.5, high=1.0, size=32), + ), + helper.make_tensor( + name="bn1_var", + data_type=onnx.TensorProto.FLOAT, + dims=(32,), + vals=np.random.uniform(low=0.5, high=1.0, size=32), + ), + helper.make_tensor( + name="weights_3", + data_type=onnx.TensorProto.FLOAT, + dims=(32, 32, 3, 3), + vals=np.random.uniform(low=0.5, high=1.0, size=32 * 32 * 3 * 3), + ), + helper.make_tensor( + name="bn2_scale", + data_type=onnx.TensorProto.FLOAT, + dims=(32,), + vals=np.random.uniform(low=0.5, high=1.0, size=32), + ), + helper.make_tensor( + name="bn2_bias", + data_type=onnx.TensorProto.FLOAT, + dims=(32,), + vals=np.random.uniform(low=0.5, high=1.0, size=32), + ), + helper.make_tensor( + name="bn2_mean", + data_type=onnx.TensorProto.FLOAT, + dims=(32,), + vals=np.random.uniform(low=0.5, high=1.0, size=32), + ), + helper.make_tensor( + name="bn2_var", + data_type=onnx.TensorProto.FLOAT, + dims=(32,), + vals=np.random.uniform(low=0.5, high=1.0, size=32), + ), + ] + + # Create the ONNX graph with the nodes and initializers + graph = helper.make_graph( + nodes, "convtranspose_conv_residual", inputs, outputs, initializer=initializers + ) + + # Create the ONNX model + model = helper.make_model(graph) + model.opset_import[0].version = 13 + model.ir_version = 10 + + # Check the ONNX model + model_inferred = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model_inferred) + + return model_inferred diff --git a/tests/unit/onnx/test_quantize_int8.py b/tests/unit/onnx/test_quantize_int8.py index cafc8beeb..b474558f8 100644 --- a/tests/unit/onnx/test_quantize_int8.py +++ b/tests/unit/onnx/test_quantize_int8.py @@ -19,9 +19,14 @@ import onnx_graphsurgeon as gs import pytest import torch -from _test_utils.onnx_quantization.lib_test_models import SimpleMLP, export_as_onnx +from _test_utils.onnx_quantization.lib_test_models import ( + SimpleMLP, + build_convtranspose_conv_residual_model, + export_as_onnx, +) import modelopt.onnx.quantization as moq +from modelopt.onnx.utils import save_onnx def _assert_nodes_are_quantized(nodes): @@ -52,6 +57,35 @@ def test_int8(tmp_path, high_precision_dtype): # Load the output model and check QDQ node placements graph = gs.import_onnx(onnx.load(output_onnx_path)) - # Check that all MatMul nodes are quantized + # Check that all MatMul nodes are quantized mm_nodes = [n for n in graph.nodes if n.op == "MatMul"] assert _assert_nodes_are_quantized(mm_nodes) + + +def test_convtranspose_conv_residual_int8(tmp_path): + onnx_model = build_convtranspose_conv_residual_model() + onnx_path = os.path.join(tmp_path, "convtranspose_conv_residual_model.onnx") + save_onnx(onnx_model, onnx_path) + + moq.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16") + + # Output model should be produced in the same tmp_path + output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx") + + # Check that quantized explicit model is generated + assert os.path.isfile(output_onnx_path) + + # Load the output model and check QDQ node placements + graph = gs.import_onnx(onnx.load(output_onnx_path)) + + # Check that Conv and ConvTransposed are quantized + conv_nodes = [n for n in graph.nodes if "Conv" in n.op] + assert _assert_nodes_are_quantized(conv_nodes) + + # Check that only 1 input of Add is quantized + add_nodes = [n for n in graph.nodes if n.op == "Add"] + for node in add_nodes: + quantized_inputs = [inp for inp in node.inputs if inp.inputs[0].op == "DequantizeLinear"] + assert len(quantized_inputs) == 1, ( + f"More than one input of {node.name} is being quantized, but only one should be quantized!" + )