Skip to content

Commit a4607b3

Browse files
authored
Merge pull request #487 from zhijxu-MS/push_branch
enhance shape inference, fix bugs
2 parents 9aad343 + 2263e91 commit a4607b3

File tree

11 files changed

+115
-79
lines changed

11 files changed

+115
-79
lines changed

tests/common.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from distutils.version import LooseVersion
1313
from parameterized import parameterized
14+
import numpy as np
1415
from tf2onnx import constants, logging, utils
1516

1617
__all__ = [
@@ -280,7 +281,8 @@ def check_onnxruntime_incompatibility(op):
280281
def validate_const_node(node, expected_val):
281282
if node.is_const():
282283
node_val = node.get_tensor_value()
283-
return node_val == expected_val
284+
np.testing.assert_allclose(expected_val, node_val)
285+
return True
284286
return False
285287

286288

tf2onnx/graph_builder.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ def make_slice(self, kwargs, name=None, shapes=None, dtypes=None):
7070
while inputs[-1] == "":
7171
inputs = inputs[:-1]
7272

73+
if self.graph.opset >= 10:
74+
dtype = self.graph.get_dtype(inputs[1])
75+
for input_data in inputs[1:]:
76+
utils.make_sure(dtype == self.graph.get_dtype(input_data), "dtype should be same")
77+
7378
return self.graph.make_node(op_type="Slice", inputs=inputs, attr=attr, name=name,
7479
outputs=outputs, shapes=shapes, dtypes=dtypes).output[0]
7580

tf2onnx/onnx_opset/tensor.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -534,11 +534,12 @@ def version_4(cls, ctx, node, **kwargs):
534534
attr = node.get_attr(attr_name)
535535
if attr is not None and attr.i != 0:
536536
raise ValueError("StridedSlice: attribute " + attr_name + " not supported")
537-
input_shape = ctx.get_shape(node.input[0])
538-
begin = node.inputs[1].get_tensor_value(as_list=False)
539-
end = node.inputs[2].get_tensor_value(as_list=False)
540-
strides = node.inputs[3].get_tensor_value(as_list=False)
541-
max_size = np.iinfo(begin.dtype).max
537+
onnx_dtype = ctx.get_dtype(node.input[1])
538+
np_dtype = utils.ONNX_TO_NUMPY_DTYPE[onnx_dtype]
539+
max_size = np.iinfo(np_dtype).max
540+
begin = node.inputs[1].get_tensor_value()
541+
end = node.inputs[2].get_tensor_value()
542+
strides = node.inputs[3].get_tensor_value()
542543
end_mask = node.get_attr("end_mask")
543544
end_mask = end_mask.i if end_mask is not None else 0
544545
begin_mask = node.get_attr("begin_mask")

tf2onnx/rewriter/bigru_rewriter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import logging
1414
import numpy as np
1515
from tf2onnx import utils
16-
from tf2onnx.rewriter.rnn_utils import is_reverse_op
16+
from tf2onnx.utils import is_reverse_op
1717
from tf2onnx.rewriter.bilstm_rewriter import slice_bilstm_for_original_lstm_consumers,\
1818
get_reverse_nodes_after_y_output, get_np_val_for_const, _process_single_init_node
1919

tf2onnx/rewriter/bilstm_rewriter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import logging
1414
import numpy as np
1515
from tf2onnx import utils
16-
from tf2onnx.rewriter.rnn_utils import is_reverse_op
16+
from tf2onnx.utils import is_reverse_op
1717
from tf2onnx.graph_builder import GraphBuilder
1818

1919
logger = logging.getLogger(__name__)

tf2onnx/rewriter/loop_rewriter_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from collections import OrderedDict
1313
from tf2onnx import utils
1414
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
15-
from tf2onnx.rewriter.rnn_utils import is_loopcond_op, is_tensor_array_op
16-
from tf2onnx.rewriter.rnn_utils import is_tensor_array_gather_op, is_tensor_array_write_op
15+
from tf2onnx.utils import is_loopcond_op, is_tensor_array_op
16+
from tf2onnx.utils import is_tensor_array_gather_op, is_tensor_array_write_op
1717
from tf2onnx.rewriter.rnn_utils import REWRITER_RESULT
1818
from tf2onnx.utils import TensorValueInfo
1919

tf2onnx/rewriter/lstm_rewriter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
import numpy as np
1414
from tf2onnx import utils
1515
from tf2onnx.graph_builder import GraphBuilder
16-
from tf2onnx.rewriter.rnn_utils import RNNUnitType, RnnWeight, \
17-
is_concat_op, is_slice_op, get_weights_from_const_node
16+
from tf2onnx.rewriter.rnn_utils import RNNUnitType, RnnWeight, get_weights_from_const_node
17+
from tf2onnx.utils import is_concat_op, is_slice_op
1818

1919
from tf2onnx.rewriter.unit_rnn_rewriter_base import UnitRnnRewriterBase
2020

tf2onnx/rewriter/rnn_utils.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -263,35 +263,3 @@ def get_weights_from_const_node(g, node):
263263
return None
264264

265265
return RnnWeight(node, val, dtype)
266-
267-
268-
def is_reverse_op(op):
269-
return op.type in ("ReverseV2", "ReverseSequence")
270-
271-
272-
def is_concat_op(op):
273-
return op.type in ("Concat", "ConcatV2", "ConcatV3")
274-
275-
276-
def is_tensor_array_gather_op(op):
277-
return op.type in ("TensorArrayGatherV2", "TensorArrayGatherV3")
278-
279-
280-
def is_tensor_array_write_op(op):
281-
return op.type in ("TensorArrayWriteV2", "TensorArrayWriteV3")
282-
283-
284-
def is_tensor_array_op(op):
285-
return op.type in ("TensorArrayV2", "TensorArrayV3")
286-
287-
288-
def is_loopcond_op(op):
289-
return op.type == "LoopCond"
290-
291-
292-
def is_select_op(op):
293-
return op.type == "Select"
294-
295-
296-
def is_slice_op(op):
297-
return op.type == "Slice"

tf2onnx/rewriter/unit_rnn_rewriter_base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,11 @@
1313
from tf2onnx.graph_builder import GraphBuilder
1414
from tf2onnx.rewriter.loop_rewriter_base import LoopRewriterBase, Context
1515
from tf2onnx.rewriter.rnn_utils import REWRITER_RESULT, get_pattern, \
16-
get_rnn_scope_name, parse_rnn_loop, is_select_op, is_tensor_array_write_op, \
17-
seq_len_pattern
16+
get_rnn_scope_name, parse_rnn_loop, seq_len_pattern
17+
from tf2onnx.utils import is_select_op, is_tensor_array_write_op
1818
from tf2onnx.graph_matcher import GraphMatcher
1919

2020

21-
2221
logger = logging.getLogger(__name__)
2322

2423

tf2onnx/shape_inference.py

Lines changed: 61 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
from __future__ import print_function
1010
from __future__ import unicode_literals
1111
import logging
12+
import numpy as np
1213
from onnx import onnx_pb
1314
from tf2onnx import utils
1415

1516
# pylint: disable=logging-not-lazy,missing-docstring,consider-swap-variables
1617

1718

18-
1919
logger = logging.getLogger(__name__)
2020

2121
direct_ops = [
@@ -115,43 +115,19 @@ def infer_shape_for_node(g, node):
115115
return False
116116
return set_shape_from_input(g, shape_node.input[0], node.output[0])
117117

118-
if node.type == "ConcatV2":
119-
axis_node = node.inputs[-1]
120-
if not axis_node.is_const():
121-
return False
122-
123-
axis = axis_node.get_tensor_value()
124-
val = 0
125-
data_inputs = node.input[:-1]
126-
for i in data_inputs:
127-
s = g.get_shape(i)
128-
if s is None:
129-
return False
130-
131-
if s[axis] == -1:
132-
val = -1
133-
break
134-
val += s[axis]
135-
136-
s1 = g.get_shape(node.input[0])
137-
if axis < 0:
138-
axis += len(s1)
139-
new_shape = s1[:axis] + [val]
140-
if axis < len(s1) - 1:
141-
new_shape += s1[axis + 1:]
142-
143-
g.set_shape(node.output[0], new_shape)
144-
logger.debug("set ConcatV2 node [%s] with new shape %s", node.output[0], new_shape)
145-
return True
146-
147118
if node.type == "Gather":
148119
# uses the follwing link to know how to infer shape of output
149120
# https://www.tensorflow.org/api_docs/python/tf/gather
150121
shape_params = g.get_shape(node.input[0])
151122
shape_indices = g.get_shape(node.input[1])
152-
axis = node.input[2].get_tensor_value()
123+
# gather can only have 2 inputs
124+
# https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/gather.html
125+
if len(node.input) == 3:
126+
axis = node.input[2].get_tensor_value()
127+
else:
128+
axis = 0
153129

154-
shape = shape_params[:axis] + shape_indices + shape_indices[axis + 1:]
130+
shape = shape_params[:axis] + shape_indices + shape_params[axis + 1:]
155131
g.set_shape(node.output[0], shape)
156132
return True
157133

@@ -194,6 +170,29 @@ def infer_shape_for_node(g, node):
194170
logger.debug("set [%s] with new shape %s", node.output[0], new_shape)
195171
return True
196172

173+
if node.type == "Unpack":
174+
input_shape = g.get_shape(node.input[0])
175+
if input_shape is None:
176+
return False
177+
178+
axis = node.get_attr("axis").i
179+
axis = axis if axis >= 0 else axis + len(input_shape)
180+
# the link below says that the rank of output is "rank(input) -1",
181+
# from this statement "num" must equal to input_shape[axis], and if not tf will throw a runtime error
182+
# https://www.tensorflow.org/api_docs/python/tf/unstack
183+
new_shape = input_shape[:axis] + input_shape[axis + 1:]
184+
for output in node.output:
185+
g.set_shape(output, new_shape)
186+
logger.debug("set %s node [%s] with new shape %s", node.type, output, new_shape)
187+
return True
188+
189+
if node.type in ["Minimum", "Maximum"]:
190+
# ops that are elementwise and support broadcasting
191+
input_shapes = [g.get_shape(node) for node in node.input]
192+
new_shape = broadcast_shape_inference(*input_shapes)
193+
g.set_shape(node.output[0], new_shape)
194+
return True
195+
197196
return False
198197

199198

@@ -213,6 +212,36 @@ def infer_input_shapes(g, node):
213212

214213

215214
def infer_output_shapes_with_partial_inputs(g, node):
215+
# output shape of concat op: only the dim val of concatenated dim will be changed
216+
# so only partial(at least one) input shapes need to be known to infer output shape of concat node
217+
if utils.is_concat_op(node):
218+
data_inputs = node.input[:-1]
219+
input_shapes = [g.get_shape(node) for node in data_inputs]
220+
input_shapes = [shape for shape in input_shapes if shape is not None]
221+
if not input_shapes:
222+
logger.debug("all input shapes of concat node %s are None, can't infer its output shape", node.name)
223+
return False
224+
225+
new_shape = input_shapes[0]
226+
axis_node = node.inputs[-1]
227+
rank = len(new_shape)
228+
if not axis_node.is_const():
229+
g.set_shape(node.output[0], [-1] * rank)
230+
return True
231+
232+
axis = axis_node.get_tensor_value()
233+
axis = axis if axis >= 0 else axis + rank
234+
new_shape[axis] = -1
235+
if len(input_shapes) == len(data_inputs): # all input shapes are known
236+
concat_dim_vals = list(np.array(input_shapes)[:, axis])
237+
# only when inputs' shape are known, then val of concat dim can be calculated
238+
if concat_dim_vals.count(-1) == 0:
239+
new_shape[axis] = sum(concat_dim_vals)
240+
241+
g.set_shape(node.output[0], new_shape)
242+
logger.debug("set Concat node [%s] with new shape %s", node.output[0], new_shape)
243+
return True
244+
216245
if node.type == "Merge":
217246
s1 = g.get_shape(node.input[0])
218247
s2 = g.get_shape(node.input[1])

0 commit comments

Comments
 (0)