Skip to content

Commit 9f19751

Browse files
committed
Support for QuantizeAndDequantize operation
1 parent 3383ff9 commit 9f19751

File tree

4 files changed

+123
-2
lines changed

4 files changed

+123
-2
lines changed

tests/test_backend.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
random_uniform = tf.compat.v1.random_uniform
7878
fused_batch_norm = tf.compat.v1.nn.fused_batch_norm
7979
dropout = tf.compat.v1.nn.dropout
80+
quantize_and_dequantize = tf.compat.v1.quantization.quantize_and_dequantize
8081
resize_nearest_neighbor = tf.compat.v1.image.resize_nearest_neighbor
8182
resize_bilinear = tf.compat.v1.image.resize_bilinear
8283
is_nan = tf.math.is_nan
@@ -1915,7 +1916,25 @@ def graph_validator(g):
19151916
return True
19161917

19171918
self._run_test_case(func_fusedbn, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05, graph_validator=graph_validator)
1918-
1919+
1920+
@check_opset_min_version(10, "quantize_and_dequantize")
1921+
def test_qdq_unsigned_input(self):
1922+
x_shape = [3, 3, 2]
1923+
x_val = np.arange(1, 1+np.prod(x_shape)).astype("float32").reshape(x_shape)
1924+
def func(x):
1925+
x_ = quantize_and_dequantize(x, 1.0, 6.0, signed_input=False, narrow_range=False, range_given=True)
1926+
return tf.identity(x_, name=_TFOUTPUT)
1927+
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
1928+
1929+
@check_opset_min_version(10, "quantize_and_dequantize")
1930+
def test_qdq_signed_input(self):
1931+
x_shape = [3, 3, 2]
1932+
x_val = np.arange(-np.prod(x_shape)/2, np.prod(x_shape)/2).astype("float32").reshape(x_shape)
1933+
def func(x):
1934+
x_ = quantize_and_dequantize(x, -6.0, 6.0, signed_input=True, narrow_range=True, range_given=True)
1935+
return tf.identity(x_, name=_TFOUTPUT)
1936+
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
1937+
19191938
@skip_caffe2_backend()
19201939
@check_opset_min_version(7, "resize_nearest_neighbor")
19211940
def test_resize_nearest_neighbor(self):

tf2onnx/rewriter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from tf2onnx.rewriter.thresholded_relu_rewriter import rewrite_thresholded_relu
2222
from tf2onnx.rewriter.transpose_rewriter import rewrite_transpose
2323
from tf2onnx.rewriter.conv2d_with_add_rewriter import rewrite_biasadd_with_conv2d
24+
from tf2onnx.rewriter.quantization_ops_rewriter import rewrite_quantize_and_dequantize
2425

2526

