Skip to content

Commit a8328b5

Browse files
committed
add shape utilities and test
1 parent 68c1b27 commit a8328b5

File tree

2 files changed

+88
-11
lines changed

2 files changed

+88
-11
lines changed

tests/test_internals.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import tensorflow as tf
1919
import tf2onnx
20-
import tf2onnx.utils
20+
from tf2onnx import utils
2121
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
2222
from tf2onnx.graph import GraphUtil
2323
from common import unittest_main
@@ -50,14 +50,13 @@ def onnx_pretty(g, args=None):
5050

5151

5252
class Tf2OnnxInternalTests(unittest.TestCase):
53-
5453
def setUp(self):
5554
"""Setup test."""
5655
# suppress log info of tensorflow so that result of test can be seen much easier
5756
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
5857
tf.logging.set_verbosity(tf.logging.WARN)
5958

60-
tf2onnx.utils.INTERNAL_NAME = 1
59+
utils.INTERNAL_NAME = 1
6160
arg = namedtuple("Arg", "input inputs outputs verbose")
6261
self._args0 = arg(input="test", inputs=[], outputs=["output:0"], verbose=False)
6362
self._args1 = arg(input="test", inputs=["input:0"], outputs=["output:0"], verbose=False)
@@ -142,8 +141,8 @@ def test_rewrite_subgraph(self):
142141
for match in match_results:
143142
input_node = match.get_op('input')
144143
output_node = match.get_op('output')
145-
op_name = tf2onnx.utils.make_name("ReplacedOp")
146-
out_name = tf2onnx.utils.port_name(op_name)
144+
op_name = utils.make_name("ReplacedOp")
145+
out_name = utils.port_name(op_name)
147146
new_node = g.make_node("Sub", inputs=input_node.input, outputs=[out_name], name=op_name)
148147
ops = g.replace_subgraph(ops, match, [], [output_node], [], [new_node])
149148
g.topological_sort(ops)
@@ -183,10 +182,26 @@ def test_cmdarg_parse(self):
183182
arg = "input/V-1_2:0,input/X:0[1,2,3],Y:1[4,5],Z:3,A:1,B"
184183
expected_inputs = ['input/V-1_2:0', 'input/X:0', 'Y:1', 'Z:3', 'A:1', 'B']
185184
expected_shape = {'Y:1': [4, 5], 'input/X:0': [1, 2, 3]}
186-
inputs, shape_override = tf2onnx.utils.split_nodename_and_shape(arg)
185+
inputs, shape_override = utils.split_nodename_and_shape(arg)
187186
self.assertEqual(expected_inputs, inputs)
188187
self.assertEqual(expected_shape, shape_override)
189188

189+
def test_shape_utils(self):
190+
self.assertEqual(utils.merge_shapes(None, None), None)
191+
self.assertEqual(utils.merge_shapes([], None), [])
192+
self.assertEqual(utils.merge_shapes(None, [1, 2, 3]), [1, 2, 3])
193+
self.assertEqual(utils.merge_shapes([1, 3], [None, 3]), [1, 3])
194+
self.assertEqual(utils.merge_shapes([1, None, 3], (-1, 2, "unk")), [1, 2, 3])
195+
196+
self.assertTrue(utils.are_shapes_compatible(None, []))
197+
self.assertTrue(utils.are_shapes_compatible([1, None, 3], (-1, 2, "unk")))
198+
self.assertFalse(utils.are_shapes_compatible([1, 2, 3], (2, 3)))
199+
self.assertFalse(utils.are_shapes_compatible([1, 2, 3], (4, 5, 6)))
200+
201+
self.assertTrue(utils.are_shapes_equal(None, None))
202+
self.assertFalse(utils.are_shapes_equal(None, []))
203+
self.assertTrue(utils.are_shapes_equal([1, 2, 3], (1, 2, 3)))
204+
190205

191206
if __name__ == '__main__':
192207
unittest_main()

tf2onnx/utils.py

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -342,11 +342,6 @@ def tf_name_scope(name):
342342
return '/'.join(name.split('/')[:-1])
343343

344344

345-
def create_vague_shape_like(shape):
346-
make_sure(len(shape) >= 0, "rank should be >= 0")
347-
return [-1 for i in enumerate(shape)]
348-
349-
350345
def get_temp_directory():
351346
return os.environ.get("TF2ONNX_TEMP_DIRECTORY", tempfile.mkdtemp())
352347

@@ -364,3 +359,70 @@ def save_protobuf(path, message, as_text=False):
364359
else:
365360
with open(path, "wb") as f:
366361
f.write(message.SerializeToString())
362+
363+
364+
def is_list_or_tuple(obj):
365+
return type(obj) in [list, tuple]
366+
367+
368+
def is_unknown_dimension(dim):
369+
""" Return true if dim is not a positive integer value. """
370+
if dim is None or not isinstance(dim, int):
371+
return True
372+
return dim <= 0
373+
374+
375+
def merge_shapes(shape1, shape2):
376+
"""
377+
Merge 2 shapes, return merged shape, choose more specific dimension value from either side.
378+
Raise exception for mismatch.
379+
"""
380+
if shape1 is None:
381+
return shape2
382+
if shape2 is None:
383+
return shape1
384+
385+
make_sure(is_list_or_tuple(shape1), "invalid type for shape1")
386+
make_sure(is_list_or_tuple(shape2), "invalid type for shape2")
387+
make_sure(len(shape1) == len(shape2), "shapes rank mismatch: shape1=%s, shape2=%s", shape1, shape2)
388+
389+
merged = []
390+
for d1, d2 in zip(shape1, shape2):
391+
d = d1
392+
if is_unknown_dimension(d1):
393+
d = d2
394+
elif not is_unknown_dimension(d2):
395+
make_sure(d1 == d2, "shapes dimension mismatch: shape1=%s, shape2=%s", shape1, shape2)
396+
merged.append(d)
397+
return merged
398+
399+
400+
def are_shapes_compatible(src, dest):
401+
"""
402+
Returns True iff src is compatible with dest.
403+
None is compatible with all shapes, different ranks are not considered as compatible
404+
"""
405+
try:
406+
merge_shapes(src, dest)
407+
return True
408+
except Exception:
409+
return False
410+
411+
412+
def are_shapes_equal(src, dest):
413+
if src is None:
414+
return dest is None
415+
if dest is None:
416+
return src is None
417+
418+
make_sure(is_list_or_tuple(src), "invalid type for src")
419+
make_sure(is_list_or_tuple(dest), "invalid type for dest")
420+
421+
if len(src) != len(dest):
422+
return False
423+
return all(i == j for i, j in zip(src, dest))
424+
425+
426+
def create_vague_shape_like(shape):
427+
make_sure(len(shape) >= 0, "rank should be >= 0")
428+
return [-1 for i in enumerate(shape)]

0 commit comments

Comments
 (0)