Skip to content

Commit 484b95a

Browse files
authored
[5271050, 5274346][ONNX] Add support for Conv-Act-Pool fusion (#448)
Signed-off-by: gcunhase <[email protected]>
1 parent 72f23dc commit 484b95a

File tree

7 files changed

+233
-53
lines changed

7 files changed

+233
-53
lines changed

modelopt/onnx/op_types.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ def is_fusible_scaling_op(op_type: str):
9696
]
9797

9898

99-
def is_copy_op(op_type: str):
100-
"""Returns whether the given op is a copy operator or not."""
101-
return op_type in [
99+
def get_copy_ops():
100+
"""Returns list of copy operators."""
101+
return [
102102
"Flatten",
103103
"Transpose",
104104
"Concat",
@@ -118,6 +118,11 @@ def is_copy_op(op_type: str):
118118
]
119119

120120

121+
def is_copy_op(op_type: str):
122+
"""Returns whether the given op is a copy operator or not."""
123+
return op_type in get_copy_ops()
124+
125+
121126
def is_linear_op(op_type: str):
122127
"""Returns whether the given op type is of Linear category or not."""
123128
return op_type in ["Conv", "ConvTranspose", "Gemm", "MatMul"]

modelopt/onnx/quantization/graph_utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from onnxruntime.quantization.calibrate import CalibrationDataReader
3030

3131
from modelopt.onnx.logging_config import logger
32-
from modelopt.onnx.op_types import is_copy_op, is_linear_op
32+
from modelopt.onnx.op_types import get_copy_ops, is_copy_op, is_linear_op
3333
from modelopt.onnx.quantization.ort_utils import create_inference_session
3434
from modelopt.onnx.utils import (
3535
find_lowest_common_ancestor,
@@ -173,7 +173,7 @@ def has_path_type(
173173
def get_fusible_backbone(node: Node, graph: Graph) -> Node | None:
174174
"""Returns the linear backbone node for a given node if it matches the pattern.
175175
176-
TensorRT fuses convolution with BN, Relu etc. when in some specific pattern.
176+
TensorRT fuses convolution with BN, Relu, MaxPool etc. when in some specific pattern.
177177
This rule tries to match some of those patterns.
178178
Note. BiasAdd and ConstMul are optional in path types.
179179
@@ -190,7 +190,7 @@ def _get_backbone(root: Node):
190190
return root
191191

192192
for tensor in root.inputs:
193-
if not isinstance(tensor, Constant):
193+
if not isinstance(tensor, Constant) and tensor.inputs:
194194
parent_node = tensor.inputs[0]
195195
bb = _get_backbone(parent_node)
196196
if bb:
@@ -207,7 +207,7 @@ def _get_backbone(root: Node):
207207
["Mul", "Sigmoid", "BatchNormalization", conv_type],
208208
]
209209
for idx, path_type in enumerate(fusible_linear_path_types):
210-
if has_path_type(node, graph, path_type, is_forward=False, wild_card_types=[]):
210+
if has_path_type(node, graph, path_type, is_forward=False, wild_card_types=get_copy_ops()):
211211
return _get_backbone(node)
212212

213213
return None
@@ -1002,7 +1002,6 @@ def find_nodes_from_matmul_to_exclude(
10021002
logger.debug("No MatMul nodes found in the model")
10031003
return []
10041004

1005-
nodes_to_exclude = []
10061005
logger.debug(f"Found {len(matmul_nodes)} MatMul nodes to analyze")
10071006

10081007
if calibration_shapes:

modelopt/onnx/quantization/partitioning.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,10 @@ def _build_fusible_partition(
4444
"""Traverses the graph starting from cur_node and updates the fusible_partition list.
4545
4646
Add a nodes to the partition if any of these holds:
47-
1. The node is a unary or binary pointwise operation and fusible by cask
48-
2. The node is BN and/or Relu and fusible with preceding Conv op
49-
3. The node is a residual Add and fusible with current partition
47+
1. The node is a unary or binary pointwise operation or a copy op and fusible by cask
48+
2. The node is BN and/or Relu and fusible with preceding Conv op (Conv-Act fusion)
49+
3. The node is MaxPool following a Conv-Act pattern (Conv-Act-Pool fusion)
50+
4. The node is a residual Add and fusible with current partition
5051
5152
Args:
5253
cur_node: Current candidate node for the partition.
@@ -131,11 +132,15 @@ def _is_fusible_mul(mul_node: Node) -> bool:
131132

132133
if (
133134
(
135+
is_copy_op(consumer_node.op)
136+
and _is_cask_fusible(consumer_node, partition_node_outputs)
137+
)
138+
or (
134139
is_pointwise_or_elementwise_op(consumer_node.op)
135140
and _is_cask_fusible(consumer_node, partition_node_outputs)
136141
)
137142
or (
138-
consumer_node.op in ["BatchNormalization", "Relu"]
143+
consumer_node.op in ["BatchNormalization", "Relu", "MaxPool"]
139144
and get_fusible_backbone(consumer_node, graph)
140145
)
141146
or _is_on_non_residual_path(consumer_node)

modelopt/onnx/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,7 @@ def get_dynamic_graph_inputs(onnx_model: onnx.ModelProto):
165165
List of dynamic inputs.
166166
"""
167167
graph = gs.import_onnx(onnx_model)
168-
return [
169-
inp for inp in graph.inputs if -1 in inp.shape or any(isinstance(s, str) for s in inp.shape)
170-
]
168+
return [inp for inp in graph.inputs if any(isinstance(s, str) or s <= 0 for s in inp.shape)]
171169

172170

173171
def _get_all_shapes(container: Any) -> dict[str, list[int]]:

tests/_test_utils/onnx/quantization/lib_test_models.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,3 +673,152 @@ def build_conv_batchnorm_sig_mul_model():
673673
onnx.checker.check_model(model_inferred)
674674

675675
return model_inferred
676+
677+
678+
def build_conv_act_pool_model(include_reshape_node=False):
679+
# Define your model inputs and outputs
680+
input_names = ["input_0"]
681+
output_names = ["output_0"]
682+
input_shapes = [(32, 64, 256, 256)]
683+
output_shapes = [(32, 128, 128, 128)]
684+
685+
inputs = [
686+
helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape)
687+
for input_name, input_shape in zip(input_names, input_shapes)
688+
]
689+
outputs = [
690+
helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape)
691+
for output_name, output_shape in zip(output_names, output_shapes)
692+
]
693+
694+
# Create the ONNX graph with the nodes
695+
nodes = [
696+
helper.make_node(
697+
op_type="Conv",
698+
inputs=["input_0", "weights_1", "bias_1"],
699+
outputs=["conv1_conv/Conv2D:0"],
700+
name="conv1_conv/Conv2D",
701+
dilations=[1, 1],
702+
group=1,
703+
kernel_shape=[3, 3],
704+
pads=[1, 1, 1, 1],
705+
strides=[1, 1],
706+
),
707+
helper.make_node(
708+
op_type="BatchNormalization",
709+
inputs=["conv1_conv/Conv2D:0", "bn1_scale", "bn1_bias", "bn1_mean", "bn1_var"],
710+
outputs=["bn1_batchnorm/BatchNormalization:0"],
711+
name="bn1_batchnorm/BatchNormalization",
712+
),
713+
helper.make_node(
714+
op_type="Relu",
715+
inputs=["bn1_batchnorm/BatchNormalization:0"],
716+
outputs=["relu1_relu/Relu:0"],
717+
name="relu1_relu/Relu",
718+
),
719+
]
720+
if include_reshape_node:
721+
nodes.append(
722+
helper.make_node(
723+
op_type="Reshape",
724+
inputs=["relu1_relu/Relu:0", "shape_1"],
725+
outputs=["reshape1_reshape/Reshape:0"],
726+
name="reshape1_reshape/Reshape",
727+
),
728+
)
729+
nodes.extend(
730+
[
731+
helper.make_node(
732+
op_type="MaxPool",
733+
inputs=[
734+
"reshape1_reshape/Reshape:0" if include_reshape_node else "relu1_relu/Relu:0"
735+
],
736+
outputs=["maxpool1_maxpool/MaxPool2D:0"],
737+
name="maxpool1_maxpool/MaxPool2D",
738+
ceil_mode=False,
739+
kernel_shape=[3, 3],
740+
pads=[1, 1, 1, 1],
741+
strides=[2, 2],
742+
),
743+
helper.make_node(
744+
op_type="Conv",
745+
inputs=["maxpool1_maxpool/MaxPool2D:0", "weights_2"],
746+
outputs=["output_0"],
747+
name="conv2_conv/Conv2D",
748+
dilations=[1, 1],
749+
group=1,
750+
kernel_shape=[3, 3],
751+
pads=[1, 1, 1, 1],
752+
strides=[1, 1],
753+
),
754+
]
755+
)
756+
757+
# Create the ONNX initializers
758+
initializers = [
759+
helper.make_tensor(
760+
name="weights_1",
761+
data_type=onnx.TensorProto.FLOAT,
762+
dims=(128, 64, 3, 3),
763+
vals=np.random.uniform(low=0.5, high=1.0, size=128 * 64 * 3 * 3),
764+
),
765+
helper.make_tensor(
766+
name="bias_1",
767+
data_type=onnx.TensorProto.FLOAT,
768+
dims=(128,),
769+
vals=np.random.uniform(low=0.5, high=1.0, size=128),
770+
),
771+
helper.make_tensor(
772+
name="bn1_scale",
773+
data_type=onnx.TensorProto.FLOAT,
774+
dims=(128,),
775+
vals=np.random.uniform(low=0.5, high=1.0, size=128),
776+
),
777+
helper.make_tensor(
778+
name="bn1_bias",
779+
data_type=onnx.TensorProto.FLOAT,
780+
dims=(128,),
781+
vals=np.random.uniform(low=0.5, high=1.0, size=128),
782+
),
783+
helper.make_tensor(
784+
name="bn1_mean",
785+
data_type=onnx.TensorProto.FLOAT,
786+
dims=(128,),
787+
vals=np.random.uniform(low=0.5, high=1.0, size=128),
788+
),
789+
helper.make_tensor(
790+
name="bn1_var",
791+
data_type=onnx.TensorProto.FLOAT,
792+
dims=(128,),
793+
vals=np.random.uniform(low=0.5, high=1.0, size=128),
794+
),
795+
helper.make_tensor(
796+
name="weights_2",
797+
data_type=onnx.TensorProto.FLOAT,
798+
dims=(128, 128, 3, 3),
799+
vals=np.random.uniform(low=0.5, high=1.0, size=128 * 128 * 3 * 3),
800+
),
801+
]
802+
if include_reshape_node:
803+
initializers.append(
804+
helper.make_tensor(
805+
name="shape_1",
806+
data_type=onnx.TensorProto.INT64,
807+
dims=(4,),
808+
vals=(32, 128, 256, 256),
809+
),
810+
)
811+
812+
# Create the ONNX graph with the nodes and initializers
813+
graph = helper.make_graph(nodes, "conv_act_pool", inputs, outputs, initializer=initializers)
814+
815+
# Create the ONNX model
816+
model = helper.make_model(graph)
817+
model.opset_import[0].version = 13
818+
model.ir_version = 10
819+
820+
# Check the ONNX model
821+
model_inferred = onnx.shape_inference.infer_shapes(model)
822+
onnx.checker.check_model(model_inferred)
823+
824+
return model_inferred

tests/unit/onnx/test_qdq_rules_int8.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@
1818
import numpy as np
1919
import onnx
2020
import onnx_graphsurgeon as gs
21+
import pytest
2122
from _test_utils.onnx.quantization.lib_test_models import (
23+
build_conv_act_pool_model,
2224
build_conv_batchnorm_sig_mul_model,
25+
build_convtranspose_conv_residual_model,
2326
build_r1a_model,
2427
build_resnet_block,
2528
build_resnet_block_with_downsample,
@@ -40,7 +43,7 @@ def assert_nodes_are_quantized(nodes):
4043
return True
4144

4245

43-
def _assert_nodes_are_not_quantized(nodes):
46+
def assert_nodes_are_not_quantized(nodes):
4447
for node in nodes:
4548
for inp_idx, inp in enumerate(node.inputs):
4649
if isinstance(inp, gs.Variable) and inp.inputs:
@@ -76,7 +79,7 @@ def test_bias_add_rule(tmp_path):
7679
other_nodes = [
7780
n for n in graph.nodes if n.op not in ["Conv", "QuantizeLinear", "DequantizeLinear"]
7881
]
79-
assert _assert_nodes_are_not_quantized(other_nodes)
82+
assert assert_nodes_are_not_quantized(other_nodes)
8083

8184

8285
def _check_resnet_residual_connection(onnx_path):
@@ -106,7 +109,7 @@ def _check_resnet_residual_connection(onnx_path):
106109
other_nodes = [
107110
n for n in graph.nodes if n.op not in ["Conv", "Add", "QuantizeLinear", "DequantizeLinear"]
108111
]
109-
assert _assert_nodes_are_not_quantized(other_nodes)
112+
assert assert_nodes_are_not_quantized(other_nodes)
110113

111114

112115
def test_resnet_residual_connections(tmp_path):
@@ -123,6 +126,35 @@ def test_resnet_residual_connection_with_downsample(tmp_path):
123126
_check_resnet_residual_connection(onnx_path)
124127

125128

129+
def test_convtranspose_conv_residual_int8(tmp_path):
130+
onnx_model = build_convtranspose_conv_residual_model()
131+
onnx_path = os.path.join(tmp_path, "convtranspose_conv_residual_model.onnx")
132+
save_onnx(onnx_model, onnx_path)
133+
134+
quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16")
135+
136+
# Output model should be produced in the same tmp_path
137+
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")
138+
139+
# Check that quantized explicit model is generated
140+
assert os.path.isfile(output_onnx_path)
141+
142+
# Load the output model and check QDQ node placements
143+
graph = gs.import_onnx(onnx.load(output_onnx_path))
144+
145+
# Check that Conv and ConvTransposed are quantized
146+
conv_nodes = [n for n in graph.nodes if "Conv" in n.op]
147+
assert assert_nodes_are_quantized(conv_nodes)
148+
149+
# Check that only 1 input of Add is quantized
150+
add_nodes = [n for n in graph.nodes if n.op == "Add"]
151+
for node in add_nodes:
152+
quantized_inputs = [inp for inp in node.inputs if inp.inputs[0].op == "DequantizeLinear"]
153+
assert len(quantized_inputs) == 1, (
154+
f"More than one input of {node.name} is being quantized, but only one should be quantized!"
155+
)
156+
157+
126158
def test_conv_batchnorm_sig_mul_int8(tmp_path):
127159
onnx_model = build_conv_batchnorm_sig_mul_model()
128160
onnx_path = os.path.join(tmp_path, "conv_batchnorm_sig_mul_model.onnx")
@@ -150,3 +182,29 @@ def test_conv_batchnorm_sig_mul_int8(tmp_path):
150182
assert len(quantized_inputs) == 1, (
151183
f"More than one input of {node.name} is being quantized, but only one should be quantized!"
152184
)
185+
186+
187+
@pytest.mark.parametrize("include_reshape_node", [False, True])
188+
def test_conv_act_pool_int8(tmp_path, include_reshape_node):
189+
onnx_model = build_conv_act_pool_model(include_reshape_node)
190+
onnx_path = os.path.join(tmp_path, f"conv_act_pool_model_{include_reshape_node}.onnx")
191+
save_onnx(onnx_model, onnx_path)
192+
193+
quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16")
194+
195+
# Output model should be produced in the same tmp_path
196+
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")
197+
198+
# Check that quantized explicit model is generated
199+
assert os.path.isfile(output_onnx_path)
200+
201+
# Load the output model and check QDQ node placements
202+
graph = gs.import_onnx(onnx.load(output_onnx_path))
203+
204+
# Check that Conv is quantized
205+
conv_nodes = [n for n in graph.nodes if n.op == "Conv"]
206+
assert assert_nodes_are_quantized(conv_nodes)
207+
208+
# Check that MaxPool is not quantized
209+
pool_nodes = [n for n in graph.nodes if n.op == "MaxPool"]
210+
assert assert_nodes_are_not_quantized(pool_nodes)

0 commit comments

Comments
 (0)