|
9 | 9 | from tensorflow.python.ops import init_ops
|
10 | 10 | from tensorflow.python.ops import variable_scope
|
11 | 11 | from backend_test_base import Tf2OnnxBackendTestBase
|
12 |
| -from common import unittest_main, check_gru_count, check_opset_after_tf_version, check_op_count, check_tf_min_version |
| 12 | +from common import * # pylint: disable=wildcard-import,unused-wildcard-import |
13 | 13 | from tf2onnx.tf_loader import is_tf2
|
14 | 14 |
|
15 | 15 | # pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test,cell-var-from-loop
|
@@ -607,6 +607,7 @@ def func(x):
|
607 | 607 | graph_validator=lambda g: check_gru_count(g, 1))
|
608 | 608 |
|
609 | 609 | @check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
|
| 610 | + @skip_tf_versions(["2.1"], "TF fails to correctly add output_2 node.") |
610 | 611 | def test_dynamic_multi_bigru_with_same_input_hidden_size(self):
|
611 | 612 | batch_size = 10
|
612 | 613 | x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
|
@@ -660,6 +661,7 @@ def func(x):
|
660 | 661 | # graph_validator=lambda g: check_gru_count(g, 2))
|
661 | 662 |
|
662 | 663 | @check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
|
| 664 | + @skip_tf_versions(["2.1"], "TF fails to correctly add output_2 node.") |
663 | 665 | def test_dynamic_multi_bigru_with_same_input_seq_len(self):
|
664 | 666 | units = 5
|
665 | 667 | batch_size = 10
|
@@ -714,7 +716,7 @@ def func(x, y1, y2):
|
714 | 716 | self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
|
715 | 717 | # graph_validator=lambda g: check_gru_count(g, 2))
|
716 | 718 |
|
717 |
| - @check_tf_min_version("2.0") |
| 719 | + @check_tf_min_version("2.2") |
718 | 720 | def test_keras_gru(self):
|
719 | 721 | in_shape = [10, 3]
|
720 | 722 | x_val = np.random.uniform(size=[2, 10, 3]).astype(np.float32)
|
|
0 commit comments