Skip to content

Commit 1fe43df

Browse files
committed
[5274346] Skip copy ops in CASK patterns, added unittest
Signed-off-by: gcunhase <[email protected]>
1 parent 37e68d4 commit 1fe43df

File tree

5 files changed

+66
-32
lines changed

5 files changed

+66
-32
lines changed

modelopt/onnx/op_types.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ def is_fusible_scaling_op(op_type: str):
9696
]
9797

9898

99-
def is_copy_op(op_type: str):
100-
"""Returns whether the given op is a copy operator or not."""
101-
return op_type in [
99+
def copy_ops():
100+
"""Returns list of copy operators."""
101+
return [
102102
"Flatten",
103103
"Transpose",
104104
"Concat",
@@ -118,6 +118,11 @@ def is_copy_op(op_type: str):
118118
]
119119

120120

121+
def is_copy_op(op_type: str):
122+
"""Returns whether the given op is a copy operator or not."""
123+
return op_type in copy_ops()
124+
125+
121126
def is_linear_op(op_type: str):
122127
"""Returns whether the given op type is of Linear category or not."""
123128
return op_type in ["Conv", "ConvTranspose", "Gemm", "MatMul"]

modelopt/onnx/quantization/graph_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from onnxruntime.quantization.calibrate import CalibrationDataReader
3030

3131
from modelopt.onnx.logging_config import logger
32-
from modelopt.onnx.op_types import is_copy_op, is_linear_op
32+
from modelopt.onnx.op_types import copy_ops, is_copy_op, is_linear_op
3333
from modelopt.onnx.quantization.ort_utils import create_inference_session
3434
from modelopt.onnx.utils import (
3535
find_lowest_common_ancestor,
@@ -203,7 +203,7 @@ def _get_backbone(root: Node):
203203
["MaxPool", "Relu", "BatchNormalization", "BiasAdd", conv_type],
204204
]
205205
for idx, path_type in enumerate(fusible_linear_path_types):
206-
if has_path_type(node, graph, path_type, is_forward=False, wild_card_types=[]):
206+
if has_path_type(node, graph, path_type, is_forward=False, wild_card_types=copy_ops()):
207207
return _get_backbone(node)
208208

209209
return None

modelopt/onnx/quantization/partitioning.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _build_fusible_partition(
4444
"""Traverses the graph starting from cur_node and updates the fusible_partition list.
4545
4646
Add a nodes to the partition if any of these holds:
47-
1. The node is a unary or binary pointwise operation and fusible by cask
47+
1. The node is a unary or binary pointwise operation or a copy op and fusible by cask
4848
2. The node is BN and/or Relu and fusible with preceding Conv op (Conv-Act fusion)
4949
3. The node is MaxPool following a Conv-Act pattern (Conv-Act-Pool fusion)
5050
4. The node is a residual Add and fusible with current partition
@@ -132,6 +132,10 @@ def _is_fusible_mul(mul_node: Node) -> bool:
132132

133133
if (
134134
(
135+
is_copy_op(consumer_node.op)
136+
and _is_cask_fusible(consumer_node, partition_node_outputs)
137+
)
138+
or (
135139
is_pointwise_or_elementwise_op(consumer_node.op)
136140
and _is_cask_fusible(consumer_node, partition_node_outputs)
137141
)

tests/_test_utils/onnx_quantization/lib_test_models.py

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def build_convtranspose_conv_residual_model():
557557
return model_inferred
558558

559559

560-
def build_conv_act_pool_model():
560+
def build_conv_act_pool_model(include_reshape_node=False):
561561
# Define your model inputs and outputs
562562
input_names = ["input_0"]
563563
output_names = ["output_0"]
@@ -583,7 +583,7 @@ def build_conv_act_pool_model():
583583
dilations=[1, 1],
584584
group=1,
585585
kernel_shape=[3, 3],
586-
pads=[0, 0, 0, 0],
586+
pads=[1, 1, 1, 1],
587587
strides=[1, 1],
588588
),
589589
helper.make_node(
@@ -598,28 +598,43 @@ def build_conv_act_pool_model():
598598
outputs=["relu1_relu/Relu:0"],
599599
name="relu1_relu/Relu",
600600
),
601-
helper.make_node(
602-
op_type="MaxPool",
603-
inputs=["relu1_relu/Relu:0"],
604-
outputs=["maxpool1_maxpool/MaxPool2D:0"],
605-
name="maxpool1_maxpool/MaxPool2D",
606-
ceil_mode=False,
607-
kernel_shape=[3, 3],
608-
pads=[0, 0, 0, 0],
609-
strides=[2, 2],
610-
),
611-
helper.make_node(
612-
op_type="Conv",
613-
inputs=["maxpool1_maxpool/MaxPool2D:0", "weights_2"],
614-
outputs=["output_0"],
615-
name="conv2_conv/Conv2D",
616-
dilations=[1, 1],
617-
group=1,
618-
kernel_shape=[3, 3],
619-
pads=[0, 0, 0, 0],
620-
strides=[1, 1],
621-
),
622601
]
602+
if include_reshape_node:
603+
nodes.append(
604+
helper.make_node(
605+
op_type="Reshape",
606+
inputs=["relu1_relu/Relu:0", "shape_1"],
607+
outputs=["reshape1_reshape/Reshape:0"],
608+
name="reshape1_reshape/Reshape",
609+
),
610+
)
611+
nodes.extend(
612+
[
613+
helper.make_node(
614+
op_type="MaxPool",
615+
inputs=[
616+
"reshape1_reshape/Reshape:0" if include_reshape_node else "relu1_relu/Relu:0"
617+
],
618+
outputs=["maxpool1_maxpool/MaxPool2D:0"],
619+
name="maxpool1_maxpool/MaxPool2D",
620+
ceil_mode=False,
621+
kernel_shape=[3, 3],
622+
pads=[1, 1, 1, 1],
623+
strides=[2, 2],
624+
),
625+
helper.make_node(
626+
op_type="Conv",
627+
inputs=["maxpool1_maxpool/MaxPool2D:0", "weights_2"],
628+
outputs=["output_0"],
629+
name="conv2_conv/Conv2D",
630+
dilations=[1, 1],
631+
group=1,
632+
kernel_shape=[3, 3],
633+
pads=[1, 1, 1, 1],
634+
strides=[1, 1],
635+
),
636+
]
637+
)
623638

624639
# Create the ONNX initializers
625640
initializers = [
@@ -666,6 +681,15 @@ def build_conv_act_pool_model():
666681
vals=np.random.uniform(low=0.5, high=1.0, size=128 * 128 * 3 * 3),
667682
),
668683
]
684+
if include_reshape_node:
685+
initializers.append(
686+
helper.make_tensor(
687+
name="shape_1",
688+
data_type=onnx.TensorProto.INT64,
689+
dims=(4,),
690+
vals=(32, 128, 256, 256),
691+
),
692+
)
669693

670694
# Create the ONNX graph with the nodes and initializers
671695
graph = helper.make_graph(nodes, "conv_act_pool", inputs, outputs, initializer=initializers)

tests/unit/onnx/test_quantize_int8.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,10 @@ def test_convtranspose_conv_residual_int8(tmp_path):
9797
)
9898

9999

100-
def test_conv_act_pool_int8(tmp_path):
101-
onnx_model = build_conv_act_pool_model()
102-
onnx_path = os.path.join(tmp_path, "conv_act_pool_model.onnx")
100+
@pytest.mark.parametrize("include_reshape_node", [False, True])
101+
def test_conv_act_pool_int8(tmp_path, include_reshape_node):
102+
onnx_model = build_conv_act_pool_model(include_reshape_node)
103+
onnx_path = os.path.join(tmp_path, f"conv_act_pool_model_{include_reshape_node}.onnx")
103104
save_onnx(onnx_model, onnx_path)
104105

105106
moq.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16")

0 commit comments

Comments
 (0)