Skip to content

Commit d4b77a9

Browse files
committed
Renamed assert_nodes function
Signed-off-by: gcunhase <[email protected]>
1 parent c5a2f7e commit d4b77a9

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

tests/unit/onnx/test_qdq_rules_int8.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import numpy as np
1919
import onnx
2020
import onnx_graphsurgeon as gs
21+
import pytest
2122
from _test_utils.onnx.quantization.lib_test_models import (
2223
build_conv_act_pool_model,
2324
build_conv_batchnorm_sig_mul_model,
@@ -42,7 +43,7 @@ def assert_nodes_are_quantized(nodes):
4243
return True
4344

4445

45-
def _assert_nodes_are_not_quantized(nodes):
46+
def assert_nodes_are_not_quantized(nodes):
4647
for node in nodes:
4748
for inp_idx, inp in enumerate(node.inputs):
4849
if isinstance(inp, gs.Variable) and inp.inputs:
@@ -78,7 +79,7 @@ def test_bias_add_rule(tmp_path):
7879
other_nodes = [
7980
n for n in graph.nodes if n.op not in ["Conv", "QuantizeLinear", "DequantizeLinear"]
8081
]
81-
assert _assert_nodes_are_not_quantized(other_nodes)
82+
assert assert_nodes_are_not_quantized(other_nodes)
8283

8384

8485
def _check_resnet_residual_connection(onnx_path):
@@ -108,7 +109,7 @@ def _check_resnet_residual_connection(onnx_path):
108109
other_nodes = [
109110
n for n in graph.nodes if n.op not in ["Conv", "Add", "QuantizeLinear", "DequantizeLinear"]
110111
]
111-
assert _assert_nodes_are_not_quantized(other_nodes)
112+
assert assert_nodes_are_not_quantized(other_nodes)
112113

113114

114115
def test_resnet_residual_connections(tmp_path):
@@ -143,7 +144,7 @@ def test_convtranspose_conv_residual_int8(tmp_path):
143144

144145
# Check that Conv and ConvTransposed are quantized
145146
conv_nodes = [n for n in graph.nodes if "Conv" in n.op]
146-
assert _assert_nodes_are_quantized(conv_nodes)
147+
assert assert_nodes_are_quantized(conv_nodes)
147148

148149
# Check that only 1 input of Add is quantized
149150
add_nodes = [n for n in graph.nodes if n.op == "Add"]
@@ -202,8 +203,8 @@ def test_conv_act_pool_int8(tmp_path, include_reshape_node):
202203

203204
# Check that Conv is quantized
204205
conv_nodes = [n for n in graph.nodes if n.op == "Conv"]
205-
assert _assert_nodes_are_quantized(conv_nodes)
206+
assert assert_nodes_are_quantized(conv_nodes)
206207

207208
# Check that MaxPool is not quantized
208209
pool_nodes = [n for n in graph.nodes if n.op == "MaxPool"]
209-
assert _assert_nodes_are_not_quantized(pool_nodes)
210+
assert assert_nodes_are_not_quantized(pool_nodes)

0 commit comments

Comments
 (0)