|
8 | 8 |
|
9 | 9 | from backend_test_base import Tf2OnnxBackendTestBase
|
10 | 10 | from common import unittest_main, check_tf_min_version, check_tf_max_version, \
|
11 |
| - check_onnxruntime_min_version, check_tfjs_max_version |
| 11 | + check_onnxruntime_min_version, check_tfjs_max_version, skip_tflite |
12 | 12 | from tf2onnx.tf_loader import is_tf2
|
13 | 13 |
|
14 | 14 |
|
@@ -302,6 +302,23 @@ def func(i):
|
302 | 302 | output_names_with_port = ["output:0"]
|
303 | 303 | self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
|
304 | 304 |
|
| 305 | + @check_tf_min_version("2") |
| 306 | + @skip_tflite("cond_graph conversion fails with tflite") |
| 307 | + def test_while_loop_cond_subgraphs(self): |
| 308 | + # test for while_loop with subgraphs in cond |
| 309 | + # Note: this is not working on tf1 |
| 310 | + def func(x): |
| 311 | + x_dim = tf.shape(x)[0] |
| 312 | + r = tf.cast(tf.zeros(1), x.dtype) |
| 313 | + for i in tf.range(10): |
| 314 | + if i == x_dim: |
| 315 | + break |
| 316 | + r += x[i] |
| 317 | + return tf.identity(r, name="output") |
| 318 | + input_names_with_port = ["input_1:0"] |
| 319 | + feed_dict = {"input_1:0": np.arange(0, 15, dtype=np.int32)} |
| 320 | + output_names_with_port = ["output:0"] |
| 321 | + self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port) |
305 | 322 |
|
306 | 323 | if __name__ == '__main__':
|
307 | 324 | unittest_main()
|
0 commit comments