2627
__all__ = [
@@ -43,4 +44,5 @@
4344
"rewrite_custom_rnn_cell",
4445
"rewrite_generic_loop",
4546
"rewrite_biasadd_with_conv2d",
47+
"rewrite_quantize_and_dequantize"
4648
]
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tf2onnx.rewriter - rewrite tensorflow QuantizeAndDequantizeV3 op
6+
"""
7+
8+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
9+
10+
import numpy as np
11+
import struct
12+
13+
14+
from tf2onnx.handler import tf_op
15+
from tf2onnx import utils
16+
from tf2onnx import constants
17+
18+
# pylint: disable=missing-docstring
19+
20+
def extract_numpy_array(node):
21+
return np.frombuffer(node.attr["value"].t.raw_data, dtype="float32")
22+
23+
def create_qdq_nodes(g, match_results):
24+
25+
for match in match_results:
26+
qdq_node = match.get_op('output')
27+
qdq_node_output_dtype = g.get_dtype(qdq_node.output[0])
28+
qdq_node_output_shape = g.get_shape(qdq_node.output[0])
29+
30+
# Get the attributes of qdq node
31+
narrow_range = qdq_node.attr['narrow_range'].i
32+
signed_input = qdq_node.attr['signed_input'].i
33+
34+
min_quantized, max_quantized = [-127,127]
35+
if not narrow_range and signed_input:
36+
min_quantized = -128
37+
38+
if not signed_input:
39+
min_quantized, max_quantized = [0,255]
40+
41+
# Get the min and max value of the inputs to QDQ op
42+
min_value = extract_numpy_array(qdq_node.inputs[1])
43+
max_value = extract_numpy_array(qdq_node.inputs[2])
44+
45+
# Calculate scales from the min and max values
46+
scale_from_min_side = min_quantized/min_value if min_quantized*min_value > 0 else max_quantized
47+
scale_from_max_side = max_quantized/max_value if max_quantized*max_value > 0 else max_quantized
48+
49+
if scale_from_min_side < scale_from_max_side:
50+
scale = scale_from_min_side
51+
else:
52+
scale = scale_from_max_side
53+
54+
assert scale > 0
55+
56+
if signed_input:
57+
zero_point = np.int8(0)
58+
else:
59+
zero_point = np.uint8(0)
60+
61+
# Split it into QuantizeLinear and DequantizeLinear and remove the QDQ node reference
62+
y_quant_scale = g.make_const(name=utils.make_name("y_quant_scale"), np_val = 1/scale)
63+
y_zero_point = g.make_const(name=utils.make_name("y_zero_point"), np_val=zero_point)
64+
quant_node = g.make_node(op_type = "QuantizeLinear", inputs=[qdq_node.input[0], y_quant_scale.output[0], y_zero_point.output[0]], shapes=[qdq_node_output_shape], dtypes=[qdq_node_output_dtype], name=utils.make_name("QuantLinearNode"))
65+
g.set_shape(quant_node.output[0], qdq_node_output_shape)
66+
67+
g.remove_node(qdq_node.name)
68+
69+
y_dequant_scale = g.make_const(name=utils.make_name("y_dequant_scale"), np_val = 1/scale)
70+
y_inv_zero_point = g.make_const(name=utils.make_name("y_inv_zero_point"), np_val=zero_point)
71+
dequant_node = g.make_node(op_type = "DequantizeLinear", inputs=[quant_node.output[0], y_dequant_scale.output[0], y_inv_zero_point.output[0]], outputs = [qdq_node.output[0]], shapes=[qdq_node_output_shape], dtypes=[qdq_node_output_dtype], name=utils.make_name("DequantLinearNode"))
72+
g.set_shape(dequant_node.output[0], qdq_node_output_shape)
73+
74+
return g.get_nodes()
75+
76+
def rewrite_quantize_and_dequantize(g, ops):
77+
78+
pattern_for_qdq_v2 = \
79+
OpTypePattern('QuantizeAndDequantizeV2', name='output', inputs=[
80+
OpTypePattern("*"),
81+
OpTypePattern(None),
82+
OpTypePattern(None),
83+
])
84+
pattern_for_qdq_v3 = \
85+
OpTypePattern('QuantizeAndDequantizeV3', name='output', inputs=[
86+
OpTypePattern("*"),
87+
OpTypePattern(None),
88+
OpTypePattern(None),
89+
OpTypePattern(None),
90+
])
91+
92+
# Match all the patterns for QDQ ops
93+
patterns = [pattern_for_qdq_v3, pattern_for_qdq_v2]
94+
match_results = []
95+
for pattern in patterns:
96+
matcher = GraphMatcher(pattern)
97+
results = list(matcher.match_ops(ops))
98+
match_results.extend(results)
99+
100+
return create_qdq_nodes(g, match_results)

tf2onnx/tfonnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def compat_handler(ctx, node, **kwargs):
459459

460460
# pre-processing graph rewrites
461461
# bi-directional re-writer should be placed after single directional re-writer
462-
rewriters = [rewrite_transpose, rewrite_flatten, rewrite_gemm,
462+
rewriters = [rewrite_quantize_and_dequantize, rewrite_transpose, rewrite_flatten, rewrite_gemm,
463463
rewrite_random_uniform, rewrite_random_uniform_fold_const,
464464
rewrite_random_normal, rewrite_dropout, rewrite_eye,
465465
rewrite_leakyrelu, rewrite_thresholded_relu, rewrite_conv2d_with_pad,

0 commit comments

Comments
 (0)