1818import numpy as np
1919import onnx
2020import onnx_graphsurgeon as gs
21+ import pytest
2122from _test_utils .onnx .quantization .lib_test_models import (
2223 build_conv_act_pool_model ,
2324 build_conv_batchnorm_sig_mul_model ,
@@ -42,7 +43,7 @@ def assert_nodes_are_quantized(nodes):
4243 return True
4344
4445
45- def _assert_nodes_are_not_quantized (nodes ):
46+ def assert_nodes_are_not_quantized (nodes ):
4647 for node in nodes :
4748 for inp_idx , inp in enumerate (node .inputs ):
4849 if isinstance (inp , gs .Variable ) and inp .inputs :
@@ -78,7 +79,7 @@ def test_bias_add_rule(tmp_path):
7879 other_nodes = [
7980 n for n in graph .nodes if n .op not in ["Conv" , "QuantizeLinear" , "DequantizeLinear" ]
8081 ]
81- assert _assert_nodes_are_not_quantized (other_nodes )
82+ assert assert_nodes_are_not_quantized (other_nodes )
8283
8384
8485def _check_resnet_residual_connection (onnx_path ):
@@ -108,7 +109,7 @@ def _check_resnet_residual_connection(onnx_path):
108109 other_nodes = [
109110 n for n in graph .nodes if n .op not in ["Conv" , "Add" , "QuantizeLinear" , "DequantizeLinear" ]
110111 ]
111- assert _assert_nodes_are_not_quantized (other_nodes )
112+ assert assert_nodes_are_not_quantized (other_nodes )
112113
113114
114115def test_resnet_residual_connections (tmp_path ):
@@ -143,7 +144,7 @@ def test_convtranspose_conv_residual_int8(tmp_path):
143144
144145 # Check that Conv and ConvTransposed are quantized
145146 conv_nodes = [n for n in graph .nodes if "Conv" in n .op ]
146- assert _assert_nodes_are_quantized (conv_nodes )
147+ assert assert_nodes_are_quantized (conv_nodes )
147148
148149 # Check that only 1 input of Add is quantized
149150 add_nodes = [n for n in graph .nodes if n .op == "Add" ]
@@ -202,8 +203,8 @@ def test_conv_act_pool_int8(tmp_path, include_reshape_node):
202203
203204 # Check that Conv is quantized
204205 conv_nodes = [n for n in graph .nodes if n .op == "Conv" ]
205- assert _assert_nodes_are_quantized (conv_nodes )
206+ assert assert_nodes_are_quantized (conv_nodes )
206207
207208 # Check that MaxPool is not quantized
208209 pool_nodes = [n for n in graph .nodes if n .op == "MaxPool" ]
209- assert _assert_nodes_are_not_quantized (pool_nodes )
210+ assert assert_nodes_are_not_quantized (pool_nodes )
0 commit comments