Skip to content

Commit 496b65d

Browse files
Fix nightly tests (#1712)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 1c60588 commit 496b65d

File tree

3 files changed

+10
-5
lines changed

3 files changed

+10
-5
lines changed

tests/common.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@
4545
"group_nodes_by_type",
4646
"test_ms_domain",
4747
"check_node_domain",
48-
"check_op_count"
48+
"check_op_count",
49+
"check_gru_count",
50+
"check_lstm_count",
4951
]
5052

5153

tests/run_pretrained_models.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ googlenet_v4_slim:
210210
rtol: 0.1
211211

212212
mobilenet_v3_large_float:
213+
tf_min_version: 1.14 # explicit_paddings for Conv2D
213214
url: https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-large_224_1.0_float.tgz
214215
model: v3-large_224_1.0_float/v3-large_224_1.0_float.pb
215216
input_get: get_beach
@@ -428,7 +429,7 @@ faster_rcnn_inception_v2_coco:
428429
- num_detections:0
429430

430431
keras_resnet50:
431-
tf_min_version: 2.1
432+
tf_min_version: 2.2
432433
disabled: false
433434
url: module://tensorflow.keras.applications.resnet50/ResNet50
434435
model: ResNet50
@@ -440,7 +441,7 @@ keras_resnet50:
440441
- Identity:0
441442

442443
keras_mobilenet_v2:
443-
tf_min_version: 2.1
444+
tf_min_version: 2.2
444445
disabled: false
445446
url: module://tensorflow.keras.applications.mobilenet_v2/MobileNetV2
446447
model: MobileNetV2

tests/test_gru.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from tensorflow.python.ops import init_ops
1010
from tensorflow.python.ops import variable_scope
1111
from backend_test_base import Tf2OnnxBackendTestBase
12-
from common import unittest_main, check_gru_count, check_opset_after_tf_version, check_op_count, check_tf_min_version
12+
from common import * # pylint: disable=wildcard-import,unused-wildcard-import
1313
from tf2onnx.tf_loader import is_tf2
1414

1515
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test,cell-var-from-loop
@@ -607,6 +607,7 @@ def func(x):
607607
graph_validator=lambda g: check_gru_count(g, 1))
608608

609609
@check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
610+
@skip_tf_versions(["2.1"], "TF fails to correctly add output_2 node.")
610611
def test_dynamic_multi_bigru_with_same_input_hidden_size(self):
611612
batch_size = 10
612613
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
@@ -660,6 +661,7 @@ def func(x):
660661
# graph_validator=lambda g: check_gru_count(g, 2))
661662

662663
@check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
664+
@skip_tf_versions(["2.1"], "TF fails to correctly add output_2 node.")
663665
def test_dynamic_multi_bigru_with_same_input_seq_len(self):
664666
units = 5
665667
batch_size = 10
@@ -714,7 +716,7 @@ def func(x, y1, y2):
714716
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
715717
# graph_validator=lambda g: check_gru_count(g, 2))
716718

717-
@check_tf_min_version("2.0")
719+
@check_tf_min_version("2.2")
718720
def test_keras_gru(self):
719721
in_shape = [10, 3]
720722
x_val = np.random.uniform(size=[2, 10, 3]).astype(np.float32)

0 commit comments

Comments
 (0)