Skip to content

Commit d2ece0c

Browse files
authored
Merge pull request #350 from nbcsm/shape
shape inference support LogicalOr, relax shape check for cond rewriter
2 parents cd34b38 + db4b9a7 commit d2ece0c

File tree

4 files changed

+102
-19
lines changed

4 files changed

+102
-19
lines changed

tests/test_internals.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
from onnx import helper
1717

1818
import tensorflow as tf
19-
import tf2onnx
20-
import tf2onnx.utils
19+
from tf2onnx import utils
2120
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
2221
from tf2onnx.graph import GraphUtil
2322
from common import unittest_main
@@ -50,14 +49,13 @@ def onnx_pretty(g, args=None):
5049

5150

5251
class Tf2OnnxInternalTests(unittest.TestCase):
53-
5452
def setUp(self):
5553
"""Setup test."""
5654
# suppress log info of tensorflow so that result of test can be seen much easier
5755
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
5856
tf.logging.set_verbosity(tf.logging.WARN)
5957

60-
tf2onnx.utils.INTERNAL_NAME = 1
58+
utils.INTERNAL_NAME = 1
6159
arg = namedtuple("Arg", "input inputs outputs verbose")
6260
self._args0 = arg(input="test", inputs=[], outputs=["output:0"], verbose=False)
6361
self._args1 = arg(input="test", inputs=["input:0"], outputs=["output:0"], verbose=False)
@@ -142,8 +140,8 @@ def test_rewrite_subgraph(self):
142140
for match in match_results:
143141
input_node = match.get_op('input')
144142
output_node = match.get_op('output')
145-
op_name = tf2onnx.utils.make_name("ReplacedOp")
146-
out_name = tf2onnx.utils.port_name(op_name)
143+
op_name = utils.make_name("ReplacedOp")
144+
out_name = utils.port_name(op_name)
147145
new_node = g.make_node("Sub", inputs=input_node.input, outputs=[out_name], name=op_name)
148146
ops = g.replace_subgraph(ops, match, [], [output_node], [], [new_node])
149147
g.topological_sort(ops)
@@ -183,10 +181,26 @@ def test_cmdarg_parse(self):
183181
arg = "input/V-1_2:0,input/X:0[1,2,3],Y:1[4,5],Z:3,A:1,B"
184182
expected_inputs = ['input/V-1_2:0', 'input/X:0', 'Y:1', 'Z:3', 'A:1', 'B']
185183
expected_shape = {'Y:1': [4, 5], 'input/X:0': [1, 2, 3]}
186-
inputs, shape_override = tf2onnx.utils.split_nodename_and_shape(arg)
184+
inputs, shape_override = utils.split_nodename_and_shape(arg)
187185
self.assertEqual(expected_inputs, inputs)
188186
self.assertEqual(expected_shape, shape_override)
189187

188+
def test_shape_utils(self):
189+
self.assertEqual(utils.merge_shapes(None, None), None)
190+
self.assertEqual(utils.merge_shapes([], None), [])
191+
self.assertEqual(utils.merge_shapes(None, [1, 2, 3]), [1, 2, 3])
192+
self.assertEqual(utils.merge_shapes([1, 3], [None, 3]), [1, 3])
193+
self.assertEqual(utils.merge_shapes([1, None, 3], (-1, 2, "unk")), [1, 2, 3])
194+
195+
self.assertTrue(utils.are_shapes_compatible(None, []))
196+
self.assertTrue(utils.are_shapes_compatible([1, None, 3], (-1, 2, "unk")))
197+
self.assertFalse(utils.are_shapes_compatible([1, 2, 3], (2, 3)))
198+
self.assertFalse(utils.are_shapes_compatible([1, 2, 3], (4, 5, 6)))
199+
200+
self.assertTrue(utils.are_shapes_equal(None, None))
201+
self.assertFalse(utils.are_shapes_equal(None, []))
202+
self.assertTrue(utils.are_shapes_equal([1, 2, 3], (1, 2, 3)))
203+
190204

191205
if __name__ == '__main__':
192206
unittest_main()

tf2onnx/rewriter/cond_rewriter.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
logging.basicConfig(level=logging.INFO)
1717
log = logging.getLogger("tf2onnx.rewriter.cond_rewriter_base")
1818

19+
1920
# pylint: disable=missing-docstring,unused-argument,broad-except
2021

2122
class BranchType(Enum):
@@ -29,6 +30,7 @@ class BranchType(Enum):
2930

3031
class CondBranchContext:
3132
"""Context for each branch graph"""
33+
3234
def __init__(self):
3335
self.output = []
3436
self.nodes = set()
@@ -37,12 +39,12 @@ def __init__(self):
3739
class CondContext:
3840
def __init__(self, cond_scope, pred_input, true_branch_context,
3941
false_branch_context, switchs, merges):
40-
self.cond_scope = cond_scope # name scope for this tf.cond
41-
self.pred_input = pred_input # condition input
42+
self.cond_scope = cond_scope # name scope for this tf.cond
43+
self.pred_input = pred_input # condition input
4244
self.true_branch_context = true_branch_context
4345
self.false_branch_context = false_branch_context
4446
self.switchs = set(switchs)
45-
self.merges = merges # list of merges in order
47+
self.merges = merges # list of merges in order
4648

4749

