|
2 | 2 | # Licensed under the MIT license.
|
3 | 3 |
|
4 | 4 | """
|
5 |
| -tf2onnx.rewriter - rewrite tensorflow QuantizeAndDequantizeV3 op |
| 5 | +tf2onnx.rewriter - rewrite tensorflow QuantizeAndDequantizeV2|QuantizeAndDequantizeV3 op |
6 | 6 | """
|
7 | 7 |
|
8 | 8 | import numpy as np
|
@@ -32,47 +32,65 @@ def create_qdq_nodes(g, match_results):
|
32 | 32 | if not signed_input:
|
33 | 33 | min_quantized, max_quantized = [0, 255]
|
34 | 34 |
|
| 35 | + # Get axis attribute for per channel implementation. |
| 36 | + if 'axis' in qdq_node.attr: |
| 37 | + axis = qdq_node.attr['axis'].i |
| 38 | + |
35 | 39 | # Get the min and max value of the inputs to QDQ op
|
36 | 40 | min_value = extract_numpy_array(qdq_node.inputs[1])
|
37 | 41 | max_value = extract_numpy_array(qdq_node.inputs[2])
|
38 | 42 |
|
39 |
| - # Calculate scales from the min and max values |
40 |
| - scale_from_min_side = min_quantized/min_value if min_quantized*min_value > 0 else max_quantized |
41 |
| - scale_from_max_side = max_quantized/max_value if max_quantized*max_value > 0 else max_quantized |
42 |
| - |
43 |
| - if scale_from_min_side < scale_from_max_side: |
44 |
| - scale = scale_from_min_side |
45 |
| - else: |
46 |
| - scale = scale_from_max_side |
47 |
| - |
48 |
| - utils.make_sure(scale > 0, "Quantize/Dequantize scale must be greater than zero") |
49 |
| - |
50 |
| - if signed_input: |
51 |
| - zero_point = np.int8(0) |
| 43 | + num_channels = min_value.shape[0] |
| 44 | + scales = np.zeros(num_channels, dtype=np.float32) |
| 45 | + zero_point_dtype = np.int8 if signed_input else np.uint8 |
| 46 | + zero_point = np.zeros(num_channels, dtype=zero_point_dtype) |
| 47 | + |
| 48 | + for i in range(num_channels): |
| 49 | + # Calculate scales from the min and max values |
| 50 | + scale_from_min_side = min_quantized/min_value[i] if min_quantized*min_value[i] > 0 else max_quantized |
| 51 | + scale_from_max_side = max_quantized/max_value[i] if max_quantized*max_value[i] > 0 else max_quantized |
| 52 | + |
| 53 | + if scale_from_min_side < scale_from_max_side: |
| 54 | + scale = scale_from_min_side |
| 55 | + else: |
| 56 | + scale = scale_from_max_side |
| 57 | + |
| 58 | + utils.make_sure(scale > 0, "Quantize/Dequantize scale must be greater than zero") |
| 59 | + scales[i] = np.float32(scale) |
| 60 | + |
| 61 | + # Set scalars for scale and zero point for per layer quantization |
| 62 | + if num_channels == 1: |
| 63 | + scales = scales[0] |
| 64 | + zero_point = zero_point[0] |
| 65 | + attrs = {} |
52 | 66 | else:
|
53 |
| - zero_point = np.uint8(0) |
| 67 | + utils.make_sure(axis, "Axis must be specified for per channel quantization") |
| 68 | + attrs = {'axis': axis} |
54 | 69 |
|
55 | 70 | # Split it into QuantizeLinear and DequantizeLinear and remove the QDQ node reference
|
56 |
| - y_quant_scale = g.make_const(name=utils.make_name("y_quant_scale"), np_val=1/scale) |
| 71 | + inverse_scale = (1/scales).astype(np.float32) |
| 72 | + y_quant_scale = g.make_const(name=utils.make_name("y_quant_scale"), np_val=inverse_scale) |
57 | 73 | y_zero_point = g.make_const(name=utils.make_name("y_zero_point"), np_val=zero_point)
|
58 | 74 | quant_node = g.make_node(op_type="QuantizeLinear",
|
59 | 75 | inputs=[qdq_node.input[0], y_quant_scale.output[0],
|
60 | 76 | y_zero_point.output[0]],
|
61 | 77 | shapes=[qdq_node_output_shape],
|
| 78 | + attr=attrs, |
62 | 79 | dtypes=[qdq_node_output_dtype],
|
63 | 80 | name=utils.make_name("QuantLinearNode"))
|
64 | 81 |
|
65 | 82 | g.set_shape(quant_node.output[0], qdq_node_output_shape)
|
66 | 83 |
|
67 | 84 | g.remove_node(qdq_node.name)
|
68 | 85 |
|
69 |
| - y_dequant_scale = g.make_const(name=utils.make_name("y_dequant_scale"), np_val=1/scale) |
| 86 | + y_dequant_scale = g.make_const(name=utils.make_name("y_dequant_scale"), np_val=inverse_scale) |
70 | 87 | y_inv_zero_point = g.make_const(name=utils.make_name("y_inv_zero_point"), np_val=zero_point)
|
71 | 88 | dequant_node = g.make_node(op_type="DequantizeLinear",
|
72 | 89 | inputs=[quant_node.output[0], y_dequant_scale.output[0],
|
73 | 90 | y_inv_zero_point.output[0]],
|
74 | 91 | outputs=[qdq_node.output[0]],
|
75 | 92 | shapes=[qdq_node_output_shape],
|
| 93 | + attr=attrs, |
76 | 94 | dtypes=[qdq_node_output_dtype],
|
77 | 95 | name=utils.make_name("DequantLinearNode"))
|
78 | 96 | g.set_shape(dequant_node.output[0], qdq_node_output_shape)
|
|
0 commit comments