Skip to content

Commit 6dda2bb

Browse files
authored
Fix wrong shapes in loop body inputs if shape invariances are set in TF (#2203)
* fix wrong shapes in loop body inputs if shape invariances are set in TF * fix and enable test for TF2 --------- Signed-off-by: f-salvetti <[email protected]>
1 parent d0ba20e commit 6dda2bb

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

tests/test_loops.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import tensorflow as tf
88

99
from backend_test_base import Tf2OnnxBackendTestBase
10-
from common import unittest_main, check_tf_min_version, check_tf_max_version, \
10+
from common import unittest_main, check_tf_min_version, \
1111
check_onnxruntime_min_version, check_tfjs_max_version, skip_tflite
1212
from tf2onnx.tf_loader import is_tf2
1313

@@ -286,15 +286,13 @@ def func(x, y):
286286
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-5)
287287

288288
@check_tf_min_version("1.9")
289-
@check_tf_max_version("1.15")
289+
@skip_tflite("infinite loop with tflite")
290290
def test_simple_while_loop_var_shape(self):
291291
# test for while_loop with variant shape variables
292-
# may not meet ONNX Loop spec
293-
# Note: this is not working on tf2 itself.
294292
def func(i):
295293
const = tf.constant(np.array([2], dtype=np.int32))
296294
c = lambda i: tf.reduce_all(tf.shape(i) < 10)
297-
b = lambda i: tf.concat([i, const], 0)
295+
b = lambda i: [tf.concat([i, const], 0)]
298296
r = tf.while_loop(c, b, [i], shape_invariants=[tf.TensorShape([None])])
299297
return tf.identity(r, name="output")
300298
input_names_with_port = ["input_1:0"]

tf2onnx/onnx_opset/controlflow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,8 @@ def wire_while_body(parent_g, g, loop_node, body_input_to_state_var, cond_input_
571571
g.set_dtype(func_inputs[0], onnx_pb.TensorProto.INT64)
572572
g.inputs = [g.get_node_by_output(inp) for inp in func_inputs]
573573

574-
for p, c in zip(loop_node.input, func_inputs):
574+
# we should use outputs shape, not inputs, since there may be shape invariants
575+
for p, c in zip(loop_node.output, func_inputs[2:]):
575576
g.copy_shape(p, c)
576577

577578
for i, node in enumerate(g.inputs):

0 commit comments

Comments
 (0)