Skip to content

Commit 4bc7752

Browse files
committed
support test for extra opset, add ms range test
1 parent 0138549 commit 4bc7752

File tree

6 files changed

+93
-24
lines changed

6 files changed

+93
-24
lines changed

tests/common.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@
1010
from collections import defaultdict
1111

1212
from distutils.version import LooseVersion
13+
from parameterized import parameterized
1314
from tf2onnx import constants, utils
1415

1516
__all__ = ["TestConfig", "get_test_config", "unittest_main",
1617
"check_tf_min_version", "skip_tf_versions",
1718
"check_opset_min_version", "check_target", "skip_caffe2_backend", "skip_onnxruntime_backend",
1819
"skip_opset", "check_onnxruntime_incompatibility", "validate_const_node",
19-
"group_nodes_by_type"]
20+
"group_nodes_by_type", "test_ms_domain", "check_node_domain"]
2021

2122

2223
# pylint: disable=missing-docstring
@@ -243,3 +244,31 @@ def check_lstm_count(graph, expected_count):
243244

244245
def check_gru_count(graph, expected_count):
245246
return check_op_count(graph, "GRU", expected_count)
247+
248+
249+
_MAX_MS_OPSET_VERSION = 1
250+
251+
252+
def test_ms_domain(versions=None):
253+
""" Parameterize test case to apply ms opset(s) as extra_opset. """
254+
255+
def _custom_name_func(testcase_func, param_num, param):
256+
del param_num
257+
arg = param.args[0]
258+
return "%s_%s" % (testcase_func.__name__, arg.version)
259+
260+
# Test all opset versions in ms domain if versions is not specified
261+
if versions is None:
262+
versions = list(range(1, _MAX_MS_OPSET_VERSION + 1))
263+
264+
opsets = []
265+
for version in versions:
266+
opsets.append([utils.make_opsetid(constants.MICROSOFT_DOMAIN, version)])
267+
return parameterized.expand(opsets, testcase_func_name=_custom_name_func)
268+
269+
270+
def check_node_domain(node, domain):
271+
# None or empty string means onnx domain
272+
if not domain:
273+
return not node.domain
274+
return node.domain == domain

tests/test_backend.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from backend_test_base import Tf2OnnxBackendTestBase
1717
# pylint reports unused-wildcard-import which is false positive, __all__ is defined in common
1818
from common import * # pylint: disable=wildcard-import,unused-wildcard-import
19+
from tf2onnx import constants
1920

2021
# pylint: disable=missing-docstring,invalid-name,unused-argument
2122

@@ -905,60 +906,95 @@ def test_sqrt(self):
905906
_ = tf.identity(x_, name=_TFOUTPUT)
906907
self._run_test_case([_OUTPUT], {_INPUT: x_val})
907908

908-
@check_opset_min_version(7, "cast")
909-
def test_range_const(self):
909+
def _test_range_const(self, extra_opset=None):
910+
process_args = {}
911+
if extra_opset is not None:
912+
process_args["extra_opset"] = [extra_opset]
913+
910914
x = tf.range(5)
911915
_ = tf.identity(x, name=_TFOUTPUT)
912-
self._run_test_case([_OUTPUT], {})
916+
self._run_test_case([_OUTPUT], {}, process_args=process_args)
913917
tf.reset_default_graph()
914918

915919
x = tf.range(3, 3, 5)
916920
_ = tf.identity(x, name=_TFOUTPUT)
917-
self._run_test_case([_OUTPUT], {})
921+
self._run_test_case([_OUTPUT], {}, process_args=process_args)
918922
tf.reset_default_graph()
919923

920924
x = tf.range(0, -5, -2)
921925
_ = tf.identity(x, name=_TFOUTPUT)
922-
self._run_test_case([_OUTPUT], {})
926+
self._run_test_case([_OUTPUT], {}, process_args=process_args)
923927
tf.reset_default_graph()
924928

925929
x = tf.range(-5.0, 5.0, 1.5)
926930
_ = tf.identity(x, name=_TFOUTPUT)
927-
self._run_test_case([_OUTPUT], {})
931+
self._run_test_case([_OUTPUT], {}, process_args=process_args)
928932
tf.reset_default_graph()
929933

930934
x = tf.range(2.5, 5.0, 10.0)
931935
_ = tf.identity(x, name=_TFOUTPUT)
932-
self._run_test_case([_OUTPUT], {})
936+
self._run_test_case([_OUTPUT], {}, process_args=process_args)
937+
938+
def _test_range_non_const(self, extra_opset=None):
939+
process_args = {}
940+
if extra_opset is not None:
941+
process_args["extra_opset"] = [extra_opset]
933942

934-
def test_range_non_const(self):
935943
x = tf.range(5.0)
936944
_ = tf.identity(x, name=_TFOUTPUT)
937-
self._run_test_case([_OUTPUT], {})
945+
g = self._run_test_case([_OUTPUT], {}, process_args=process_args)
946+
self.assertTrue(extra_opset is None
947+
or check_node_domain(group_nodes_by_type(g)["Range"][0], extra_opset.domain))
938948
tf.reset_default_graph()
939949

940950
x = tf.range(0, -5.0, -2)
941951
_ = tf.identity(x, name=_TFOUTPUT)
942-
self._run_test_case([_OUTPUT], {})
952+
g = self._run_test_case([_OUTPUT], {}, process_args=process_args)
953+
self.assertTrue(extra_opset is None
954+
or check_node_domain(group_nodes_by_type(g)["Range"][0], extra_opset.domain))
943955
tf.reset_default_graph()
944956

