|
12 | 12 | import tensorflow as tf
|
13 | 13 |
|
14 | 14 | from backend_test_base import Tf2OnnxBackendTestBase
|
15 |
| -from common import unittest_main |
| 15 | +from common import unittest_main, check_tf_min_version |
16 | 16 |
|
17 | 17 |
|
18 | 18 | # pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
|
19 | 19 |
|
20 | 20 |
|
21 | 21 | class LoopTests(Tf2OnnxBackendTestBase):
|
22 | 22 |
|
| 23 | + @check_tf_min_version("1.9") |
| 24 | + def test_simple_while_loop_var_shape(self): |
| 25 | + # test for while_loop with variant shape variables |
| 26 | + # may not meet ONNX Loop spec |
| 27 | + i = tf.placeholder(tf.int32, (1), name="input_1") |
| 28 | + const = tf.constant(np.array([2], dtype=np.int32)) |
| 29 | + |
| 30 | + c = lambda i: tf.reduce_all(tf.shape(i) < 10) |
| 31 | + b = lambda i: tf.concat([i, const], 0) |
| 32 | + r = tf.while_loop(c, b, [i], shape_invariants=[tf.TensorShape([None])]) |
| 33 | + |
| 34 | + _ = tf.identity(r, name="output") |
| 35 | + input_names_with_port = ["input_1:0"] |
| 36 | + feed_dict = {"input_1:0": np.array([0], dtype=np.int32)} |
| 37 | + |
| 38 | + output_names_with_port = ["output:0"] |
| 39 | + self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06) |
| 40 | + |
23 | 41 | def test_simple_while_loop(self):
|
24 | 42 | i = tf.placeholder(tf.int32, (), name="input_1")
|
25 | 43 | c = lambda i: tf.less(i, 10)
|
|
0 commit comments