Skip to content

Commit 010829c

Browse files
committed
Moved unittest to qdq_rules script
Signed-off-by: gcunhase <[email protected]>
1 parent 57c440b commit 010829c

File tree

2 files changed

+31
-30
lines changed

2 files changed

+31
-30
lines changed

tests/unit/onnx/test_qdq_rules_int8.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919
import onnx
2020
import onnx_graphsurgeon as gs
2121
from _test_utils.onnx_quantization.lib_test_models import (
22+
build_conv_batchnorm_sig_mul_model,
2223
build_r1a_model,
2324
build_resnet_block,
2425
build_resnet_block_with_downsample,
2526
export_as_onnx,
2627
)
2728

2829
from modelopt.onnx.quantization.quantize import quantize
30+
from modelopt.onnx.utils import save_onnx
2931

3032

3133
def _assert_nodes_are_quantized(nodes):
@@ -119,3 +121,32 @@ def test_resnet_residual_connection_with_downsample(tmp_path):
119121
onnx_path = os.path.join(tmp_path, "model.onnx")
120122
export_as_onnx(model_torch, input_tensor, onnx_filename=onnx_path)
121123
_check_resnet_residual_connection(onnx_path)
124+
125+
126+
def test_conv_batchnorm_sig_mul_int8(tmp_path):
127+
onnx_model = build_conv_batchnorm_sig_mul_model()
128+
onnx_path = os.path.join(tmp_path, "conv_batchnorm_sig_mul_model.onnx")
129+
save_onnx(onnx_model, onnx_path)
130+
131+
quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16")
132+
133+
# Output model should be produced in the same tmp_path
134+
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")
135+
136+
# Check that quantized explicit model is generated
137+
assert os.path.isfile(output_onnx_path)
138+
139+
# Load the output model and check QDQ node placements
140+
graph = gs.import_onnx(onnx.load(output_onnx_path))
141+
142+
# Check that Conv and ConvTransposed are quantized
143+
conv_nodes = [n for n in graph.nodes if "Conv" in n.op]
144+
assert _assert_nodes_are_quantized(conv_nodes)
145+
146+
# Check that only 1 input of Add is quantized
147+
add_nodes = [n for n in graph.nodes if n.op == "Add"]
148+
for node in add_nodes:
149+
quantized_inputs = [inp for inp in node.inputs if inp.inputs[0].op == "DequantizeLinear"]
150+
assert len(quantized_inputs) == 1, (
151+
f"More than one input of {node.name} is being quantized, but only one should be quantized!"
152+
)

tests/unit/onnx/test_quantize_int8.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import torch
2222
from _test_utils.onnx_quantization.lib_test_models import (
2323
SimpleMLP,
24-
build_conv_batchnorm_sig_mul_model,
2524
build_convtranspose_conv_residual_model,
2625
export_as_onnx,
2726
)
@@ -90,32 +89,3 @@ def test_convtranspose_conv_residual_int8(tmp_path):
9089
assert len(quantized_inputs) == 1, (
9190
f"More than one input of {node.name} is being quantized, but only one should be quantized!"
9291
)
93-
94-
95-
def test_conv_batchnorm_sig_mul_int8(tmp_path="./"):
96-
onnx_model = build_conv_batchnorm_sig_mul_model()
97-
onnx_path = os.path.join(tmp_path, "conv_batchnorm_sig_mul_model.onnx")
98-
save_onnx(onnx_model, onnx_path)
99-
100-
moq.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16")
101-
102-
# Output model should be produced in the same tmp_path
103-
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")
104-
105-
# Check that quantized explicit model is generated
106-
assert os.path.isfile(output_onnx_path)
107-
108-
# Load the output model and check QDQ node placements
109-
graph = gs.import_onnx(onnx.load(output_onnx_path))
110-
111-
# Check that Conv and ConvTransposed are quantized
112-
conv_nodes = [n for n in graph.nodes if "Conv" in n.op]
113-
assert _assert_nodes_are_quantized(conv_nodes)
114-
115-
# Check that only 1 input of Add is quantized
116-
add_nodes = [n for n in graph.nodes if n.op == "Add"]
117-
for node in add_nodes:
118-
quantized_inputs = [inp for inp in node.inputs if inp.inputs[0].op == "DequantizeLinear"]
119-
assert len(quantized_inputs) == 1, (
120-
f"More than one input of {node.name} is being quantized, but only one should be quantized!"
121-
)

0 commit comments

Comments
 (0)