13
13
import sys
14
14
15
15
import numpy as np
16
- from onnx import onnx_pb
16
+ from onnx import onnx_pb , numpy_helper , TensorProto
17
17
from onnx .onnx_pb import TensorProto
18
18
19
19
from tf2onnx import constants , utils
@@ -37,12 +37,19 @@ def version_11(cls, ctx, node, **kwargs):
37
37
amax = node .get_attr ("max" ).f
38
38
narrow_range = node .get_attr ("narrow_range" ).i
39
39
num_bits = node .get_attr ("num_bits" ).i
40
-
40
+
41
41
if narrow_range :
42
42
raise RuntimeError (
43
43
"Unable to convert node FakeQuantWithMinMaxArgs with "
44
44
"narrow_range=%r" % narrow_range )
45
-
45
+ if num_bits != 8 :
46
+ raise RuntimeError (
47
+ "Unable to convert node FakeQuantWithMinMaxArgs with "
48
+ "num_bits=%r" % num_bits )
49
+
50
+ scale = (amax - amin ) / (2 ** num_bits - 1 )
51
+ min_adj = scale * int (amin / scale )
52
+ max_adj = amax + min_adj - amin
46
53
if 0 < amin < amax :
47
54
min_adj = 0
48
55
max_adj = amax - amin
@@ -62,18 +69,27 @@ def version_11(cls, ctx, node, **kwargs):
62
69
63
70
dtype = ctx .get_dtype (node .input [0 ])
64
71
shape = ctx .get_shape (node .input [0 ])
72
+ axis = 1
73
+ idtype = TensorProto .UINT8
74
+
75
+ pb_scale = ctx .make_const (
76
+ utils .make_name ("{}_scaley" .format (node .name )),
77
+ np .array (scale , dtype = np .float32 ))
78
+ zero_point = ctx .make_const (
79
+ utils .make_name ("{}_zpy" .format (node .name )),
80
+ np .array (min_adj , dtype = np .uint8 ))
65
81
66
82
new_node = ctx .make_node (
67
- "QuantizeLinear" , [node .input [0 ], pb_scale , y_zero_point ],
68
- op_name_scope = node .name , attr = {"axes " : [ axis ] },
83
+ "QuantizeLinear" , [node .input [0 ], pb_scale . name , zero_point . name ],
84
+ op_name_scope = node .name , attr = {"axis " : axis },
69
85
shapes = [shape ], dtypes = [idtype ])
70
86
output_name = new_node .output [0 ]
71
- node .input [i ] = output_name
87
+ node .input [0 ] = output_name
72
88
73
89
ctx .remove_node (node .name )
74
90
75
91
last_node = ctx .make_node (
76
- "DequantizeLinear" , [new_node .output [0 ], x_scale , x_zero_point ],
92
+ "DequantizeLinear" , [new_node .output [0 ], pb_scale . name , zero_point . name ],
77
93
op_name_scope = node .name , attr = {"axis" : axis },
78
94
shapes = [shape ], dtypes = [dtype ])
79
95
ctx .replace_all_inputs (ctx .get_nodes (), node .output [0 ], last_node .output [0 ])
0 commit comments