Skip to content

Commit eb9e31e

Browse files
authored
[5593873] [ONNX] Fix ResAdd logic to support 'Conv-BN-Sigmoid-Mul-Add' as fusible patterns (#450)
Signed-off-by: gcunhase <[email protected]>
1 parent d0e83ed commit eb9e31e

File tree

3 files changed

+150
-0
lines changed

3 files changed

+150
-0
lines changed

modelopt/onnx/quantization/graph_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def _get_backbone(root: Node):
201201
["BatchNormalization", "BiasAdd", conv_type],
202202
["Relu", "BatchNormalization", "BiasAdd", conv_type],
203203
["MaxPool", "Relu", "BatchNormalization", "BiasAdd", conv_type],
204+
["Mul", "Sigmoid", "BatchNormalization", conv_type],
204205
]
205206
for idx, path_type in enumerate(fusible_linear_path_types):
206207
if has_path_type(node, graph, path_type, is_forward=False, wild_card_types=[]):

tests/_test_utils/onnx_quantization/lib_test_models.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,3 +555,121 @@ def build_convtranspose_conv_residual_model():
555555
onnx.checker.check_model(model_inferred)
556556

557557
return model_inferred
558+
559+
560+
def build_conv_batchnorm_sig_mul_model():
561+
# Define your model inputs and outputs
562+
input_names = ["input_0"]
563+
output_names = ["output_0"]
564+
input_shapes = [(6, 48, 64, 176)]
565+
output_shapes = [(6, 48, 64, 176)]
566+
567+
inputs = [
568+
helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape)
569+
for input_name, input_shape in zip(input_names, input_shapes)
570+
]
571+
outputs = [
572+
helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape)
573+
for output_name, output_shape in zip(output_names, output_shapes)
574+
]
575+
576+
# Create the ONNX graph with the nodes
577+
nodes = [
578+
helper.make_node(
579+
op_type="Relu",
580+
inputs=["input_0"],
581+
outputs=["relu0_relu/Relu:0"],
582+
name="relu0_relu/Relu",
583+
),
584+
helper.make_node(
585+
op_type="Conv",
586+
inputs=["relu0_relu/Relu:0", "weights_1"],
587+
outputs=["conv1_conv/Conv2D:0"],
588+
name="conv1_conv/Conv2D",
589+
dilations=[1, 1],
590+
group=1,
591+
kernel_shape=[3, 3],
592+
pads=[1, 1, 1, 1],
593+
strides=[1, 1],
594+
),
595+
helper.make_node(
596+
op_type="BatchNormalization",
597+
inputs=["conv1_conv/Conv2D:0", "bn1_scale", "bn1_bias", "bn1_mean", "bn1_var"],
598+
outputs=["bn1_batchnorm/BatchNormalization:0"],
599+
name="bn1_batchnorm/BatchNormalization",
600+
),
601+
helper.make_node(
602+
op_type="Sigmoid",
603+
inputs=["bn1_batchnorm/BatchNormalization:0"],
604+
outputs=["sig1_sigmoid/Sigmoid:0"],
605+
name="sig1_sigmoid/Sigmoid",
606+
),
607+
helper.make_node(
608+
op_type="Mul",
609+
inputs=["sig1_sigmoid/Sigmoid:0", "bn1_batchnorm/BatchNormalization:0"],
610+
outputs=["mul1_mul/Mul:0"],
611+
name="mul1_mul/Mul",
612+
),
613+
helper.make_node(
614+
op_type="Add",
615+
inputs=["relu0_relu/Relu:0", "mul1_mul/Mul:0"],
616+
outputs=["add1_add/Add:0"],
617+
name="add1_add/Add",
618+
),
619+
helper.make_node(
620+
op_type="Relu",
621+
inputs=["add1_add/Add:0"],
622+
outputs=["output_0"],
623+
name="relu2_relu/Relu",
624+
),
625+
]
626+
627+
# Create the ONNX initializers
628+
initializers = [
629+
helper.make_tensor(
630+
name="weights_1",
631+
data_type=onnx.TensorProto.FLOAT,
632+
dims=(48, 48, 3, 3),
633+
vals=np.random.uniform(low=0.5, high=1.0, size=48 * 48 * 3 * 3),
634+
),
635+
helper.make_tensor(
636+
name="bn1_scale",
637+
data_type=onnx.TensorProto.FLOAT,
638+
dims=(48,),
639+
vals=np.random.uniform(low=0.5, high=1.0, size=48),
640+
),
641+
helper.make_tensor(
642+
name="bn1_bias",
643+
data_type=onnx.TensorProto.FLOAT,
644+
dims=(48,),
645+
vals=np.random.uniform(low=0.5, high=1.0, size=48),
646+
),
647+
helper.make_tensor(
648+
name="bn1_mean",
649+
data_type=onnx.TensorProto.FLOAT,
650+
dims=(48,),
651+
vals=np.random.uniform(low=0.5, high=1.0, size=48),
652+
),
653+
helper.make_tensor(
654+
name="bn1_var",
655+
data_type=onnx.TensorProto.FLOAT,
656+
dims=(48,),
657+
vals=np.random.uniform(low=0.5, high=1.0, size=48),
658+
),
659+
]
660+
661+
# Create the ONNX graph with the nodes and initializers
662+
graph = helper.make_graph(
663+
nodes, "conv_batchnorm_sig_mul", inputs, outputs, initializer=initializers
664+
)
665+
666+
# Create the ONNX model
667+
model = helper.make_model(graph)
668+
model.opset_import[0].version = 13
669+
model.ir_version = 10
670+
671+
# Check the ONNX model
672+
model_inferred = onnx.shape_inference.infer_shapes(model)
673+
onnx.checker.check_model(model_inferred)
674+
675+
return model_inferred