4850
class CondRewriter:
@@ -114,7 +116,7 @@ def _get_output_shape_dtype(self, cond_context):
114116
true_dtype = self.g.get_dtype(true_output)
115117
false_shape = self.g.get_shape(false_output)
116118
false_dtype = self.g.get_dtype(false_output)
117-
if true_shape != false_shape:
119+
if not utils.are_shapes_compatible(true_shape, false_shape):
118120
raise RuntimeError(
119121
"the shape of outputs {} and {} mismatch: {}, {}".format(
120122
true_output,
@@ -132,7 +134,7 @@ def _get_output_shape_dtype(self, cond_context):
132134
false_dtype
133135
)
134136
)
135-
output_shapes.append(true_shape)
137+
output_shapes.append(utils.merge_shapes(true_shape, false_shape))
136138
output_dtypes.append(true_dtype)
137139
return output_shapes, output_dtypes
138140

@@ -243,11 +245,13 @@ def _trace_back_from_one_merge(self, merge_node):
243245
merge_input_1 = merge_node.input[0]
244246
merge_input_2 = merge_node.input[1]
245247
switchs = set()
248+
246249
def stop_at_switch(node):
247250
if self._is_switch(node):
248251
switchs.add(node)
249252
return False
250253
return True
254+
251255
branch_nodes_1 = self.g.extract_sub_graph_nodes(
252256
[merge_input_1],
253257
stop_at_switch

tf2onnx/shape_inference.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@
3636
"GreaterEqual",
3737
"Less",
3838
"LogicalAnd",
39+
"LogicalOr",
3940
"Mul",
4041
"RealDiv",
4142
"Sub"
4243
]
4344

45+
4446
def infer_shape_for_graph(g):
4547
no_shape_updated = True
4648
while no_shape_updated:
@@ -135,7 +137,7 @@ def infer_shape_for_node(g, node):
135137
axis += len(s1)
136138
new_shape = s1[:axis] + [val]
137139
if axis < len(s1) - 1:
138-
new_shape += s1[axis+1:]
140+
new_shape += s1[axis + 1:]
139141

140142
g.set_shape(node.output[0], new_shape)
141143
log.debug("set ConcatV2 node [%s] with new shape %s", node.output[0], new_shape)
@@ -148,7 +150,7 @@ def infer_shape_for_node(g, node):
148150
shape_indices = g.get_shape(node.input[1])
149151
axis = node.input[2].get_tensor_value()
150152

151-
shape = shape_params[:axis] + shape_indices + shape_indices[axis+1:]
153+
shape = shape_params[:axis] + shape_indices + shape_indices[axis + 1:]
152154
g.set_shape(node.output[0], shape)
153155
return True
154156

tf2onnx/utils.py

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -342,11 +342,6 @@ def tf_name_scope(name):
342342
return '/'.join(name.split('/')[:-1])
343343

344344

345-
def create_vague_shape_like(shape):
346-
make_sure(len(shape) >= 0, "rank should be >= 0")
347-
return [-1 for i in enumerate(shape)]
348-
349-
350345
def get_temp_directory():
351346
return os.environ.get("TF2ONNX_TEMP_DIRECTORY", tempfile.mkdtemp())
352347

@@ -364,3 +359,71 @@ def save_protobuf(path, message, as_text=False):
364359
else:
365360
with open(path, "wb") as f:
366361
f.write(message.SerializeToString())
362+
363+
364+
def is_list_or_tuple(obj):
365+
return isinstance(obj, (list, tuple))
366+
367+
368+
def is_unknown_dimension(dim):
369+
""" Return true if dim is not a positive integer value. """
370+
if dim is None or not isinstance(dim, int):
371+
return True
372+
return dim <= 0
373+
374+
375+
def merge_shapes(shape1, shape2):
376+
"""
377+
Merge 2 shapes, return merged shape, choose more specific dimension value from either side.
378+
Raise exception for mismatch.
379+
"""
380+
if shape1 is None:
381+
return shape2
382+
if shape2 is None:
383+
return shape1
384+
385+
make_sure(is_list_or_tuple(shape1), "invalid type for shape1")
386+
make_sure(is_list_or_tuple(shape2), "invalid type for shape2")
387+
make_sure(len(shape1) == len(shape2), "shapes rank mismatch: shape1=%s, shape2=%s", shape1, shape2)
388+
389+
merged = []
390+
for d1, d2 in zip(shape1, shape2):
391+
d = d1
392+
if is_unknown_dimension(d1):
393+
d = d2
394+
elif not is_unknown_dimension(d2):
395+
make_sure(d1 == d2, "shapes dimension mismatch: shape1=%s, shape2=%s", shape1, shape2)
396+
merged.append(d)
397+
return merged
398+
399+
400+
def are_shapes_compatible(src, dest):
401+
"""
402+
Returns True iff src is compatible with dest.
403+
None is compatible with all shapes, different ranks are not considered as compatible
404+
"""
405+
try:
406+
merge_shapes(src, dest)
407+
return True
408+
except: # pylint: disable=bare-except
409+
return False
410+
411+
412+
def are_shapes_equal(src, dest):
413+
""" Check whether 2 shapes are equal. """
414+
if src is None:
415+
return dest is None
416+
if dest is None:
417+
return src is None
418+
419+
make_sure(is_list_or_tuple(src), "invalid type for src")
420+
make_sure(is_list_or_tuple(dest), "invalid type for dest")
421+
422+
if len(src) != len(dest):
423+
return False
424+
return all(i == j for i, j in zip(src, dest))
425+
426+
427+
def create_vague_shape_like(shape):
428+
make_sure(len(shape) >= 0, "rank should be >= 0")
429+
return [-1 for i in enumerate(shape)]

0 commit comments

Comments
 (0)