Skip to content

Commit bf5c8fb

Browse files
committed
Moved unittest to qdq_rules script
Signed-off-by: gcunhase <[email protected]>
1 parent 3f8c84c commit bf5c8fb

File tree

2 files changed

+27
-27
lines changed

2 files changed

+27
-27
lines changed

tests/unit/onnx/test_qdq_rules_int8.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import onnx
2020
import onnx_graphsurgeon as gs
2121
from _test_utils.onnx.quantization.lib_test_models import (
22+
build_conv_act_pool_model,
2223
build_conv_batchnorm_sig_mul_model,
2324
build_r1a_model,
2425
build_resnet_block,
@@ -150,3 +151,29 @@ def test_conv_batchnorm_sig_mul_int8(tmp_path):
150151
assert len(quantized_inputs) == 1, (
151152
f"More than one input of {node.name} is being quantized, but only one should be quantized!"
152153
)
154+
155+
156+
@pytest.mark.parametrize("include_reshape_node", [False, True])
157+
def test_conv_act_pool_int8(tmp_path, include_reshape_node):
158+
onnx_model = build_conv_act_pool_model(include_reshape_node)
159+
onnx_path = os.path.join(tmp_path, f"conv_act_pool_model_{include_reshape_node}.onnx")
160+
save_onnx(onnx_model, onnx_path)
161+
162+
quantize.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16")
163+
164+
# Output model should be produced in the same tmp_path
165+
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")
166+
167+
# Check that quantized explicit model is generated
168+
assert os.path.isfile(output_onnx_path)
169+
170+
# Load the output model and check QDQ node placements
171+
graph = gs.import_onnx(onnx.load(output_onnx_path))
172+
173+
# Check that Conv is quantized
174+
conv_nodes = [n for n in graph.nodes if n.op == "Conv"]
175+
assert _assert_nodes_are_quantized(conv_nodes)
176+
177+
# Check that MaxPool is not quantized
178+
pool_nodes = [n for n in graph.nodes if n.op == "MaxPool"]
179+
assert _assert_nodes_are_not_quantized(pool_nodes)

tests/unit/onnx/test_quantize_int8.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import torch
2222
from _test_utils.onnx.quantization.lib_test_models import (
2323
SimpleMLP,
24-
build_conv_act_pool_model,
2524
build_convtranspose_conv_residual_model,
2625
export_as_onnx,
2726
)
@@ -95,29 +94,3 @@ def test_convtranspose_conv_residual_int8(tmp_path):
9594
assert len(quantized_inputs) == 1, (
9695
f"More than one input of {node.name} is being quantized, but only one should be quantized!"
9796
)
98-
99-
100-
@pytest.mark.parametrize("include_reshape_node", [False, True])
101-
def test_conv_act_pool_int8(tmp_path, include_reshape_node):
102-
onnx_model = build_conv_act_pool_model(include_reshape_node)
103-
onnx_path = os.path.join(tmp_path, f"conv_act_pool_model_{include_reshape_node}.onnx")
104-
save_onnx(onnx_model, onnx_path)
105-
106-
moq.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16")
107-
108-
# Output model should be produced in the same tmp_path
109-
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")
110-
111-
# Check that quantized explicit model is generated
112-
assert os.path.isfile(output_onnx_path)
113-
114-
# Load the output model and check QDQ node placements
115-
graph = gs.import_onnx(onnx.load(output_onnx_path))
116-
117-
# Check that Conv is quantized
118-
conv_nodes = [n for n in graph.nodes if n.op == "Conv"]
119-
assert _assert_nodes_quantization(conv_nodes)
120-
121-
# Check that MaxPool is not quantized
122-
pool_nodes = [n for n in graph.nodes if n.op == "MaxPool"]
123-
assert _assert_nodes_quantization(pool_nodes, should_be_quantized=False)

0 commit comments

Comments
 (0)