|
20 | 20 |
|
21 | 21 | class LoopTests(Tf2OnnxBackendTestBase):
|
22 | 22 |
|
23 |
| - @check_tf_min_version("1.9") |
24 |
| - def test_simple_while_loop_var_shape(self): |
25 |
| - # test for while_loop with variant shape variables |
26 |
| - # may not meet ONNX Loop spec |
27 |
| - i = tf.placeholder(tf.int32, (1), name="input_1") |
28 |
| - const = tf.constant(np.array([2], dtype=np.int32)) |
29 |
| - |
30 |
| - c = lambda i: tf.reduce_all(tf.shape(i) < 10) |
31 |
| - b = lambda i: tf.concat([i, const], 0) |
32 |
| - r = tf.while_loop(c, b, [i], shape_invariants=[tf.TensorShape([None])]) |
33 |
| - |
34 |
| - _ = tf.identity(r, name="output") |
35 |
| - input_names_with_port = ["input_1:0"] |
36 |
| - feed_dict = {"input_1:0": np.array([0], dtype=np.int32)} |
37 |
| - |
38 |
| - output_names_with_port = ["output:0"] |
39 |
| - self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06) |
40 |
| - |
41 | 23 | def test_simple_while_loop(self):
|
42 | 24 | i = tf.placeholder(tf.int32, (), name="input_1")
|
43 | 25 | c = lambda i: tf.less(i, 10)
|
@@ -214,6 +196,24 @@ def fn1(elem):
|
214 | 196 | self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-5)
|
215 | 197 | tf.reset_default_graph()
|
216 | 198 |
|
| 199 | + @check_tf_min_version("1.9") |
| 200 | + def test_simple_while_loop_var_shape(self): |
| 201 | + # test for while_loop with variant shape variables |
| 202 | + # may not meet ONNX Loop spec |
| 203 | + i = tf.placeholder(tf.int32, (1), name="input_1") |
| 204 | + const = tf.constant(np.array([2], dtype=np.int32)) |
| 205 | + |
| 206 | + c = lambda i: tf.reduce_all(tf.shape(i) < 10) |
| 207 | + b = lambda i: tf.concat([i, const], 0) |
| 208 | + r = tf.while_loop(c, b, [i], shape_invariants=[tf.TensorShape([None])]) |
| 209 | + |
| 210 | + _ = tf.identity(r, name="output") |
| 211 | + input_names_with_port = ["input_1:0"] |
| 212 | + feed_dict = {"input_1:0": np.array([0], dtype=np.int32)} |
| 213 | + |
| 214 | + output_names_with_port = ["output:0"] |
| 215 | + self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06) |
| 216 | + |
217 | 217 |
|
218 | 218 | if __name__ == '__main__':
|
219 | 219 | unittest_main()
|
0 commit comments