Skip to content

Commit 08370e6

Browse files
make unittest only run on onnxruntime > 0.4.0
1 parent 4eda997 commit 08370e6

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

tests/common.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from parameterized import parameterized
1414
from tf2onnx import constants, utils
1515

16-
__all__ = ["TestConfig", "get_test_config", "unittest_main",
17-
"check_tf_min_version", "skip_tf_versions",
16+
__all__ = ["TestConfig", "get_test_config", "unittest_main", "check_onnxruntime_backend",
17+
"check_tf_min_version", "skip_tf_versions", "check_onnxruntime_min_version",
1818
"check_opset_min_version", "check_target", "skip_caffe2_backend", "skip_onnxruntime_backend",
1919
"skip_opset", "check_onnxruntime_incompatibility", "validate_const_node",
2020
"group_nodes_by_type", "test_ms_domain", "check_node_domain"]
@@ -177,6 +177,21 @@ def skip_onnxruntime_backend(message=""):
177177
return unittest.skipIf(config.is_onnxruntime_backend, reason)
178178

179179

180+
def check_onnxruntime_backend(message=""):
181+
""" Skip if backend is NOT onnxruntime """
182+
config = get_test_config()
183+
reason = _append_message("only supported by onnxruntime", message)
184+
return unittest.skipIf(not config.is_onnxruntime_backend, reason)
185+
186+
187+
def check_onnxruntime_min_version(min_required_version, message=""):
188+
""" Skip if onnxruntime version < min_required_version """
189+
config = get_test_config()
190+
reason = _append_message("conversion requires onnxruntime >= {}".format(min_required_version), message)
191+
return unittest.skipIf(config.is_onnxruntime_backend and
192+
config.backend_version < LooseVersion(min_required_version), reason)
193+
194+
180195
def skip_caffe2_backend(message=""):
181196
""" Skip if backend is caffe2 """
182197
config = get_test_config()
@@ -252,6 +267,7 @@ def check_gru_count(graph, expected_count):
252267
def test_ms_domain(versions=None):
253268
""" Parameterize test case to apply ms opset(s) as extra_opset. """
254269

270+
@check_onnxruntime_backend()
255271
def _custom_name_func(testcase_func, param_num, param):
256272
del param_num
257273
arg = param.args[0]

tests/test_backend.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1596,13 +1596,14 @@ def test_reverse_sequence_batch_major(self):
15961596
def test_reverse_sequence_time_major(self):
15971597
self._test_reverse_sequence_time_major()
15981598

1599+
# only support onnxruntime with version larger than 0.4.0
15991600
@test_ms_domain()
1600-
@unittest.skipIf(True, "not support in pypi onnxruntime")
1601+
@check_onnxruntime_min_version("0.4.0")
16011602
def test_ms_reverse_sequence_batch_major(self, extra_opset):
16021603
self._test_reverse_sequence_batch_major(extra_opset)
16031604

16041605
@test_ms_domain()
1605-
@unittest.skipIf(True, "not support in pypi onnxruntime")
1606+
@check_onnxruntime_min_version("0.4.0")
16061607
def test_ms_reverse_sequence_time_major(self, extra_opset):
16071608
self._test_reverse_sequence_time_major(extra_opset)
16081609

0 commit comments

Comments
 (0)