Skip to content

Commit 5005fe5

Browse files
infer shape for tensorflow graph
1 parent 348b4dd commit 5005fe5

File tree

9 files changed

+187
-363
lines changed

9 files changed

+187
-363
lines changed

tests/common.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
class TestConfig(object):
4444
def __init__(self):
4545
self.platform = sys.platform
46-
self.tf_version = self._get_tf_version()
46+
self.tf_version = utils.get_tf_version()
4747
self.opset = int(os.environ.get("TF2ONNX_TEST_OPSET", constants.PREFERRED_OPSET))
4848
self.target = os.environ.get("TF2ONNX_TEST_TARGET", ",".join(constants.DEFAULT_TARGET)).split(',')
4949
self.backend = os.environ.get("TF2ONNX_TEST_BACKEND", "onnxruntime")
@@ -67,10 +67,6 @@ def is_caffe2_backend(self):
6767
def is_debug_mode(self):
6868
return utils.is_debug_mode()
6969

70-
def _get_tf_version(self):
71-
import tensorflow as tf
72-
return LooseVersion(tf.__version__)
73-
7470
def _get_backend_version(self):
7571
version = None
7672
if self.backend == "onnxruntime":

tests/test_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1271,7 +1271,7 @@ def test_randomuniform_int(self):
12711271
def test_randomuniform_dyn_shape(self):
12721272
# test for dynamic shape coming from a shape op
12731273
x_val = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
1274-
x = tf.placeholder(x_val.dtype, name=_TFINPUT)
1274+
x = tf.placeholder(x_val.dtype, [None, 3], name=_TFINPUT)
12751275
x_ = tf.stack([x, x])
12761276
x_ = tf.identity(x_)
12771277
x_ = tf.shape(x_, name="shape")

tests/test_loops.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,32 @@
1212
import tensorflow as tf
1313

1414
from backend_test_base import Tf2OnnxBackendTestBase
15-
from common import unittest_main
15+
from common import unittest_main, check_tf_min_version
1616

1717

1818
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
1919

2020

2121
class LoopTests(Tf2OnnxBackendTestBase):
2222

23+
@check_tf_min_version("1.9")
24+
def test_simple_while_loop_var_shape(self):
25+
# test for while_loop with variant shape variables
26+
# may not meet ONNX Loop spec
27+
i = tf.placeholder(tf.int32, (1), name="input_1")
28+
const = tf.constant(np.array([2], dtype=np.int32))
29+
30+
c = lambda i: tf.reduce_all(tf.shape(i) < 10)
31+
b = lambda i: tf.concat([i, const], 0)
32+
r = tf.while_loop(c, b, [i], shape_invariants=[tf.TensorShape([None])])
33+
34+
_ = tf.identity(r, name="output")
35+
input_names_with_port = ["input_1:0"]
36+
feed_dict = {"input_1:0": np.array([0], dtype=np.int32)}
37+
38+
output_names_with_port = ["output:0"]
39+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
40+
2341
def test_simple_while_loop(self):
2442
i = tf.placeholder(tf.int32, (), name="input_1")
2543
c = lambda i: tf.less(i, 10)

tf2onnx/loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def freeze_session(sess, keep_var_names=None, output_names=None, clear_devices=T
2929
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
3030
output_names = output_names or []
3131
output_names += [v.op.name for v in tf.global_variables()]
32-
input_graph_def = graph.as_graph_def()
32+
input_graph_def = graph.as_graph_def(add_shapes=True)
3333
if clear_devices:
3434
for node in input_graph_def.node:
3535
node.device = ""

tf2onnx/onnx_opset/controlflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def version_8(cls, ctx, node, **kwargs):
268268
make_sure(true_data_shape is not None, "select true data shape cannot be None")
269269

270270
condition_shape = ctx.get_shape(node.input[0])
271-
utils.make_sure(condition_shape is not None, "condition shape is None")
271+
utils.make_sure(condition_shape is not None, "Shape of {} is None".format(node.input[0]))
272272
rank = len(condition_shape)
273273

274274
utils.make_sure(rank >= 0, "rank should be >= 0")

tf2onnx/rewriter/loop_rewriter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ def rewrite(self, context):
5858
body_nodes = set(cell_g_info.nodes + cond_g_info.nodes)
5959
body_outputs = cond_g_info.outputs + cell_g_info.outputs
6060
for out_tensor_value_info in body_outputs:
61-
out_tensor_value_info.shape = utils.create_vague_shape_like(out_tensor_value_info.shape)
61+
shape = out_tensor_value_info.shape
62+
utils.make_sure(shape is not None, "Shape of {} is None".format(out_tensor_value_info.id))
63+
out_tensor_value_info.shape = utils.create_vague_shape_like(shape)
6264

6365
loop_body_g = LoopRewriterBase.construct_graph_from_nodes(self.g, body_nodes, body_outputs)
6466

0 commit comments

Comments
 (0)