Skip to content

Commit 5ebd5d3

Browse files
committed
Moved unittest to qdq_rules script
Signed-off-by: gcunhase <[email protected]>
1 parent 77904b3 commit 5ebd5d3

File tree

2 files changed

+29
-27
lines changed

2 files changed

+29
-27
lines changed

tests/unit/onnx/test_qdq_rules_int8.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,17 @@
1818
import numpy as np
1919
import onnx
2020
import onnx_graphsurgeon as gs
21+
import pytest
2122
from _test_utils.onnx_quantization.lib_test_models import (
23+
build_conv_act_pool_model,
2224
build_r1a_model,
2325
build_resnet_block,
2426
build_resnet_block_with_downsample,
2527
export_as_onnx,
2628
)
2729

2830
from modelopt.onnx.quantization.quantize import quantize
31+
from modelopt.onnx.utils import save_onnx
2932

3033

3134
def _assert_nodes_are_quantized(nodes):
@@ -119,3 +122,29 @@ def test_resnet_residual_connection_with_downsample(tmp_path):
119122
onnx_path = os.path.join(tmp_path, "model.onnx")
120123
export_as_onnx(model_torch, input_tensor, onnx_filename=onnx_path)
121124
_check_resnet_residual_connection(onnx_path)
125+
126+
127+
@pytest.mark.parametrize("include_reshape_node", [False, True])
128+
def test_conv_act_pool_int8(tmp_path, include_reshape_node):
129+
onnx_model = build_conv_act_pool_model(include_reshape_node)
130+
onnx_path = os.path.join(tmp_path, f"conv_act_pool_model_{include_reshape_node}.onnx")
131+
save_onnx(onnx_model, onnx_path)
132+
133+
quantize.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16")
134+
135+
# Output model should be produced in the same tmp_path
136+
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")
137+
138+
# Check that quantized explicit model is generated
139+
assert os.path.isfile(output_onnx_path)
140+
141+
# Load the output model and check QDQ node placements
142+
graph = gs.import_onnx(onnx.load(output_onnx_path))
143+
144+
# Check that Conv is quantized
145+
conv_nodes = [n for n in graph.nodes if n.op == "Conv"]
146+
assert _assert_nodes_are_quantized(conv_nodes)
147+
148+
# Check that MaxPool is not quantized
149+
pool_nodes = [n for n in graph.nodes if n.op == "MaxPool"]
150+
assert _assert_nodes_are_not_quantized(pool_nodes)

tests/unit/onnx/test_quantize_int8.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import torch
2222
from _test_utils.onnx_quantization.lib_test_models import (
2323
SimpleMLP,
24-
build_conv_act_pool_model,
2524
build_convtranspose_conv_residual_model,
2625
export_as_onnx,
2726
)
@@ -95,29 +94,3 @@ def test_convtranspose_conv_residual_int8(tmp_path):
9594
assert len(quantized_inputs) == 1, (
9695
f"More than one input of {node.name} is being quantized, but only one should be quantized!"
9796
)
98-
99-
100-
@pytest.mark.parametrize("include_reshape_node", [False, True])
101-
def test_conv_act_pool_int8(tmp_path, include_reshape_node):
102-
onnx_model = build_conv_act_pool_model(include_reshape_node)
103-
onnx_path = os.path.join(tmp_path, f"conv_act_pool_model_{include_reshape_node}.onnx")
104-
save_onnx(onnx_model, onnx_path)
105-
106-
moq.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16")
107-
108-
# Output model should be produced in the same tmp_path
109-
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")
110-
111-
# Check that quantized explicit model is generated
112-
assert os.path.isfile(output_onnx_path)
113-
114-
# Load the output model and check QDQ node placements
115-
graph = gs.import_onnx(onnx.load(output_onnx_path))
116-
117-
# Check that Conv is quantized
118-
conv_nodes = [n for n in graph.nodes if n.op == "Conv"]
119-
assert _assert_nodes_quantization(conv_nodes)
120-
121-
# Check that MaxPool is not quantized
122-
pool_nodes = [n for n in graph.nodes if n.op == "MaxPool"]
123-
assert _assert_nodes_quantization(pool_nodes, should_be_quantized=False)

0 commit comments

Comments
 (0)