Skip to content

Commit 49a9b2d

Browse files
committed
Merge remote-tracking branch 'upstream/master' into jignparm/activate_opset12
2 parents a989a0b + f333ce5 commit 49a9b2d

File tree

7 files changed

+149
-17
lines changed

7 files changed

+149
-17
lines changed

tests/test_backend.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
fused_batch_norm = tf.compat.v1.nn.fused_batch_norm
6363
dropout = tf.compat.v1.nn.dropout
6464
resize_nearest_neighbor = tf.compat.v1.image.resize_nearest_neighbor
65+
quantize_and_dequantize = tf.quantization.quantize_and_dequantize
6566
resize_bilinear = tf.compat.v1.image.resize_bilinear
6667
is_nan = tf.math.is_nan
6768
is_inf = tf.math.is_inf
@@ -77,6 +78,7 @@
7778
random_uniform = tf.compat.v1.random_uniform
7879
fused_batch_norm = tf.compat.v1.nn.fused_batch_norm
7980
dropout = tf.compat.v1.nn.dropout
81+
quantize_and_dequantize = tf.compat.v1.quantization.quantize_and_dequantize
8082
resize_nearest_neighbor = tf.compat.v1.image.resize_nearest_neighbor
8183
resize_bilinear = tf.compat.v1.image.resize_bilinear
8284
is_nan = tf.math.is_nan
@@ -1916,6 +1918,26 @@ def graph_validator(g):
19161918

19171919
self._run_test_case(func_fusedbn, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05, graph_validator=graph_validator)
19181920

1921+
@check_tf_min_version("1.15")
1922+
@check_opset_min_version(10, "quantize_and_dequantize")
1923+
def test_qdq_unsigned_input(self):
1924+
x_shape = [3, 3, 2]
1925+
x_val = np.arange(1, 1+np.prod(x_shape)).astype("float32").reshape(x_shape)
1926+
def func(x):
1927+
x_ = quantize_and_dequantize(x, 1.0, 6.0, signed_input=False, range_given=True)
1928+
return tf.identity(x_, name=_TFOUTPUT)
1929+
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
1930+
1931+
@check_tf_min_version("1.15")
1932+
@check_opset_min_version(10, "quantize_and_dequantize")
1933+
def test_qdq_signed_input(self):
1934+
x_shape = [3, 3, 2]
1935+
x_val = np.arange(-np.prod(x_shape)/2, np.prod(x_shape)/2).astype("float32").reshape(x_shape)
1936+
def func(x):
1937+
x_ = quantize_and_dequantize(x, -6.0, 6.0, signed_input=True, narrow_range=True, range_given=True)
1938+
return tf.identity(x_, name=_TFOUTPUT)
1939+
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
1940+
19191941
@skip_caffe2_backend()
19201942
@check_opset_min_version(7, "resize_nearest_neighbor")
19211943
def test_resize_nearest_neighbor(self):

tf2onnx/graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1067,7 +1067,9 @@ def make_onnx_graph_io(self, ids):
10671067
shape = self.get_shape(name)
10681068

10691069
utils.make_sure(dtype is not None, "missing output dtype for " + name)
1070-
utils.make_sure(shape is not None, "missing output shape for " + name)
1070+
# TODO: allow None output shape or not? e.g. shape=(?,)
1071+
#utils.make_sure(shape is not None, "missing output shape for " + name)
1072+
if shape is None: logger.warning("missing output shape for %s", name)
10711073

10721074
v = utils.make_onnx_inputs_outputs(name, dtype, shape)
10731075
tensor_value_infos.append(v)

tf2onnx/onnx_opset/tensor.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ def version_1(cls, ctx, node, **kwargs):
196196
shape = ctx.get_shape(node.input[0])
197197
utils.make_sure(shape is not None, "squeeze input shape cannot be None")
198198
axis = [i for i, j in enumerate(shape) if j == 1]
199+
if not axis: axis = [0]
199200
node.set_attr("axes", axis)
200201

