1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT license.
3
+
4
+ """
5
+ tf2onnx.rewriter - rewrite tensorflow QuantizeAndDequantizeV3 op
6
+ """
7
+
8
+ from tf2onnx .graph_matcher import OpTypePattern , GraphMatcher
9
+
10
+ import numpy as np
11
+ import struct
12
+
13
+
14
+ from tf2onnx .handler import tf_op
15
+ from tf2onnx import utils
16
+ from tf2onnx import constants
17
+
18
+ # pylint: disable=missing-docstring
19
+
20
+ def extract_numpy_array (node ):
21
+ return np .frombuffer (node .attr ["value" ].t .raw_data , dtype = "float32" )
22
+
23
+ def create_qdq_nodes (g , match_results ):
24
+
25
+ for match in match_results :
26
+ qdq_node = match .get_op ('output' )
27
+ qdq_node_output_dtype = g .get_dtype (qdq_node .output [0 ])
28
+ qdq_node_output_shape = g .get_shape (qdq_node .output [0 ])
29
+
30
+ # Get the attributes of qdq node
31
+ narrow_range = qdq_node .attr ['narrow_range' ].i
32
+ signed_input = qdq_node .attr ['signed_input' ].i
33
+
34
+ min_quantized , max_quantized = [- 127 ,127 ]
35
+ if not narrow_range and signed_input :
36
+ min_quantized = - 128
37
+
38
+ if not signed_input :
39
+ min_quantized , max_quantized = [0 ,255 ]
40
+
41
+ # Get the min and max value of the inputs to QDQ op
42
+ min_value = extract_numpy_array (qdq_node .inputs [1 ])
43
+ max_value = extract_numpy_array (qdq_node .inputs [2 ])
44
+
45
+ # Calculate scales from the min and max values
46
+ scale_from_min_side = min_quantized / min_value if min_quantized * min_value > 0 else max_quantized
47
+ scale_from_max_side = max_quantized / max_value if max_quantized * max_value > 0 else max_quantized
48
+
49
+ if scale_from_min_side < scale_from_max_side :
50
+ scale = scale_from_min_side
51
+ else :
52
+ scale = scale_from_max_side
53
+
54
+ assert scale > 0
55
+
56
+ if signed_input :
57
+ zero_point = np .int8 (0 )
58
+ else :
59
+ zero_point = np .uint8 (0 )
60
+
61
+ # Split it into QuantizeLinear and DequantizeLinear and remove the QDQ node reference
62
+ y_quant_scale = g .make_const (name = utils .make_name ("y_quant_scale" ), np_val = 1 / scale )
63
+ y_zero_point = g .make_const (name = utils .make_name ("y_zero_point" ), np_val = zero_point )
64
+ quant_node = g .make_node (op_type = "QuantizeLinear" , inputs = [qdq_node .input [0 ], y_quant_scale .output [0 ], y_zero_point .output [0 ]], shapes = [qdq_node_output_shape ], dtypes = [qdq_node_output_dtype ], name = utils .make_name ("QuantLinearNode" ))
65
+ g .set_shape (quant_node .output [0 ], qdq_node_output_shape )
66
+
67
+ g .remove_node (qdq_node .name )
68
+
69
+ y_dequant_scale = g .make_const (name = utils .make_name ("y_dequant_scale" ), np_val = 1 / scale )
70
+ y_inv_zero_point = g .make_const (name = utils .make_name ("y_inv_zero_point" ), np_val = zero_point )
71
+ dequant_node = g .make_node (op_type = "DequantizeLinear" , inputs = [quant_node .output [0 ], y_dequant_scale .output [0 ], y_inv_zero_point .output [0 ]], outputs = [qdq_node .output [0 ]], shapes = [qdq_node_output_shape ], dtypes = [qdq_node_output_dtype ], name = utils .make_name ("DequantLinearNode" ))
72
+ g .set_shape (dequant_node .output [0 ], qdq_node_output_shape )
73
+
74
+ return g .get_nodes ()
75
+
76
+ def rewrite_quantize_and_dequantize (g , ops ):
77
+
78
+ pattern_for_qdq_v2 = \
79
+ OpTypePattern ('QuantizeAndDequantizeV2' , name = 'output' , inputs = [
80
+ OpTypePattern ("*" ),
81
+ OpTypePattern (None ),
82
+ OpTypePattern (None ),
83
+ ])
84
+ pattern_for_qdq_v3 = \
85
+ OpTypePattern ('QuantizeAndDequantizeV3' , name = 'output' , inputs = [
86
+ OpTypePattern ("*" ),
87
+ OpTypePattern (None ),
88
+ OpTypePattern (None ),
89
+ OpTypePattern (None ),
90
+ ])
91
+
92
+ # Match all the patterns for QDQ ops
93
+ patterns = [pattern_for_qdq_v3 , pattern_for_qdq_v2 ]
94
+ match_results = []
95
+ for pattern in patterns :
96
+ matcher = GraphMatcher (pattern )
97
+ results = list (matcher .match_ops (ops ))
98
+ match_results .extend (results )
99
+
100
+ return create_qdq_nodes (g , match_results )
0 commit comments