Skip to content

Commit b442a28

Browse files
committed
[5274346] Skip copy ops in CASK patterns, added unittest
Signed-off-by: gcunhase <[email protected]>
1 parent be0470c commit b442a28

File tree

5 files changed

+66
-32
lines changed

5 files changed

+66
-32
lines changed

modelopt/onnx/op_types.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ def is_fusible_scaling_op(op_type: str):
9696
]
9797

9898

99-
def is_copy_op(op_type: str):
100-
"""Returns whether the given op is a copy operator or not."""
101-
return op_type in [
99+
def copy_ops():
100+
"""Returns list of copy operators."""
101+
return [
102102
"Flatten",
103103
"Transpose",
104104
"Concat",
@@ -118,6 +118,11 @@ def is_copy_op(op_type: str):
118118
]
119119

120120

121+
def is_copy_op(op_type: str):
122+
"""Returns whether the given op is a copy operator or not."""
123+
return op_type in copy_ops()
124+
125+
121126
def is_linear_op(op_type: str):
122127
"""Returns whether the given op type is of Linear category or not."""
123128
return op_type in ["Conv", "ConvTranspose", "Gemm", "MatMul"]

modelopt/onnx/quantization/graph_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
3131

3232
from modelopt.onnx.logging_config import logger
33-
from modelopt.onnx.op_types import is_copy_op, is_linear_op
33+
from modelopt.onnx.op_types import copy_ops, is_copy_op, is_linear_op
3434
from modelopt.onnx.quantization.ort_utils import create_inference_session
3535
from modelopt.onnx.utils import (
3636
find_lowest_common_ancestor,
@@ -203,7 +203,7 @@ def _get_backbone(root: Node):
203203
["MaxPool", "Relu", "BatchNormalization", "BiasAdd", conv_type],
204204
]
205205
for idx, path_type in enumerate(fusible_linear_path_types):
206-
if has_path_type(node, graph, path_type, is_forward=False, wild_card_types=[]):
206+
if has_path_type(node, graph, path_type, is_forward=False, wild_card_types=copy_ops()):
207207
return _get_backbone(node)
208208

209209
return None

modelopt/onnx/quantization/partitioning.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _build_fusible_partition(
4444
"""Traverses the graph starting from cur_node and updates the fusible_partition list.
4545
4646
Add a nodes to the partition if any of these holds:
47-
1. The node is a unary or binary pointwise operation and fusible by cask
47+
1. The node is a unary or binary pointwise operation or a copy op and fusible by cask
4848
2. The node is BN and/or Relu and fusible with preceding Conv op (Conv-Act fusion)
4949
3. The node is MaxPool following a Conv-Act pattern (Conv-Act-Pool fusion)
5050
4. The node is a residual Add and fusible with current partition
@@ -132,6 +132,10 @@ def _is_fusible_mul(mul_node: Node) -> bool:
132132

133133
if (
134134
(
135+
is_copy_op(consumer_node.op)
136+
and _is_cask_fusible(consumer_node, partition_node_outputs)
137+
)
138+
or (
135139
is_pointwise_or_elementwise_op(consumer_node.op)
136140
and _is_cask_fusible(consumer_node, partition_node_outputs)
137141
)

tests/_test_utils/onnx_quantization/lib_test_models.py

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ def build_convtranspose_conv_residual_model():
556556
return model_inferred
557557

558558

559-
def build_conv_act_pool_model():
559+
def build_conv_act_pool_model(include_reshape_node=False):
560560
# Define your model inputs and outputs
561561
input_names = ["input_0"]
562562
output_names = ["output_0"]
@@ -582,7 +582,7 @@ def build_conv_act_pool_model():
582582
dilations=[1, 1],
583583
group=1,
584584
kernel_shape=[3, 3],
585-
pads=[0, 0, 0, 0],
585+
pads=[1, 1, 1, 1],
586586
strides=[1, 1],
587587
),
588588
helper.make_node(
@@ -597,28 +597,43 @@ def build_conv_act_pool_model():
597597
outputs=["relu1_relu/Relu:0"],
598598
name="relu1_relu/Relu",
599599
),
600-
helper.make_node(
601-
op_type="MaxPool",
602-
inputs=["relu1_relu/Relu:0"],
603-
outputs=["maxpool1_maxpool/MaxPool2D:0"],
604-
name="maxpool1_maxpool/MaxPool2D",
605-
ceil_mode=False,
606-
kernel_shape=[3, 3],
607-
pads=[0, 0, 0, 0],
608-
strides=[2, 2],
609-
),
610-
helper.make_node(
611-
op_type="Conv",
612-
inputs=["maxpool1_maxpool/MaxPool2D:0", "weights_2"],
613-
outputs=["output_0"],
614-
name="conv2_conv/Conv2D",
615-
dilations=[1, 1],
616-
group=1,
617-
kernel_shape=[3, 3],
618-
pads=[0, 0, 0, 0],
619-
strides=[1, 1],
620-
),
621600
]
601+
if include_reshape_node:
602+
nodes.append(
603+
helper.make_node(
604+
op_type="Reshape",
605+
inputs=["relu1_relu/Relu:0", "shape_1"],
606+
outputs=["reshape1_reshape/Reshape:0"],
607+
name="reshape1_reshape/Reshape",
608+
),
609+
)
610+
nodes.extend(
611+
[
612+
helper.make_node(
613+
op_type="MaxPool",
614+
inputs=[
615+
"reshape1_reshape/Reshape:0" if include_reshape_node else "relu1_relu/Relu:0"
616+
],
617+
outputs=["maxpool1_maxpool/MaxPool2D:0"],
618+
name="maxpool1_maxpool/MaxPool2D",
619+
ceil_mode=False,
620+
kernel_shape=[3, 3],
621+
pads=[1, 1, 1, 1],
622+
strides=[2, 2],
623+
),
624+
helper.make_node(
625+
op_type="Conv",
626+
inputs=["maxpool1_maxpool/MaxPool2D:0", "weights_2"],
627+
outputs=["output_0"],
628+
name="conv2_conv/Conv2D",
629+
dilations=[1, 1],
630+
group=1,
631+
kernel_shape=[3, 3],
632+
pads=[1, 1, 1, 1],
633+
strides=[1, 1],
634+
),
635+
]
636+
)
622637

623638
# Create the ONNX initializers
624639
initializers = [
@@ -665,6 +680,15 @@ def build_conv_act_pool_model():
665680
vals=np.random.uniform(low=0.5, high=1.0, size=128 * 128 * 3 * 3),
666681
),
667682
]
683+
if include_reshape_node:
684+
initializers.append(
685+
helper.make_tensor(
686+
name="shape_1",
687+
data_type=onnx.TensorProto.INT64,
688+
dims=(4,),
689+
vals=(32, 128, 256, 256),
690+
),
691+
)
668692

669693
# Create the ONNX graph with the nodes and initializers
670694
graph = helper.make_graph(nodes, "conv_act_pool", inputs, outputs, initializer=initializers)

tests/unit/onnx/test_quantize_int8.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,10 @@ def test_convtranspose_conv_residual_int8(tmp_path):
9797
)
9898

9999

100-
def test_conv_act_pool_int8(tmp_path):
101-
onnx_model = build_conv_act_pool_model()
102-
onnx_path = os.path.join(tmp_path, "conv_act_pool_model.onnx")
100+
@pytest.mark.parametrize("include_reshape_node", [False, True])
101+
def test_conv_act_pool_int8(tmp_path, include_reshape_node):
102+
onnx_model = build_conv_act_pool_model(include_reshape_node)
103+
onnx_path = os.path.join(tmp_path, f"conv_act_pool_model_{include_reshape_node}.onnx")
103104
save_onnx(onnx_model, onnx_path)
104105

105106
moq.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16")

0 commit comments

Comments
 (0)