Skip to content

Commit 68d7b88

Browse files
authored
Merge pull request #919 from peri044/qdq_rewriter
Support for QuantizeAndDequantize operation
2 parents dd85f83 + 1996a56 commit 68d7b88

File tree

4 files changed

+131
-1
lines changed

4 files changed

+131
-1
lines changed

tests/test_backend.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
fused_batch_norm = tf.compat.v1.nn.fused_batch_norm
6363
dropout = tf.compat.v1.nn.dropout
6464
resize_nearest_neighbor = tf.compat.v1.image.resize_nearest_neighbor
65+
quantize_and_dequantize = tf.quantization.quantize_and_dequantize
6566
resize_bilinear = tf.compat.v1.image.resize_bilinear
6667
is_nan = tf.math.is_nan
6768
is_inf = tf.math.is_inf
@@ -77,6 +78,7 @@
7778
random_uniform = tf.compat.v1.random_uniform
7879
fused_batch_norm = tf.compat.v1.nn.fused_batch_norm
7980
dropout = tf.compat.v1.nn.dropout
81+
quantize_and_dequantize = tf.compat.v1.quantization.quantize_and_dequantize
8082
resize_nearest_neighbor = tf.compat.v1.image.resize_nearest_neighbor
8183
resize_bilinear = tf.compat.v1.image.resize_bilinear
8284
is_nan = tf.math.is_nan
@@ -1916,6 +1918,26 @@ def graph_validator(g):
19161918

19171919
self._run_test_case(func_fusedbn, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05, graph_validator=graph_validator)
19181920

1921+
@check_tf_min_version("1.15")
1922+
@check_opset_min_version(10, "quantize_and_dequantize")
1923+
def test_qdq_unsigned_input(self):
1924+
x_shape = [3, 3, 2]
1925+
x_val = np.arange(1, 1+np.prod(x_shape)).astype("float32").reshape(x_shape)
1926+
def func(x):
1927+
x_ = quantize_and_dequantize(x, 1.0, 6.0, signed_input=False, range_given=True)
1928+
return tf.identity(x_, name=_TFOUTPUT)
1929+
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
1930+
1931+
@check_tf_min_version("1.15")
1932+
@check_opset_min_version(10, "quantize_and_dequantize")
1933+
def test_qdq_signed_input(self):
1934+
x_shape = [3, 3, 2]
1935+
x_val = np.arange(-np.prod(x_shape)/2, np.prod(x_shape)/2).astype("float32").reshape(x_shape)
1936+
def func(x):
1937+
x_ = quantize_and_dequantize(x, -6.0, 6.0, signed_input=True, narrow_range=True, range_given=True)
1938+
return tf.identity(x_, name=_TFOUTPUT)
1939+
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
1940+
19191941
@skip_caffe2_backend()
19201942
@check_opset_min_version(7, "resize_nearest_neighbor")
19211943
def test_resize_nearest_neighbor(self):

tf2onnx/rewriter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from tf2onnx.rewriter.thresholded_relu_rewriter import rewrite_thresholded_relu
2222
from tf2onnx.rewriter.transpose_rewriter import rewrite_transpose
2323
from tf2onnx.rewriter.conv2d_with_add_rewriter import rewrite_biasadd_with_conv2d
24+
from tf2onnx.rewriter.quantization_ops_rewriter import rewrite_quantize_and_dequantize
2425

2526

