|
2 | 2 | # Licensed under the MIT license.
|
3 | 3 |
|
4 | 4 | """
|
5 |
| -tf2onnx.rewriter - rewrite tensorflow QuantizeAndDequantizeV3 op |
| 5 | +tf2onnx.rewriter - rewrite tensorflow QuantizeAndDequantizeV2|QuantizeAndDequantizeV3 op |
6 | 6 | """
|
7 | 7 |
|
8 | 8 | import numpy as np
|
@@ -32,47 +32,61 @@ def create_qdq_nodes(g, match_results):
|
32 | 32 | if not signed_input:
|
33 | 33 | min_quantized, max_quantized = [0, 255]
|
34 | 34 |
|
| 35 | + # Get axis attribute for per channel implementation. |
| 36 | + axis = qdq_node.attr['axis'].i |
| 37 | + |
35 | 38 | # Get the min and max value of the inputs to QDQ op
|
36 | 39 | min_value = extract_numpy_array(qdq_node.inputs[1])
|
37 | 40 | max_value = extract_numpy_array(qdq_node.inputs[2])
|
38 | 41 |
|
39 |
| - # Calculate scales from the min and max values |
40 |
| - scale_from_min_side = min_quantized/min_value if min_quantized*min_value > 0 else max_quantized |
41 |
| - scale_from_max_side = max_quantized/max_value if max_quantized*max_value > 0 else max_quantized |
| 42 | + num_channels = min_value.shape[0] |
| 43 | + scales = np.zeros(num_channels, dtype=np.float32) |
| 44 | + zero_point_dtype = np.int8 if signed_input else np.uint8 |
| 45 | + zero_point = np.zeros(num_channels, dtype=zero_point_dtype) |
| 46 | + |
| 47 | + for i in range(num_channels): |
| 48 | + # Calculate scales from the min and max values |
| 49 | + scale_from_min_side = min_quantized/min_value[i] if min_quantized*min_value[i] > 0 else max_quantized |
| 50 | + scale_from_max_side = max_quantized/max_value[i] if max_quantized*max_value[i] > 0 else max_quantized |
42 | 51 |
|
43 |
| - if scale_from_min_side < scale_from_max_side: |
44 |
| - scale = scale_from_min_side |
45 |
| - else: |
46 |
| - scale = scale_from_max_side |
| 52 | + if scale_from_min_side < scale_from_max_side: |
| 53 | + scale = scale_from_min_side |
| 54 | + else: |
| 55 | + scale = scale_from_max_side |
47 | 56 |
|
48 |
| - utils.make_sure(scale > 0, "Quantize/Dequantize scale must be greater than zero") |
| 57 | + utils.make_sure(scale > 0, "Quantize/Dequantize scale must be greater than zero") |
| 58 | + scales[i] = np.float32(scale) |
49 | 59 |
|
50 |
| - if signed_input: |
51 |
| - zero_point = np.int8(0) |
52 |
| - else: |
53 |
| - zero_point = np.uint8(0) |
| 60 | + # Set scalars for scale and zero point for per layer quantization |
| 61 | + if num_channels == 1: |
| 62 | + scales = scales[0] |
| 63 | + zero_point = zero_point[0] |
| 64 | + axis = np.int64(1) # Default value of axis |
54 | 65 |
|
55 | 66 | # Split it into QuantizeLinear and DequantizeLinear and remove the QDQ node reference
|
56 |
| - y_quant_scale = g.make_const(name=utils.make_name("y_quant_scale"), np_val=1/scale) |
| 67 | + inverse_scale = (1/scales).astype(np.float32) |
| 68 | + y_quant_scale = g.make_const(name=utils.make_name("y_quant_scale"), np_val=inverse_scale) |
57 | 69 | y_zero_point = g.make_const(name=utils.make_name("y_zero_point"), np_val=zero_point)
|
58 | 70 | quant_node = g.make_node(op_type="QuantizeLinear",
|
59 | 71 | inputs=[qdq_node.input[0], y_quant_scale.output[0],
|
60 | 72 | y_zero_point.output[0]],
|
61 | 73 | shapes=[qdq_node_output_shape],
|
| 74 | + attr={'axis': axis}, |
62 | 75 | dtypes=[qdq_node_output_dtype],
|
63 | 76 | name=utils.make_name("QuantLinearNode"))
|
64 | 77 |
|
65 | 78 | g.set_shape(quant_node.output[0], qdq_node_output_shape)
|
66 | 79 |
|
67 | 80 | g.remove_node(qdq_node.name)
|
68 | 81 |
|
69 |
| - y_dequant_scale = g.make_const(name=utils.make_name("y_dequant_scale"), np_val=1/scale) |
| 82 | + y_dequant_scale = g.make_const(name=utils.make_name("y_dequant_scale"), np_val=inverse_scale) |
70 | 83 | y_inv_zero_point = g.make_const(name=utils.make_name("y_inv_zero_point"), np_val=zero_point)
|
71 | 84 | dequant_node = g.make_node(op_type="DequantizeLinear",
|
72 | 85 | inputs=[quant_node.output[0], y_dequant_scale.output[0],
|
73 | 86 | y_inv_zero_point.output[0]],
|
74 | 87 | outputs=[qdq_node.output[0]],
|
75 | 88 | shapes=[qdq_node_output_shape],
|
| 89 | + attr={'axis': axis}, |
76 | 90 | dtypes=[qdq_node_output_dtype],
|
77 | 91 | name=utils.make_name("DequantLinearNode"))
|
78 | 92 | g.set_shape(dequant_node.output[0], qdq_node_output_shape)
|
|
0 commit comments