|
7 | 7 | import tensorflow as tf
|
8 | 8 |
|
9 | 9 | from backend_test_base import Tf2OnnxBackendTestBase
|
10 |
| -from common import unittest_main, check_tf_min_version, check_tf_max_version, \ |
| 10 | +from common import unittest_main, check_tf_min_version, \ |
11 | 11 | check_onnxruntime_min_version, check_tfjs_max_version, skip_tflite
|
12 | 12 | from tf2onnx.tf_loader import is_tf2
|
13 | 13 |
|
@@ -286,15 +286,13 @@ def func(x, y):
|
286 | 286 | self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-5)
|
287 | 287 |
|
288 | 288 | @check_tf_min_version("1.9")
|
289 |
| - @check_tf_max_version("1.15") |
| 289 | + @skip_tflite("infinite loop with tflite") |
290 | 290 | def test_simple_while_loop_var_shape(self):
|
291 | 291 | # test for while_loop with variant shape variables
|
292 |
| - # may not meet ONNX Loop spec |
293 |
| - # Note: this is not working on tf2 itself. |
294 | 292 | def func(i):
|
295 | 293 | const = tf.constant(np.array([2], dtype=np.int32))
|
296 | 294 | c = lambda i: tf.reduce_all(tf.shape(i) < 10)
|
297 |
| - b = lambda i: tf.concat([i, const], 0) |
| 295 | + b = lambda i: [tf.concat([i, const], 0)] |
298 | 296 | r = tf.while_loop(c, b, [i], shape_invariants=[tf.TensorShape([None])])
|
299 | 297 | return tf.identity(r, name="output")
|
300 | 298 | input_names_with_port = ["input_1:0"]
|
|
0 commit comments