2627
__all__ = [
@@ -43,4 +44,5 @@
4344
"rewrite_custom_rnn_cell",
4445
"rewrite_generic_loop",
4546
"rewrite_biasadd_with_conv2d",
47+
"rewrite_quantize_and_dequantize"
4648
]
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tf2onnx.rewriter - rewrite tensorflow QuantizeAndDequantizeV3 op
6+
"""
7+
8+
import numpy as np
9+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
10+
from tf2onnx import utils
11+
12+
# pylint: disable=missing-docstring
13+
14+
def extract_numpy_array(node):
15+
return np.frombuffer(node.attr["value"].t.raw_data, dtype="float32")
16+
17+
def create_qdq_nodes(g, match_results):
18+
19+
for match in match_results:
20+
qdq_node = match.get_op('output')
21+
qdq_node_output_dtype = g.get_dtype(qdq_node.output[0])
22+
qdq_node_output_shape = g.get_shape(qdq_node.output[0])
23+
24+
# Get the attributes of qdq node
25+
narrow_range = qdq_node.attr['narrow_range'].i
26+
signed_input = qdq_node.attr['signed_input'].i
27+
28+
min_quantized, max_quantized = [-127, 127]
29+
if not narrow_range and signed_input:
30+
min_quantized = -128
31+
32+
if not signed_input:
33+
min_quantized, max_quantized = [0, 255]
34+
35+
# Get the min and max value of the inputs to QDQ op
36+
min_value = extract_numpy_array(qdq_node.inputs[1])
37+
max_value = extract_numpy_array(qdq_node.inputs[2])
38+
39+
# Calculate scales from the min and max values
40+
scale_from_min_side = min_quantized/min_value if min_quantized*min_value > 0 else max_quantized
41+
scale_from_max_side = max_quantized/max_value if max_quantized*max_value > 0 else max_quantized
42+
43+
if scale_from_min_side < scale_from_max_side:
44+
scale = scale_from_min_side
45+
else:
46+
scale = scale_from_max_side
47+
48+
utils.make_sure(scale > 0, "Quantize/Dequantize scale must be greater than zero")
49+
50+
if signed_input:
51+
zero_point = np.int8(0)
52+
else:
53+
zero_point = np.uint8(0)
54+
55+
# Split it into QuantizeLinear and DequantizeLinear and remove the QDQ node reference
56+
y_quant_scale = g.make_const(name=utils.make_name("y_quant_scale"), np_val=1/scale)
57+
y_zero_point = g.make_const(name=utils.make_name("y_zero_point"), np_val=zero_point)
58+
quant_node = g.make_node(op_type="QuantizeLinear",
59+
inputs=[qdq_node.input[0], y_quant_scale.output[0],
60+
y_zero_point.output[0]],
61+
shapes=[qdq_node_output_shape],
62+
dtypes=[qdq_node_output_dtype],
63+
name=utils.make_name("QuantLinearNode"))
64+
65+
g.set_shape(quant_node.output[0], qdq_node_output_shape)
66+
67+
g.remove_node(qdq_node.name)
68+
69+
y_dequant_scale = g.make_const(name=utils.make_name("y_dequant_scale"), np_val=1/scale)
70+
y_inv_zero_point = g.make_const(name=utils.make_name("y_inv_zero_point"), np_val=zero_point)
71+
dequant_node = g.make_node(op_type="DequantizeLinear",
72+
inputs=[quant_node.output[0], y_dequant_scale.output[0],
73+
y_inv_zero_point.output[0]],
74+
outputs=[qdq_node.output[0]],
75+
shapes=[qdq_node_output_shape],
76+
dtypes=[qdq_node_output_dtype],
77+
name=utils.make_name("DequantLinearNode"))
78+
g.set_shape(dequant_node.output[0], qdq_node_output_shape)
79+
80+
return g.get_nodes()
81+
82+
def rewrite_quantize_and_dequantize(g, ops):
83+
84+
pattern_for_qdq_v2 = \
85+
OpTypePattern('QuantizeAndDequantizeV2', name='output', inputs=[
86+
OpTypePattern("*"),
87+
OpTypePattern(None),
88+
OpTypePattern(None),
89+
])
90+
pattern_for_qdq_v3 = \
91+
OpTypePattern('QuantizeAndDequantizeV3', name='output', inputs=[
92+
OpTypePattern("*"),
93+
OpTypePattern(None),
94+
OpTypePattern(None),
95+
OpTypePattern(None),
96+
])
97+
98+
# Match all the patterns for QDQ ops
99+
patterns = [pattern_for_qdq_v3, pattern_for_qdq_v2]
100+
match_results = []
101+
for pattern in patterns:
102+
matcher = GraphMatcher(pattern)
103+
results = list(matcher.match_ops(ops))
104+
match_results.extend(results)
105+
106+
return create_qdq_nodes(g, match_results)

tf2onnx/tfonnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def compat_handler(ctx, node, **kwargs):
459459

460460
# pre-processing graph rewrites
461461
# bi-directional re-writer should be placed after single directional re-writer
462-
rewriters = [rewrite_transpose, rewrite_flatten, rewrite_gemm,
462+
rewriters = [rewrite_quantize_and_dequantize, rewrite_transpose, rewrite_flatten, rewrite_gemm,
463463
rewrite_random_uniform, rewrite_random_uniform_fold_const,
464464
rewrite_random_normal, rewrite_dropout, rewrite_eye,
465465
rewrite_leakyrelu, rewrite_thresholded_relu, rewrite_conv2d_with_pad,

0 commit comments

Comments
 (0)