945-
x = tf.range(3.0, 3.0, 5)
946-
_ = tf.identity(x, name=_TFOUTPUT)
947-
self._run_test_case([_OUTPUT], {})
948-
tf.reset_default_graph()
957+
# disable this case for ms domain due to onnxruntime range-1 issue
958+
# https://github.com/Microsoft/onnxruntime/issues/730
959+
if not (extra_opset and extra_opset.domain == constants.MICROSOFT_DOMAIN):
960+
x = tf.range(3.0, 3.0, 5)
961+
_ = tf.identity(x, name=_TFOUTPUT)
962+
g = self._run_test_case([_OUTPUT], {}, process_args=process_args)
963+
self.assertTrue(extra_opset is None
964+
or check_node_domain(group_nodes_by_type(g)["Range"][0], extra_opset.domain))
965+
tf.reset_default_graph()
949966

950967
delta_val = np.array(1.5, dtype=np.float32)
951968
delta = tf.placeholder(tf.float32, shape=(), name=_TFINPUT)
952969
x = tf.range(-5.0, 5.0, delta)
953970
_ = tf.identity(x, name=_TFOUTPUT)
954-
self._run_test_case([_OUTPUT], {_INPUT: delta_val})
971+
g = self._run_test_case([_OUTPUT], {_INPUT: delta_val}, process_args=process_args)
972+
self.assertTrue(extra_opset is None
973+
or check_node_domain(group_nodes_by_type(g)["Range"][0], extra_opset.domain))
955974
tf.reset_default_graph()
956975

957976
start_val = np.array(2.5, dtype=np.float32)
958977
start = tf.placeholder(tf.float32, shape=(), name=_TFINPUT)
959978
x = tf.range(start, 5.0, 10.0)
960979
_ = tf.identity(x, name=_TFOUTPUT)
961-
self._run_test_case([_OUTPUT], {_INPUT: start_val})
980+
g = self._run_test_case([_OUTPUT], {_INPUT: start_val}, process_args=process_args)
981+
self.assertTrue(extra_opset is None
982+
or check_node_domain(group_nodes_by_type(g)["Range"][0], extra_opset.domain))
983+
984+
@check_opset_min_version(7, "cast")
985+
def test_range_const(self):
986+
self._test_range_const()
987+
988+
def test_range_non_const(self):
989+
self._test_range_non_const()
990+
991+
@test_ms_domain()
992+
def test_ms_range_const(self, extra_opset):
993+
self._test_range_const(extra_opset)
994+
995+
@test_ms_domain()
996+
def test_ms_range_non_const(self, extra_opset):
997+
self._test_range_non_const(extra_opset)
962998

963999
@check_onnxruntime_incompatibility("Sqrt")
9641000
def test_rsqrt(self):

tests/test_graph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from onnx import helper
1919

2020
import tf2onnx
21-
from tf2onnx import constants
21+
from tf2onnx import constants, utils
2222
from tf2onnx.graph import GraphUtil
2323
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
2424
from tf2onnx.tfonnx import process_tf_graph
@@ -362,8 +362,8 @@ def print_handler(ctx, node, name, args): # pylint: disable=unused-argument
362362

363363
def test_extra_opset(self):
364364
extra_opset = [
365-
helper.make_opsetid(constants.MICROSOFT_DOMAIN, 1),
366-
helper.make_opsetid("my.domain", 1024),
365+
utils.make_opsetid(constants.MICROSOFT_DOMAIN, 1),
366+
utils.make_opsetid("my.domain", 1024),
367367
]
368368
with tf.Session() as sess:
369369
x = tf.placeholder(tf.float32, [2, 3], name="input1")

tf2onnx/constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
common constants
66
"""
77

8-
from onnx import helper
8+
from . import utils
99

1010
# Built-in supported domains
1111
ONNX_DOMAIN = ""
@@ -16,7 +16,7 @@
1616
PREFERRED_OPSET = 7
1717

1818
# Default opset for custom ops
19-
DEFAULT_CUSTOM_OP_OPSET = helper.make_opsetid("ai.onnx.converters.tensorflow", 1)
19+
DEFAULT_CUSTOM_OP_OPSET = utils.make_opsetid("ai.onnx.converters.tensorflow", 1)
2020

2121
# Target for the generated onnx graph. It possible targets:
2222
# onnx-1.1 = onnx at v1.1 (winml in rs4 is based on this)

tf2onnx/convert.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from __future__ import unicode_literals
1111

1212
import argparse
13-
from onnx import helper
1413
import tensorflow as tf
1514

1615
from tf2onnx import constants, loader, utils
@@ -69,7 +68,7 @@ def get_args():
6968
tokens = args.extra_opset.split(':')
7069
if len(tokens) != 2:
7170
raise ValueError("invalid extra_opset argument")
72-
args.extra_opset = [helper.make_opsetid(tokens[0], int(tokens[1]))]
71+
args.extra_opset = [utils.make_opsetid(tokens[0], int(tokens[1]))]
7372
return args
7473

7574

tf2onnx/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,3 +426,8 @@ def create_vague_shape_like(shape):
426426

427427
def get_onnx_version():
428428
return onnx.__version__
429+
430+
431+
def make_opsetid(domain, version):
432+
make_sure(isinstance(version, int), "version must be an integer")
433+
return helper.make_opsetid(domain, version)

0 commit comments

Comments
 (0)