|
13 | 13 | from parameterized import parameterized
|
14 | 14 | from tf2onnx import constants, utils
|
15 | 15 |
|
16 |
| -__all__ = ["TestConfig", "get_test_config", "unittest_main", |
17 |
| - "check_tf_min_version", "skip_tf_versions", |
| 16 | +__all__ = ["TestConfig", "get_test_config", "unittest_main", "check_onnxruntime_backend", |
| 17 | + "check_tf_min_version", "skip_tf_versions", "check_onnxruntime_min_version", |
18 | 18 | "check_opset_min_version", "check_target", "skip_caffe2_backend", "skip_onnxruntime_backend",
|
19 | 19 | "skip_opset", "check_onnxruntime_incompatibility", "validate_const_node",
|
20 | 20 | "group_nodes_by_type", "test_ms_domain", "check_node_domain"]
|
@@ -177,6 +177,21 @@ def skip_onnxruntime_backend(message=""):
|
177 | 177 | return unittest.skipIf(config.is_onnxruntime_backend, reason)
|
178 | 178 |
|
179 | 179 |
|
| 180 | +def check_onnxruntime_backend(message=""): |
| 181 | + """ Skip if backend is NOT onnxruntime """ |
| 182 | + config = get_test_config() |
| 183 | + reason = _append_message("only supported by onnxruntime", message) |
| 184 | + return unittest.skipIf(not config.is_onnxruntime_backend, reason) |
| 185 | + |
| 186 | + |
| 187 | +def check_onnxruntime_min_version(min_required_version, message=""): |
| 188 | + """ Skip if onnxruntime version < min_required_version """ |
| 189 | + config = get_test_config() |
| 190 | + reason = _append_message("conversion requires onnxruntime >= {}".format(min_required_version), message) |
| 191 | + return unittest.skipIf(config.is_onnxruntime_backend and |
| 192 | + config.backend_version < LooseVersion(min_required_version), reason) |
| 193 | + |
| 194 | + |
180 | 195 | def skip_caffe2_backend(message=""):
|
181 | 196 | """ Skip if backend is caffe2 """
|
182 | 197 | config = get_test_config()
|
@@ -252,6 +267,7 @@ def check_gru_count(graph, expected_count):
|
252 | 267 | def test_ms_domain(versions=None):
|
253 | 268 | """ Parameterize test case to apply ms opset(s) as extra_opset. """
|
254 | 269 |
|
| 270 | + @check_onnxruntime_backend() |
255 | 271 | def _custom_name_func(testcase_func, param_num, param):
|
256 | 272 | del param_num
|
257 | 273 | arg = param.args[0]
|
|
0 commit comments