Skip to content

Commit caf4c38

Browse files
committed
Move ConvT residual test to qdq_rules
Signed-off-by: gcunhase <[email protected]>
1 parent de601b8 commit caf4c38

File tree

2 files changed

+31
-35
lines changed

2 files changed

+31
-35
lines changed

tests/unit/onnx/test_qdq_rules_int8.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from _test_utils.onnx.quantization.lib_test_models import (
2222
build_conv_act_pool_model,
2323
build_conv_batchnorm_sig_mul_model,
24+
build_convtranspose_conv_residual_model,
2425
build_r1a_model,
2526
build_resnet_block,
2627
build_resnet_block_with_downsample,
@@ -124,6 +125,35 @@ def test_resnet_residual_connection_with_downsample(tmp_path):
124125
_check_resnet_residual_connection(onnx_path)
125126

126127

128+
def test_convtranspose_conv_residual_int8(tmp_path):
129+
onnx_model = build_convtranspose_conv_residual_model()
130+
onnx_path = os.path.join(tmp_path, "convtranspose_conv_residual_model.onnx")
131+
save_onnx(onnx_model, onnx_path)
132+
133+
quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16")
134+
135+
# Output model should be produced in the same tmp_path
136+
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")
137+
138+
# Check that quantized explicit model is generated
139+
assert os.path.isfile(output_onnx_path)
140+
141+
# Load the output model and check QDQ node placements
142+
graph = gs.import_onnx(onnx.load(output_onnx_path))
143+
144+
# Check that Conv and ConvTransposed are quantized
145+
conv_nodes = [n for n in graph.nodes if "Conv" in n.op]
146+
assert _assert_nodes_are_quantized(conv_nodes)
147+
148+
# Check that only 1 input of Add is quantized
149+
add_nodes = [n for n in graph.nodes if n.op == "Add"]
150+
for node in add_nodes:
151+
quantized_inputs = [inp for inp in node.inputs if inp.inputs[0].op == "DequantizeLinear"]
152+
assert len(quantized_inputs) == 1, (
153+
f"More than one input of {node.name} is being quantized, but only one should be quantized!"
154+
)
155+
156+
127157
def test_conv_batchnorm_sig_mul_int8(tmp_path):
128158
onnx_model = build_conv_batchnorm_sig_mul_model()
129159
onnx_path = os.path.join(tmp_path, "conv_batchnorm_sig_mul_model.onnx")

tests/unit/onnx/test_quantize_int8.py

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,9 @@
1919
import onnx_graphsurgeon as gs
2020
import pytest
2121
import torch
22-
from _test_utils.onnx.quantization.lib_test_models import (
23-
SimpleMLP,
24-
build_convtranspose_conv_residual_model,
25-
export_as_onnx,
26-
)
22+
from _test_utils.onnx.quantization.lib_test_models import SimpleMLP, export_as_onnx
2723

2824
import modelopt.onnx.quantization as moq
29-
from modelopt.onnx.utils import save_onnx
3025

3126

3227
def assert_nodes_are_quantized(nodes):
@@ -60,32 +55,3 @@ def test_int8(tmp_path, high_precision_dtype):
6055
# Check that all MatMul nodes are quantized
6156
mm_nodes = [n for n in graph.nodes if n.op == "MatMul"]
6257
assert assert_nodes_are_quantized(mm_nodes)
63-
64-
65-
def test_convtranspose_conv_residual_int8(tmp_path):
66-
onnx_model = build_convtranspose_conv_residual_model()
67-
onnx_path = os.path.join(tmp_path, "convtranspose_conv_residual_model.onnx")
68-
save_onnx(onnx_model, onnx_path)
69-
70-
moq.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16")
71-
72-
# Output model should be produced in the same tmp_path
73-
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")
74-
75-
# Check that quantized explicit model is generated
76-
assert os.path.isfile(output_onnx_path)
77-
78-
# Load the output model and check QDQ node placements
79-
graph = gs.import_onnx(onnx.load(output_onnx_path))
80-
81-
# Check that Conv and ConvTransposed are quantized
82-
conv_nodes = [n for n in graph.nodes if "Conv" in n.op]
83-
assert assert_nodes_are_quantized(conv_nodes)
84-
85-
# Check that only 1 input of Add is quantized
86-
add_nodes = [n for n in graph.nodes if n.op == "Add"]
87-
for node in add_nodes:
88-
quantized_inputs = [inp for inp in node.inputs if inp.inputs[0].op == "DequantizeLinear"]
89-
assert len(quantized_inputs) == 1, (
90-
f"More than one input of {node.name} is being quantized, but only one should be quantized!"
91-
)

0 commit comments

Comments
 (0)