|
17 | 17 |
|
18 | 18 | import tensorflow as tf
|
19 | 19 | import tf2onnx
|
20 |
| -import tf2onnx.utils |
| 20 | +from tf2onnx import utils |
21 | 21 | from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
|
22 | 22 | from tf2onnx.graph import GraphUtil
|
23 | 23 | from common import unittest_main
|
@@ -50,14 +50,13 @@ def onnx_pretty(g, args=None):
|
50 | 50 |
|
51 | 51 |
|
52 | 52 | class Tf2OnnxInternalTests(unittest.TestCase):
|
53 |
| - |
54 | 53 | def setUp(self):
|
55 | 54 | """Setup test."""
|
56 | 55 | # suppress log info of tensorflow so that result of test can be seen much easier
|
57 | 56 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
58 | 57 | tf.logging.set_verbosity(tf.logging.WARN)
|
59 | 58 |
|
60 |
| - tf2onnx.utils.INTERNAL_NAME = 1 |
| 59 | + utils.INTERNAL_NAME = 1 |
61 | 60 | arg = namedtuple("Arg", "input inputs outputs verbose")
|
62 | 61 | self._args0 = arg(input="test", inputs=[], outputs=["output:0"], verbose=False)
|
63 | 62 | self._args1 = arg(input="test", inputs=["input:0"], outputs=["output:0"], verbose=False)
|
@@ -142,8 +141,8 @@ def test_rewrite_subgraph(self):
|
142 | 141 | for match in match_results:
|
143 | 142 | input_node = match.get_op('input')
|
144 | 143 | output_node = match.get_op('output')
|
145 |
| - op_name = tf2onnx.utils.make_name("ReplacedOp") |
146 |
| - out_name = tf2onnx.utils.port_name(op_name) |
| 144 | + op_name = utils.make_name("ReplacedOp") |
| 145 | + out_name = utils.port_name(op_name) |
147 | 146 | new_node = g.make_node("Sub", inputs=input_node.input, outputs=[out_name], name=op_name)
|
148 | 147 | ops = g.replace_subgraph(ops, match, [], [output_node], [], [new_node])
|
149 | 148 | g.topological_sort(ops)
|
@@ -183,10 +182,26 @@ def test_cmdarg_parse(self):
|
183 | 182 | arg = "input/V-1_2:0,input/X:0[1,2,3],Y:1[4,5],Z:3,A:1,B"
|
184 | 183 | expected_inputs = ['input/V-1_2:0', 'input/X:0', 'Y:1', 'Z:3', 'A:1', 'B']
|
185 | 184 | expected_shape = {'Y:1': [4, 5], 'input/X:0': [1, 2, 3]}
|
186 |
| - inputs, shape_override = tf2onnx.utils.split_nodename_and_shape(arg) |
| 185 | + inputs, shape_override = utils.split_nodename_and_shape(arg) |
187 | 186 | self.assertEqual(expected_inputs, inputs)
|
188 | 187 | self.assertEqual(expected_shape, shape_override)
|
189 | 188 |
|
| 189 | + def test_shape_utils(self): |
| 190 | + self.assertEqual(utils.merge_shapes(None, None), None) |
| 191 | + self.assertEqual(utils.merge_shapes([], None), []) |
| 192 | + self.assertEqual(utils.merge_shapes(None, [1, 2, 3]), [1, 2, 3]) |
| 193 | + self.assertEqual(utils.merge_shapes([1, 3], [None, 3]), [1, 3]) |
| 194 | + self.assertEqual(utils.merge_shapes([1, None, 3], (-1, 2, "unk")), [1, 2, 3]) |
| 195 | + |
| 196 | + self.assertTrue(utils.are_shapes_compatible(None, [])) |
| 197 | + self.assertTrue(utils.are_shapes_compatible([1, None, 3], (-1, 2, "unk"))) |
| 198 | + self.assertFalse(utils.are_shapes_compatible([1, 2, 3], (2, 3))) |
| 199 | + self.assertFalse(utils.are_shapes_compatible([1, 2, 3], (4, 5, 6))) |
| 200 | + |
| 201 | + self.assertTrue(utils.are_shapes_equal(None, None)) |
| 202 | + self.assertFalse(utils.are_shapes_equal(None, [])) |
| 203 | + self.assertTrue(utils.are_shapes_equal([1, 2, 3], (1, 2, 3))) |
| 204 | + |
190 | 205 |
|
191 | 206 | if __name__ == '__main__':
|
192 | 207 | unittest_main()
|
0 commit comments