Skip to content

Commit 58305e2

Browse files
committed
Add functionality for QDQ per channel
1 parent 6ec695b commit 58305e2

File tree

2 files changed

+42
-15
lines changed

2 files changed

+42
-15
lines changed

tests/test_backend.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2043,6 +2043,19 @@ def func(x):
20432043
x_ = quantize_and_dequantize(x, -6.0, 6.0, signed_input=True, narrow_range=False, range_given=True)
20442044
return tf.identity(x_, name=_TFOUTPUT)
20452045
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
2046+
2047+
@check_tf_min_version("2.0")
2048+
@check_opset_min_version(13, "quantize_and_dequantize")
2049+
def test_qdq_per_channel_signed_input(self):
2050+
x_shape = [3, 3, 2]
2051+
x_val = np.arange(-np.prod(x_shape)/2, np.prod(x_shape)/2).astype("float32").reshape(x_shape)
2052+
def func(x):
2053+
x_ = quantize_and_dequantize(x, np.array([-1.72, -3.89]).astype(np.float32), \
2054+
np.array([5.12, 2.36]).astype(np.float32), \
2055+
signed_input=True, narrow_range=False, \
2056+
range_given=True, axis=-1)
2057+
return tf.identity(x_, name=_TFOUTPUT)
2058+
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
20462059

20472060
@skip_caffe2_backend()
20482061
@check_opset_min_version(7, "resize_nearest_neighbor")

tf2onnx/rewriter/quantization_ops_rewriter.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Licensed under the MIT license.
33

44
"""
5-
tf2onnx.rewriter - rewrite tensorflow QuantizeAndDequantizeV3 op
5+
tf2onnx.rewriter - rewrite tensorflow QuantizeAndDequantizeV2|QuantizeAndDequantizeV3 op
66
"""
77

88
import numpy as np
@@ -32,47 +32,61 @@ def create_qdq_nodes(g, match_results):
3232
if not signed_input:
3333
min_quantized, max_quantized = [0, 255]
3434

35+
# Get axis attribute for per channel implementation.
36+
axis = qdq_node.attr['axis'].i
37+
3538
# Get the min and max value of the inputs to QDQ op
3639
min_value = extract_numpy_array(qdq_node.inputs[1])
3740
max_value = extract_numpy_array(qdq_node.inputs[2])
3841

39-
# Calculate scales from the min and max values
40-
scale_from_min_side = min_quantized/min_value if min_quantized*min_value > 0 else max_quantized
41-
scale_from_max_side = max_quantized/max_value if max_quantized*max_value > 0 else max_quantized
42+
num_channels = min_value.shape[0]
43+
scales = np.zeros(num_channels, dtype=np.float32)
44+
zero_point_dtype = np.int8 if signed_input else np.uint8
45+
zero_point = np.zeros(num_channels, dtype=zero_point_dtype)
46+
47+
for i in range(num_channels):
48+
# Calculate scales from the min and max values
49+
scale_from_min_side = min_quantized/min_value[i] if min_quantized*min_value[i] > 0 else max_quantized
50+
scale_from_max_side = max_quantized/max_value[i] if max_quantized*max_value[i] > 0 else max_quantized
4251

43-
if scale_from_min_side < scale_from_max_side:
44-
scale = scale_from_min_side
45-
else:
46-
scale = scale_from_max_side
52+
if scale_from_min_side < scale_from_max_side:
53+
scale = scale_from_min_side
54+
else:
55+
scale = scale_from_max_side
4756

48-
utils.make_sure(scale > 0, "Quantize/Dequantize scale must be greater than zero")
57+
utils.make_sure(scale > 0, "Quantize/Dequantize scale must be greater than zero")
58+
scales[i] = np.float32(scale)
4959

50-
if signed_input:
51-
zero_point = np.int8(0)
52-
else:
53-
zero_point = np.uint8(0)
60+
# Set scalars for scale and zero point for per layer quantization
61+
if num_channels == 1:
62+
scales = scales[0]
63+
zero_point = zero_point[0]
64+
axis = np.int64(1) # Default value of axis
5465

5566
# Split it into QuantizeLinear and DequantizeLinear and remove the QDQ node reference
56-
y_quant_scale = g.make_const(name=utils.make_name("y_quant_scale"), np_val=1/scale)
67+
inverse_scale = (1/scales).astype(np.float32)
68+
y_quant_scale = g.make_const(name=utils.make_name("y_quant_scale"), np_val=inverse_scale)
5769
y_zero_point = g.make_const(name=utils.make_name("y_zero_point"), np_val=zero_point)
5870
quant_node = g.make_node(op_type="QuantizeLinear",
5971
inputs=[qdq_node.input[0], y_quant_scale.output[0],
6072
y_zero_point.output[0]],
6173
shapes=[qdq_node_output_shape],
74+
attr={'axis': axis},
6275
dtypes=[qdq_node_output_dtype],
6376
name=utils.make_name("QuantLinearNode"))
6477

6578
g.set_shape(quant_node.output[0], qdq_node_output_shape)
6679

6780
g.remove_node(qdq_node.name)
6881

69-
y_dequant_scale = g.make_const(name=utils.make_name("y_dequant_scale"), np_val=1/scale)
82+
y_dequant_scale = g.make_const(name=utils.make_name("y_dequant_scale"), np_val=inverse_scale)
7083
y_inv_zero_point = g.make_const(name=utils.make_name("y_inv_zero_point"), np_val=zero_point)
7184
dequant_node = g.make_node(op_type="DequantizeLinear",
7285
inputs=[quant_node.output[0], y_dequant_scale.output[0],
7386
y_inv_zero_point.output[0]],
7487
outputs=[qdq_node.output[0]],
7588
shapes=[qdq_node_output_shape],
89+
attr={'axis': axis},
7690
dtypes=[qdq_node_output_dtype],
7791
name=utils.make_name("DequantLinearNode"))
7892
g.set_shape(dequant_node.output[0], qdq_node_output_shape)

0 commit comments

Comments
 (0)