Skip to content

Commit 0b15fe1

Browse files
authored
Merge pull request #1081 from peri044/qdq_per_channel
Add functionality for QDQ per channel
2 parents 6ec695b + b18e633 commit 0b15fe1

File tree

2 files changed

+48
-17
lines changed

2 files changed

+48
-17
lines changed

tests/test_backend.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2044,6 +2044,19 @@ def func(x):
20442044
return tf.identity(x_, name=_TFOUTPUT)
20452045
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
20462046

2047+
@check_tf_min_version("2.0")
2048+
@check_opset_min_version(13, "quantize_and_dequantize")
2049+
def test_qdq_per_channel_signed_input(self):
2050+
x_shape = [3, 3, 2]
2051+
x_val = np.arange(-np.prod(x_shape)/2, np.prod(x_shape)/2).astype("float32").reshape(x_shape)
2052+
def func(x):
2053+
x_ = quantize_and_dequantize(x, np.array([-1.72, -3.89]).astype(np.float32), \
2054+
np.array([5.12, 2.36]).astype(np.float32), \
2055+
signed_input=True, narrow_range=False, \
2056+
range_given=True, axis=-1)
2057+
return tf.identity(x_, name=_TFOUTPUT)
2058+
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
2059+
20472060
@skip_caffe2_backend()
20482061
@check_opset_min_version(7, "resize_nearest_neighbor")
20492062
def test_resize_nearest_neighbor(self):

tf2onnx/rewriter/quantization_ops_rewriter.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Licensed under the MIT license.
33

44
"""
5-
tf2onnx.rewriter - rewrite tensorflow QuantizeAndDequantizeV3 op
5+
tf2onnx.rewriter - rewrite tensorflow QuantizeAndDequantizeV2|QuantizeAndDequantizeV3 op
66
"""
77

88
import numpy as np
@@ -32,47 +32,65 @@ def create_qdq_nodes(g, match_results):
3232
if not signed_input:
3333
min_quantized, max_quantized = [0, 255]
3434

35+
# Get axis attribute for per channel implementation.
36+
if 'axis' in qdq_node.attr:
37+
axis = qdq_node.attr['axis'].i
38+
3539
# Get the min and max value of the inputs to QDQ op
3640
min_value = extract_numpy_array(qdq_node.inputs[1])
3741
max_value = extract_numpy_array(qdq_node.inputs[2])
3842

39-
# Calculate scales from the min and max values
40-
scale_from_min_side = min_quantized/min_value if min_quantized*min_value > 0 else max_quantized
41-
scale_from_max_side = max_quantized/max_value if max_quantized*max_value > 0 else max_quantized
42-
43-
if scale_from_min_side < scale_from_max_side:
44-
scale = scale_from_min_side
45-
else:
46-
scale = scale_from_max_side
47-
48-
utils.make_sure(scale > 0, "Quantize/Dequantize scale must be greater than zero")
49-
50-
if signed_input:
51-
zero_point = np.int8(0)
43+
num_channels = min_value.shape[0]
44+
scales = np.zeros(num_channels, dtype=np.float32)
45+
zero_point_dtype = np.int8 if signed_input else np.uint8
46+
zero_point = np.zeros(num_channels, dtype=zero_point_dtype)
47+
48+
for i in range(num_channels):
49+
# Calculate scales from the min and max values
50+
scale_from_min_side = min_quantized/min_value[i] if min_quantized*min_value[i] > 0 else max_quantized
51+
scale_from_max_side = max_quantized/max_value[i] if max_quantized*max_value[i] > 0 else max_quantized
52+
53+
if scale_from_min_side < scale_from_max_side:
54+
scale = scale_from_min_side
55+
else:
56+
scale = scale_from_max_side
57+
58+
utils.make_sure(scale > 0, "Quantize/Dequantize scale must be greater than zero")
59+
scales[i] = np.float32(scale)
60+
61+
# Set scalars for scale and zero point for per layer quantization
62+
if num_channels == 1:
63+
scales = scales[0]
64+
zero_point = zero_point[0]
65+
attrs = {}
5266
else:
53-
zero_point = np.uint8(0)
67+
utils.make_sure(axis, "Axis must be specified for per channel quantization")
68+
attrs = {'axis': axis}
5469

5570
# Split it into QuantizeLinear and DequantizeLinear and remove the QDQ node reference
56-
y_quant_scale = g.make_const(name=utils.make_name("y_quant_scale"), np_val=1/scale)
71+
inverse_scale = (1/scales).astype(np.float32)
72+
y_quant_scale = g.make_const(name=utils.make_name("y_quant_scale"), np_val=inverse_scale)
5773
y_zero_point = g.make_const(name=utils.make_name("y_zero_point"), np_val=zero_point)
5874
quant_node = g.make_node(op_type="QuantizeLinear",
5975
inputs=[qdq_node.input[0], y_quant_scale.output[0],
6076
y_zero_point.output[0]],
6177
shapes=[qdq_node_output_shape],
78+
attr=attrs,
6279
dtypes=[qdq_node_output_dtype],
6380
name=utils.make_name("QuantLinearNode"))
6481

6582
g.set_shape(quant_node.output[0], qdq_node_output_shape)
6683

6784
g.remove_node(qdq_node.name)
6885

69-
y_dequant_scale = g.make_const(name=utils.make_name("y_dequant_scale"), np_val=1/scale)
86+
y_dequant_scale = g.make_const(name=utils.make_name("y_dequant_scale"), np_val=inverse_scale)
7087
y_inv_zero_point = g.make_const(name=utils.make_name("y_inv_zero_point"), np_val=zero_point)
7188
dequant_node = g.make_node(op_type="DequantizeLinear",
7289
inputs=[quant_node.output[0], y_dequant_scale.output[0],
7390
y_inv_zero_point.output[0]],
7491
outputs=[qdq_node.output[0]],
7592
shapes=[qdq_node_output_shape],
93+
attr=attrs,
7694
dtypes=[qdq_node_output_dtype],
7795
name=utils.make_name("DequantLinearNode"))
7896
g.set_shape(dequant_node.output[0], qdq_node_output_shape)

0 commit comments

Comments
 (0)