5
5
tf2onnx.rewriter - rewrite tensorflow QuantizeAndDequantizeV3 op
6
6
"""
7
7
8
- from tf2onnx .graph_matcher import OpTypePattern , GraphMatcher
9
-
10
8
import numpy as np
11
- import struct
12
-
13
-
14
- from tf2onnx .handler import tf_op
9
+ from tf2onnx .graph_matcher import OpTypePattern , GraphMatcher
15
10
from tf2onnx import utils
16
- from tf2onnx import constants
17
11
18
12
# pylint: disable=missing-docstring
19
13
20
14
def extract_numpy_array (node ):
21
15
return np .frombuffer (node .attr ["value" ].t .raw_data , dtype = "float32" )
22
16
23
17
def create_qdq_nodes (g , match_results ):
24
-
18
+
25
19
for match in match_results :
26
20
qdq_node = match .get_op ('output' )
27
21
qdq_node_output_dtype = g .get_dtype (qdq_node .output [0 ])
@@ -30,51 +24,63 @@ def create_qdq_nodes(g, match_results):
30
24
# Get the attributes of qdq node
31
25
narrow_range = qdq_node .attr ['narrow_range' ].i
32
26
signed_input = qdq_node .attr ['signed_input' ].i
33
-
34
- min_quantized , max_quantized = [- 127 ,127 ]
27
+
28
+ min_quantized , max_quantized = [- 127 , 127 ]
35
29
if not narrow_range and signed_input :
36
30
min_quantized = - 128
37
-
31
+
38
32
if not signed_input :
39
- min_quantized , max_quantized = [0 ,255 ]
40
-
33
+ min_quantized , max_quantized = [0 , 255 ]
34
+
41
35
# Get the min and max value of the inputs to QDQ op
42
36
min_value = extract_numpy_array (qdq_node .inputs [1 ])
43
37
max_value = extract_numpy_array (qdq_node .inputs [2 ])
44
-
38
+
45
39
# Calculate scales from the min and max values
46
40
scale_from_min_side = min_quantized / min_value if min_quantized * min_value > 0 else max_quantized
47
41
scale_from_max_side = max_quantized / max_value if max_quantized * max_value > 0 else max_quantized
48
-
42
+
49
43
if scale_from_min_side < scale_from_max_side :
50
44
scale = scale_from_min_side
51
45
else :
52
46
scale = scale_from_max_side
53
-
47
+
54
48
assert scale > 0
55
-
49
+
56
50
if signed_input :
57
51
zero_point = np .int8 (0 )
58
52
else :
59
53
zero_point = np .uint8 (0 )
60
-
54
+
61
55
# Split it into QuantizeLinear and DequantizeLinear and remove the QDQ node reference
62
- y_quant_scale = g .make_const (name = utils .make_name ("y_quant_scale" ), np_val = 1 / scale )
56
+ y_quant_scale = g .make_const (name = utils .make_name ("y_quant_scale" ), np_val = 1 / scale )
63
57
y_zero_point = g .make_const (name = utils .make_name ("y_zero_point" ), np_val = zero_point )
64
- quant_node = g .make_node (op_type = "QuantizeLinear" , inputs = [qdq_node .input [0 ], y_quant_scale .output [0 ], y_zero_point .output [0 ]], shapes = [qdq_node_output_shape ], dtypes = [qdq_node_output_dtype ], name = utils .make_name ("QuantLinearNode" ))
58
+ quant_node = g .make_node (op_type = "QuantizeLinear" ,
59
+ inputs = [qdq_node .input [0 ], y_quant_scale .output [0 ],
60
+ y_zero_point .output [0 ]],
61
+ shapes = [qdq_node_output_shape ],
62
+ dtypes = [qdq_node_output_dtype ],
63
+ name = utils .make_name ("QuantLinearNode" ))
64
+
65
65
g .set_shape (quant_node .output [0 ], qdq_node_output_shape )
66
-
66
+
67
67
g .remove_node (qdq_node .name )
68
68
69
- y_dequant_scale = g .make_const (name = utils .make_name ("y_dequant_scale" ), np_val = 1 / scale )
69
+ y_dequant_scale = g .make_const (name = utils .make_name ("y_dequant_scale" ), np_val = 1 / scale )
70
70
y_inv_zero_point = g .make_const (name = utils .make_name ("y_inv_zero_point" ), np_val = zero_point )
71
- dequant_node = g .make_node (op_type = "DequantizeLinear" , inputs = [quant_node .output [0 ], y_dequant_scale .output [0 ], y_inv_zero_point .output [0 ]], outputs = [qdq_node .output [0 ]], shapes = [qdq_node_output_shape ], dtypes = [qdq_node_output_dtype ], name = utils .make_name ("DequantLinearNode" ))
71
+ dequant_node = g .make_node (op_type = "DequantizeLinear" ,
72
+ inputs = [quant_node .output [0 ], y_dequant_scale .output [0 ],
73
+ y_inv_zero_point .output [0 ]],
74
+ outputs = [qdq_node .output [0 ]],
75
+ shapes = [qdq_node_output_shape ],
76
+ dtypes = [qdq_node_output_dtype ],
77
+ name = utils .make_name ("DequantLinearNode" ))
72
78
g .set_shape (dequant_node .output [0 ], qdq_node_output_shape )
73
-
79
+
74
80
return g .get_nodes ()
75
81
76
82
def rewrite_quantize_and_dequantize (g , ops ):
77
-
83
+
78
84
pattern_for_qdq_v2 = \
79
85
OpTypePattern ('QuantizeAndDequantizeV2' , name = 'output' , inputs = [
80
86
OpTypePattern ("*" ),
@@ -88,7 +94,7 @@ def rewrite_quantize_and_dequantize(g, ops):
88
94
OpTypePattern (None ),
89
95
OpTypePattern (None ),
90
96
])
91
-
97
+
92
98
# Match all the patterns for QDQ ops
93
99
patterns = [pattern_for_qdq_v3 , pattern_for_qdq_v2 ]
94
100
match_results = []
@@ -97,4 +103,4 @@ def rewrite_quantize_and_dequantize(g, ops):
97
103
results = list (matcher .match_ops (ops ))
98
104
match_results .extend (results )
99
105
100
- return create_qdq_nodes (g , match_results )
106
+ return create_qdq_nodes (g , match_results )
0 commit comments