tests/unit/onnx/test_qdq_rules_int8.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919
import onnx
2020
import onnx_graphsurgeon as gs
2121
from _test_utils.onnx_quantization.lib_test_models import (
22+
build_conv_batchnorm_sig_mul_model,
2223
build_r1a_model,
2324
build_resnet_block,
2425
build_resnet_block_with_downsample,
2526
export_as_onnx,
2627
)
2728

2829
from modelopt.onnx.quantization.quantize import quantize
30+
from modelopt.onnx.utils import save_onnx
2931

3032

3133
def _assert_nodes_are_quantized(nodes):
@@ -119,3 +121,32 @@ def test_resnet_residual_connection_with_downsample(tmp_path):
119121
onnx_path = os.path.join(tmp_path, "model.onnx")
120122
export_as_onnx(model_torch, input_tensor, onnx_filename=onnx_path)
121123
_check_resnet_residual_connection(onnx_path)
124+
125+
126+
def test_conv_batchnorm_sig_mul_int8(tmp_path):
127+
onnx_model = build_conv_batchnorm_sig_mul_model()
128+
onnx_path = os.path.join(tmp_path, "conv_batchnorm_sig_mul_model.onnx")
129+
save_onnx(onnx_model, onnx_path)
130+
131+
quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16")
132+
133+
# Output model should be produced in the same tmp_path
134+
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")
135+
136+
# Check that quantized explicit model is generated
137+
assert os.path.isfile(output_onnx_path)
138+
139+
# Load the output model and check QDQ node placements
140+
graph = gs.import_onnx(onnx.load(output_onnx_path))
141+
142+
# Check that Conv and ConvTransposed are quantized
143+
conv_nodes = [n for n in graph.nodes if "Conv" in n.op]
144+
assert _assert_nodes_are_quantized(conv_nodes)
145+
146+
# Check that only 1 input of Add is quantized
147+
add_nodes = [n for n in graph.nodes if n.op == "Add"]
148+
for node in add_nodes:
149+
quantized_inputs = [inp for inp in node.inputs if inp.inputs[0].op == "DequantizeLinear"]
150+
assert len(quantized_inputs) == 1, (
151+
f"More than one input of {node.name} is being quantized, but only one should be quantized!"
152+
)

0 commit comments

Comments
 (0)