201202
@classmethod
@@ -1772,9 +1773,12 @@ class MatrixDiagPart:
17721773
def version_11(cls, ctx, node, **kwargs):
17731774
# MatrixDiagPart by slice and gather
17741775
const_zero = ctx.make_const(utils.make_name(node.name) + 'const_zero', np.array([0]).astype(np.int64))
1776+
const_zero_ = ctx.make_const(utils.make_name(node.name) + 'const_zero_', np.array(0).astype(np.int64))
1777+
17751778
const_zero_zero = ctx.make_const(utils.make_name(node.name) + 'const_zero_zero',
17761779
np.array([0, 0]).astype(np.int64))
17771780
const_one = ctx.make_const(utils.make_name(node.name) + 'const_one', np.array([1]).astype(np.int64))
1781+
const_one_ = ctx.make_const(utils.make_name(node.name) + 'const_one_', np.array(1).astype(np.int64))
17781782
const_two = ctx.make_const(utils.make_name(node.name) + 'const_two', np.array([2]).astype(np.int64))
17791783
const_negative_one = ctx.make_const(utils.make_name(node.name) + 'const_negative_one',
17801784
np.array([-1]).astype(np.int64))
@@ -1802,7 +1806,9 @@ def version_11(cls, ctx, node, **kwargs):
18021806
const_negative_one.output[0]])
18031807
sliced_input_shape_new = ctx.make_node('Concat', [sliced_input_shape_half.output[0], const_one.output[0]],
18041808
attr={'axis': -1})
1805-
matrice_range = ctx.make_node('Range', [const_zero.output[0], min_matrice_dim.output[0], const_one.output[0]])
1809+
min_matrice_dim_ = ctx.make_node('Squeeze', [min_matrice_dim.output[0]], {'axes': [0]})
1810+
matrice_range = ctx.make_node('Range', [const_zero_.output[0], min_matrice_dim_.output[0],
1811+
const_one_.output[0]])
18061812
unsqueezed_matrice_range = ctx.make_node('Unsqueeze', [matrice_range.output[0]], attr={"axes": [-1]})
18071813
expanded_range = ctx.make_node('Expand', [unsqueezed_matrice_range.output[0], sliced_input_shape_new.output[0]])
18081814
gathered_result = ctx.make_node('GatherElements', [sliced_input.output[0], expanded_range.output[0]],
@@ -1893,6 +1899,8 @@ def version_11(cls, ctx, node, **kwargs):
18931899
new_width = body_graph.make_node('Slice', [processed_shape.output[0], const_neg_one.output[0],
18941900
shape_processed_shape.output[0]])
18951901
abs_k = body_graph.make_node('Abs', [current_k.output[0]])
1902+
1903+
18961904
range_k = body_graph.make_node('Range', [abs_k.output[0], new_width.output[0], const_one.output[0]],
18971905
domain="com.microsoft")
18981906
sliced_range = body_graph.make_node('Slice', [range_k.output[0], const_zero.output[0], new_depth.output[0]])

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,10 @@ def _initialize_handlers(self):
175175
"Clip": self._simple_through_handler,
176176
"Concat": self._concat_handler,
177177
"Elu": self._simple_through_handler,
178+
"Exp": self._simple_through_handler,
178179
"Identity": self._identity_handler,
179180
"LeakyRelu": self._simple_through_handler,
181+
"Log": self._simple_through_handler,
180182
"Max": self._maxmin_handler,
181183
"Min": self._maxmin_handler,
182184
"Mul": self._mul_handler,

tf2onnx/rewriter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from tf2onnx.rewriter.thresholded_relu_rewriter import rewrite_thresholded_relu
2222
from tf2onnx.rewriter.transpose_rewriter import rewrite_transpose
2323
from tf2onnx.rewriter.conv2d_with_add_rewriter import rewrite_biasadd_with_conv2d
24+
from tf2onnx.rewriter.quantization_ops_rewriter import rewrite_quantize_and_dequantize
2425

2526

2627
__all__ = [
@@ -43,4 +44,5 @@
4344
"rewrite_custom_rnn_cell",
4445
"rewrite_generic_loop",
4546
"rewrite_biasadd_with_conv2d",
47+
"rewrite_quantize_and_dequantize"
4648
]
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tf2onnx.rewriter - rewrite tensorflow QuantizeAndDequantizeV3 op
6+
"""
7+
8+
import numpy as np
9+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
10+
from tf2onnx import utils
11+
12+
# pylint: disable=missing-docstring
13+
14+
def extract_numpy_array(node):
15+
return np.frombuffer(node.attr["value"].t.raw_data, dtype="float32")
16+
17+
def create_qdq_nodes(g, match_results):
18+
19+
for match in match_results:
20+
qdq_node = match.get_op('output')
21+
qdq_node_output_dtype = g.get_dtype(qdq_node.output[0])
22+
qdq_node_output_shape = g.get_shape(qdq_node.output[0])
23+
24+
# Get the attributes of qdq node
25+
narrow_range = qdq_node.attr['narrow_range'].i
26+
signed_input = qdq_node.attr['signed_input'].i
27+
28+
min_quantized, max_quantized = [-127, 127]
29+
if not narrow_range and signed_input:
30+
min_quantized = -128
31+
32+
if not signed_input:
33+
min_quantized, max_quantized = [0, 255]
34+
35+
# Get the min and max value of the inputs to QDQ op
36+
min_value = extract_numpy_array(qdq_node.inputs[1])
37+
max_value = extract_numpy_array(qdq_node.inputs[2])
38+
39+
# Calculate scales from the min and max values
40+
scale_from_min_side = min_quantized/min_value if min_quantized*min_value > 0 else max_quantized
41+
scale_from_max_side = max_quantized/max_value if max_quantized*max_value > 0 else max_quantized
42+
43+
if scale_from_min_side < scale_from_max_side:
44+
scale = scale_from_min_side
45+
else:
46+
scale = scale_from_max_side
47+
48+
utils.make_sure(scale > 0, "Quantize/Dequantize scale must be greater than zero")
49+
50+
if signed_input:
51+
zero_point = np.int8(0)
52+
else:
53+
zero_point = np.uint8(0)
54+
55+
# Split it into QuantizeLinear and DequantizeLinear and remove the QDQ node reference
56+
y_quant_scale = g.make_const(name=utils.make_name("y_quant_scale"), np_val=1/scale)
57+
y_zero_point = g.make_const(name=utils.make_name("y_zero_point"), np_val=zero_point)
58+
quant_node = g.make_node(op_type="QuantizeLinear",
59+
inputs=[qdq_node.input[0], y_quant_scale.output[0],
60+
y_zero_point.output[0]],
61+
shapes=[qdq_node_output_shape],
62+
dtypes=[qdq_node_output_dtype],
63+
name=utils.make_name("QuantLinearNode"))
64+
65+
g.set_shape(quant_node.output[0], qdq_node_output_shape)
66+
67+
g.remove_node(qdq_node.name)
68+
69+
y_dequant_scale = g.make_const(name=utils.make_name("y_dequant_scale"), np_val=1/scale)
70+
y_inv_zero_point = g.make_const(name=utils.make_name("y_inv_zero_point"), np_val=zero_point)
71+
dequant_node = g.make_node(op_type="DequantizeLinear",
72+
inputs=[quant_node.output[0], y_dequant_scale.output[0],
73+
y_inv_zero_point.output[0]],
74+
outputs=[qdq_node.output[0]],
75+
shapes=[qdq_node_output_shape],
76+
dtypes=[qdq_node_output_dtype],
77+
name=utils.make_name("DequantLinearNode"))
78+
g.set_shape(dequant_node.output[0], qdq_node_output_shape)
79+
80+
return g.get_nodes()
81+
82+
def rewrite_quantize_and_dequantize(g, ops):
83+
84+
pattern_for_qdq_v2 = \
85+
OpTypePattern('QuantizeAndDequantizeV2', name='output', inputs=[
86+
OpTypePattern("*"),
87+
OpTypePattern(None),
88+
OpTypePattern(None),
89+
])
90+
pattern_for_qdq_v3 = \
91+
OpTypePattern('QuantizeAndDequantizeV3', name='output', inputs=[
92+
OpTypePattern("*"),
93+
OpTypePattern(None),
94+
OpTypePattern(None),
95+
OpTypePattern(None),
96+
])
97+
98+
# Match all the patterns for QDQ ops
99+
patterns = [pattern_for_qdq_v3, pattern_for_qdq_v2]
100+
match_results = []
101+
for pattern in patterns:
102+
matcher = GraphMatcher(pattern)
103+
results = list(matcher.match_ops(ops))
104+
match_results.extend(results)
105+
106+
return create_qdq_nodes(g, match_results)

tf2onnx/tfonnx.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,21 +54,15 @@ def rewrite_constant_fold(g, ops):
5454
"Sqrt": np.sqrt,
5555
"Sub": np.subtract,
5656
}
57-
ref_cnt_per_node = {}
58-
for idx, op in enumerate(ops):
59-
for op_input in op.inputs:
60-
if op_input.name not in ref_cnt_per_node:
61-
ref_cnt_per_node[op_input.name] = 0
62-
ref_cnt_per_node[op_input.name] += 1
6357

6458
# pylint: disable=too-many-nested-blocks
6559
keep_looking = True
6660
while keep_looking:
6761
keep_looking = False
6862
for idx, op in enumerate(ops):
6963
func = func_map.get(op.type)
70-
if func is None:
71-
continue
64+
if func is None: continue
65+
if set(op.output) & set(g.outputs): continue
7266
try:
7367
inputs = []
7468
for node in op.inputs:
@@ -109,18 +103,14 @@ def rewrite_constant_fold(g, ops):
109103
old_node_name = op.name
110104
logger.debug("create const node [%s] replacing [%s]", new_node_name, old_node_name)
111105
ops[idx] = g.make_const(new_node_name, val)
112-
ref_cnt_per_node[new_node_name] = ref_cnt_per_node[old_node_name]
113106

114107
logger.debug("replace old output [%s] with new output [%s]", old_output_name, new_output_name)
115108
# need to re-write the consumers input name to use the const name
116109
consumers = g.find_output_consumers(old_output_name)
117110
if consumers:
118111
for consumer in consumers:
119112
g.replace_input(consumer, old_output_name, new_output_name)
120-
for node in op.inputs:
121-
ref_cnt_per_node[node.name] -= 1
122-
if ref_cnt_per_node[node.name] == 0:
123-
g.remove_node(node.name)
113+
124114
# keep looking until there is nothing we can fold.
125115
# We keep the graph in topological order so if we folded,
126116
# the result might help a following op.
@@ -459,8 +449,8 @@ def compat_handler(ctx, node, **kwargs):
459449

460450
# pre-processing graph rewrites
461451
# bi-directional re-writer should be placed after single directional re-writer
462-
rewriters = [rewrite_transpose, rewrite_flatten, rewrite_gemm,
463-
rewrite_random_uniform, rewrite_random_uniform_fold_const,
452+
rewriters = [rewrite_constant_fold, rewrite_quantize_and_dequantize, rewrite_transpose, rewrite_flatten,
453+
rewrite_gemm, rewrite_random_uniform, rewrite_random_uniform_fold_const,
464454
rewrite_random_normal, rewrite_dropout, rewrite_eye,
465455
rewrite_leakyrelu, rewrite_thresholded_relu, rewrite_conv2d_with_pad,
466456
rewrite_single_direction_lstm, rewrite_bi_direction_lstm,

0 commit comments

Comments
 (0)