2626# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2727# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2828
29+ import pytest
30+
2931import numpy as np
3032import onnx .parser as oprs
3133
3436from qonnx .transformation .extract_quant_scale_zeropt import ExtractQuantScaleZeroPt
3537
3638
37- def make_test_model ():
38- ishp = (1 , 10 )
39+ def make_test_model (ishp , channelwise , bitwidth , need_extraction_scale , need_extraction_zeropt ):
3940 ishp_str = str (list (ishp ))
40- channelwise = True
41- bitwidth = np .asarray (4.0 , dtype = np .float32 )
4241 if channelwise :
4342 q_attr_shp = ishp
4443 else :
45- q_attr_shp = 1
44+ q_attr_shp = ( 1 ,)
4645 attrshp_str = str (list (q_attr_shp ))
4746 np .random .seed (0 )
48- scale = np .random .rand (* q_attr_shp ).astype (np .float32 )
49- zeropt = np .random .rand (* q_attr_shp ).astype (np .float32 )
47+ if need_extraction_scale :
48+ scale = np .random .rand (* q_attr_shp ).astype (np .float32 )
49+ else :
50+ scale = np .ones (q_attr_shp , dtype = np .float32 )
51+ if need_extraction_zeropt :
52+ zeropt = np .random .rand (* q_attr_shp ).astype (np .float32 )
53+ else :
54+ zeropt = np .zeros (q_attr_shp , dtype = np .float32 )
5055 signed = 1
5156 narrow = 1
5257 rounding_mode = "ROUND"
@@ -78,8 +83,13 @@ def make_test_model():
7883 return model
7984
8085
81- def test_extract_quant_scale_zeropt ():
82- model = make_test_model ()
86+ @pytest .mark .parametrize ("need_extraction_scale" , [True , False ])
87+ @pytest .mark .parametrize ("need_extraction_zeropt" , [True , False ])
88+ @pytest .mark .parametrize ("channelwise" , [True , False ])
89+ def test_extract_quant_scale_zeropt (channelwise , need_extraction_scale , need_extraction_zeropt ):
90+ ishp = (1 , 10 )
91+ bitwidth = np .asarray (4.0 , dtype = np .float32 )
92+ model = make_test_model (ishp , channelwise , bitwidth , need_extraction_scale , need_extraction_zeropt )
8393 ishp = model .get_tensor_shape ("in0" )
8494 inp = np .random .rand (* ishp ).astype (np .float32 )
8595 y_golden = execute_onnx (model , {"in0" : inp })["out0" ]
@@ -88,6 +98,12 @@ def test_extract_quant_scale_zeropt():
8898 assert np .allclose (y_golden , y_ret )
8999 qnt_node = model_new .get_nodes_by_op_type ("Quant" )[0 ]
90100 new_scale = model_new .get_initializer (qnt_node .input [1 ])
91- assert new_scale == 1
101+ assert ( new_scale == 1 ). all ()
92102 new_zeropt = model_new .get_initializer (qnt_node .input [2 ])
93- assert new_zeropt == 0
103+ assert (new_zeropt == 0 ).all ()
104+ if need_extraction_scale :
105+ assert len (model_new .get_nodes_by_op_type ("Mul" )) == 1
106+ assert len (model_new .get_nodes_by_op_type ("Div" )) == 1
107+ if need_extraction_zeropt :
108+ assert len (model_new .get_nodes_by_op_type ("Add" )) == 1
109+ assert len (model_new .get_nodes_by_op_type ("Sub" )) == 1
0 commit comments