Skip to content

Commit 45aad6d

Browse files
committed
Moved unittest to qdq_rules script
Signed-off-by: gcunhase <[email protected]>
1 parent 4ed1713 commit 45aad6d

File tree

2 files changed

+28
-27
lines changed

2 files changed

+28
-27
lines changed

tests/unit/onnx/test_qdq_rules_int8.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
import numpy as np
1919
import onnx
2020
import onnx_graphsurgeon as gs
21+
import pytest
2122
from _test_utils.onnx_quantization.lib_test_models import (
23+
build_conv_act_pool_model,
2224
build_conv_batchnorm_sig_mul_model,
2325
build_r1a_model,
2426
build_resnet_block,
@@ -150,3 +152,29 @@ def test_conv_batchnorm_sig_mul_int8(tmp_path):
150152
assert len(quantized_inputs) == 1, (
151153
f"More than one input of {node.name} is being quantized, but only one should be quantized!"
152154
)
155+
156+
157+
@pytest.mark.parametrize("include_reshape_node", [False, True])
158+
def test_conv_act_pool_int8(tmp_path, include_reshape_node):
159+
onnx_model = build_conv_act_pool_model(include_reshape_node)
160+
onnx_path = os.path.join(tmp_path, f"conv_act_pool_model_{include_reshape_node}.onnx")
161+
save_onnx(onnx_model, onnx_path)
162+
163+
quantize.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16")
164+
165+
# Output model should be produced in the same tmp_path
166+
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")
167+
168+
# Check that quantized explicit model is generated
169+
assert os.path.isfile(output_onnx_path)
170+
171+
# Load the output model and check QDQ node placements
172+
graph = gs.import_onnx(onnx.load(output_onnx_path))
173+
174+
# Check that Conv is quantized
175+
conv_nodes = [n for n in graph.nodes if n.op == "Conv"]
176+
assert _assert_nodes_are_quantized(conv_nodes)
177+
178+
# Check that MaxPool is not quantized
179+
pool_nodes = [n for n in graph.nodes if n.op == "MaxPool"]
180+
assert _assert_nodes_are_not_quantized(pool_nodes)

tests/unit/onnx/test_quantize_int8.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import torch
2222
from _test_utils.onnx_quantization.lib_test_models import (
2323
SimpleMLP,
24-
build_conv_act_pool_model,
2524
build_convtranspose_conv_residual_model,
2625
export_as_onnx,
2726
)
@@ -95,29 +94,3 @@ def test_convtranspose_conv_residual_int8(tmp_path):
9594
assert len(quantized_inputs) == 1, (
9695
f"More than one input of {node.name} is being quantized, but only one should be quantized!"
9796
)
98-
99-
100-
@pytest.mark.parametrize("include_reshape_node", [False, True])
101-
def test_conv_act_pool_int8(tmp_path, include_reshape_node):
102-
onnx_model = build_conv_act_pool_model(include_reshape_node)
103-
onnx_path = os.path.join(tmp_path, f"conv_act_pool_model_{include_reshape_node}.onnx")
104-
save_onnx(onnx_model, onnx_path)
105-
106-
moq.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16")
107-
108-
# Output model should be produced in the same tmp_path
109-
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")
110-
111-
# Check that quantized explicit model is generated
112-
assert os.path.isfile(output_onnx_path)
113-
114-
# Load the output model and check QDQ node placements
115-
graph = gs.import_onnx(onnx.load(output_onnx_path))
116-
117-
# Check that Conv is quantized
118-
conv_nodes = [n for n in graph.nodes if n.op == "Conv"]
119-
assert _assert_nodes_quantization(conv_nodes)
120-
121-
# Check that MaxPool is not quantized
122-
pool_nodes = [n for n in graph.nodes if n.op == "MaxPool"]
123-
assert _assert_nodes_quantization(pool_nodes, should_be_quantized=False)

0 commit comments

Comments
 (0)