Skip to content

Commit dd3a410

Browse files
authored
Merge pull request #484 from lucienwang1009/opset10_bug
disable shape inference test for scan in opsets larger than 8
2 parents 8c72723 + e41fb55 commit dd3a410

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

tests/common.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,27 @@
1313
from parameterized import parameterized
1414
from tf2onnx import constants, logging, utils
1515

16-
__all__ = ["TestConfig", "get_test_config", "unittest_main", "check_onnxruntime_backend",
17-
"check_tf_min_version", "check_tf_max_version", "skip_tf_versions", "check_onnxruntime_min_version",
18-
"check_opset_min_version", "check_target", "skip_caffe2_backend", "skip_onnxruntime_backend",
19-
"skip_opset", "check_onnxruntime_incompatibility", "validate_const_node",
20-
"group_nodes_by_type", "test_ms_domain", "check_node_domain"]
16+
__all__ = [
17+
"TestConfig",
18+
"get_test_config",
19+
"unittest_main",
20+
"check_onnxruntime_backend",
21+
"check_tf_min_version",
22+
"check_tf_max_version",
23+
"skip_tf_versions",
24+
"check_onnxruntime_min_version",
25+
"check_opset_min_version",
26+
"check_opset_max_version",
27+
"check_target",
28+
"skip_caffe2_backend",
29+
"skip_onnxruntime_backend",
30+
"skip_opset",
31+
"check_onnxruntime_incompatibility",
32+
"validate_const_node",
33+
"group_nodes_by_type",
34+
"test_ms_domain",
35+
"check_node_domain"
36+
]
2137

2238

2339
# pylint: disable=missing-docstring
@@ -174,6 +190,13 @@ def check_opset_min_version(min_required_version, message=""):
174190
return unittest.skipIf(config.opset < min_required_version, reason)
175191

176192

193+
def check_opset_max_version(max_accepted_version, message=""):
194+
""" Skip if opset > max_accepted_version """
195+
config = get_test_config()
196+
reason = _append_message("conversion requires opset <= {}".format(max_accepted_version), message)
197+
return unittest.skipIf(config.opset > max_accepted_version, reason)
198+
199+
177200
def skip_opset(opset_v, message=""):
178201
""" Skip if opset = opset_v """
179202
config = get_test_config()

tests/test_onnx_shape_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def test_constant_of_shape(self):
267267

268268
# node with subgraph
269269
@check_opset_min_version(8, "Scan")
270-
@skip_opset(9, "Scan")
270+
@check_opset_max_version(8, "Scan")
271271
def test_scan(self):
272272
batch_size = 1
273273
seq_len = 16

0 commit comments

Comments
 (0)