Skip to content

Commit d43b585

Browse files
authored
Merge pull request #426 from nbcsm/ms
support ms domain
2 parents 9fcb069 + 271545e commit d43b585

File tree

14 files changed

+291
-93
lines changed

14 files changed

+291
-93
lines changed

ci_build/azure_pipelines/templates/setup.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
steps:
44
- bash: |
55
set -ex
6-
pip install pytest pytest-cov pytest-runner graphviz requests pyyaml pillow pandas
6+
pip install pytest pytest-cov pytest-runner graphviz requests pyyaml pillow pandas parameterized
77
pip install $(CI_PIP_TF_NAME) $(CI_PIP_ONNX_NAME) $(CI_PIP_ONNX_BACKEND_NAME)
88
99
# TF 1.10 requires numpy <=1.14.5 and >=1.13.3, but onnxruntime 0.2.1 does not work with numpy <= 1.14.5

examples/custom_op_via_python.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def print_handler(ctx, node, name, args):
2525
x_ = tf.Print(x, [x], "hello")
2626
_ = tf.identity(x_, name="output")
2727
onnx_graph = tf2onnx.tfonnx.process_tf_graph(sess.graph,
28-
custom_op_handlers={"Print": print_handler},
28+
custom_op_handlers={"Print": (print_handler, [])},
2929
extra_opset=[helper.make_opsetid(_TENSORFLOW_DOMAIN, 1)],
3030
input_names=["input:0"],
3131
output_names=["output:0"])

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def run(self):
7474
version=VersionInfo.version,
7575
description='Tensorflow to ONNX converter',
7676
setup_requires=['pytest-runner'],
77-
tests_require=['requests', 'pytest', 'pytest-cov', 'graphviz', 'pyyaml'],
77+
tests_require=['graphviz', 'requests', 'parameterized', 'pytest', 'pytest-cov', 'pyyaml'],
7878
cmdclass=cmdclass,
7979
packages=find_packages(),
8080

tests/common.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010
from collections import defaultdict
1111

1212
from distutils.version import LooseVersion
13-
from tf2onnx import utils
14-
from tf2onnx.tfonnx import DEFAULT_TARGET, POSSIBLE_TARGETS
13+
from parameterized import parameterized
14+
from tf2onnx import constants, utils
1515

1616
__all__ = ["TestConfig", "get_test_config", "unittest_main",
1717
"check_tf_min_version", "skip_tf_versions",
1818
"check_opset_min_version", "check_target", "skip_caffe2_backend", "skip_onnxruntime_backend",
1919
"skip_opset", "check_onnxruntime_incompatibility", "validate_const_node",
20-
"group_nodes_by_type"]
20+
"group_nodes_by_type", "test_ms_domain", "check_node_domain"]
2121

2222

2323
# pylint: disable=missing-docstring
@@ -26,8 +26,8 @@ class TestConfig(object):
2626
def __init__(self):
2727
self.platform = sys.platform
2828
self.tf_version = self._get_tf_version()
29-
self.opset = int(os.environ.get("TF2ONNX_TEST_OPSET", 7))
30-
self.target = os.environ.get("TF2ONNX_TEST_TARGET", ",".join(DEFAULT_TARGET)).split(',')
29+
self.opset = int(os.environ.get("TF2ONNX_TEST_OPSET", constants.PREFERRED_OPSET))
30+
self.target = os.environ.get("TF2ONNX_TEST_TARGET", ",".join(constants.DEFAULT_TARGET)).split(',')
3131
self.backend = os.environ.get("TF2ONNX_TEST_BACKEND", "onnxruntime")
3232
self.backend_version = self._get_backend_version()
3333
self.is_debug_mode = False
@@ -83,7 +83,7 @@ def load():
8383
choices=["caffe2", "onnxmsrtnext", "onnxruntime"],
8484
help="backend to test against")
8585
parser.add_argument("--opset", type=int, default=config.opset, help="opset to test against")
86-
parser.add_argument("--target", default=",".join(config.target), choices=POSSIBLE_TARGETS,
86+
parser.add_argument("--target", default=",".join(config.target), choices=constants.POSSIBLE_TARGETS,
8787
help="target platform")
8888
parser.add_argument("--debug", help="output debugging information", action="store_true")
8989
parser.add_argument("--temp_dir", help="temp dir")
@@ -244,3 +244,31 @@ def check_lstm_count(graph, expected_count):
244244

245245
def check_gru_count(graph, expected_count):
246246
return check_op_count(graph, "GRU", expected_count)
247+
248+
249+
_MAX_MS_OPSET_VERSION = 1
250+
251+
252+
def test_ms_domain(versions=None):
253+
""" Parameterize test case to apply ms opset(s) as extra_opset. """
254+
255+
def _custom_name_func(testcase_func, param_num, param):
256+
del param_num
257+
arg = param.args[0]
258+
return "%s_%s" % (testcase_func.__name__, arg.version)
259+
260+
# Test all opset versions in ms domain if versions is not specified
261+
if versions is None:
262+
versions = list(range(1, _MAX_MS_OPSET_VERSION + 1))
263+
264+
opsets = []
265+
for version in versions:
266+
opsets.append([utils.make_opsetid(constants.MICROSOFT_DOMAIN, version)])
267+
return parameterized.expand(opsets, testcase_func_name=_custom_name_func)
268+
269+
270+
def check_node_domain(node, domain):
271+
# None or empty string means onnx domain
272+
if not domain:
273+
return not node.domain
274+
return node.domain == domain

tests/test_backend.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from backend_test_base import Tf2OnnxBackendTestBase
1717
# pylint reports unused-wildcard-import which is false positive, __all__ is defined in common
1818
from common import * # pylint: disable=wildcard-import,unused-wildcard-import
19+
from tf2onnx import constants
1920

2021
# pylint: disable=missing-docstring,invalid-name,unused-argument
2122

@@ -905,60 +906,95 @@ def test_sqrt(self):
905906
_ = tf.identity(x_, name=_TFOUTPUT)
906907
self._run_test_case([_OUTPUT], {_INPUT: x_val})
907908

908-
@check_opset_min_version(7, "cast")
909-
def test_range_const(self):
909+
def _test_range_const(self, extra_opset=None):
910+
process_args = {}
911+
if extra_opset is not None:
912+
process_args["extra_opset"] = [extra_opset]
913+
910914
x = tf.range(5)
911915
_ = tf.identity(x, name=_TFOUTPUT)
912-
self._run_test_case([_OUTPUT], {})
916+
self._run_test_case([_OUTPUT], {}, process_args=process_args)
913917
tf.reset_default_graph()
914918

915919
x = tf.range(3, 3, 5)
916920
_ = tf.identity(x, name=_TFOUTPUT)
917-
self._run_test_case([_OUTPUT], {})
921+
self._run_test_case([_OUTPUT], {}, process_args=process_args)
918922
tf.reset_default_graph()
919923

920924
x = tf.range(0, -5, -2)
921925
_ = tf.identity(x, name=_TFOUTPUT)
922-
self._run_test_case([_OUTPUT], {})
926+
self._run_test_case([_OUTPUT], {}, process_args=process_args)
923927
tf.reset_default_graph()
924928

925929
x = tf.range(-5.0, 5.0, 1.5)
926930
_ = tf.identity(x, name=_TFOUTPUT)
927-
self._run_test_case([_OUTPUT], {})
931+
self._run_test_case([_OUTPUT], {}, process_args=process_args)
928932
tf.reset_default_graph()
929933

930934
x = tf.range(2.5, 5.0, 10.0)
931935
_ = tf.identity(x, name=_TFOUTPUT)
932-
self._run_test_case([_OUTPUT], {})
936+
self._run_test_case([_OUTPUT], {}, process_args=process_args)
937+
938+
def _test_range_non_const(self, extra_opset=None):
939+
process_args = {}
940+
if extra_opset is not None:
941+
process_args["extra_opset"] = [extra_opset]
933942

934-
def test_range_non_const(self):
935943
x = tf.range(5.0)
936944
_ = tf.identity(x, name=_TFOUTPUT)
937-
self._run_test_case([_OUTPUT], {})
945+
g = self._run_test_case([_OUTPUT], {}, process_args=process_args)
946+
self.assertTrue(extra_opset is None
947+
or check_node_domain(group_nodes_by_type(g)["Range"][0], extra_opset.domain))
938948
tf.reset_default_graph()
939949

940950
x = tf.range(0, -5.0, -2)
941951
_ = tf.identity(x, name=_TFOUTPUT)
942-
self._run_test_case([_OUTPUT], {})
952+
g = self._run_test_case([_OUTPUT], {}, process_args=process_args)
953+
self.assertTrue(extra_opset is None
954+
or check_node_domain(group_nodes_by_type(g)["Range"][0], extra_opset.domain))
943955
tf.reset_default_graph()
944956

945-
x = tf.range(3.0, 3.0, 5)
946-
_ = tf.identity(x, name=_TFOUTPUT)
947-
self._run_test_case([_OUTPUT], {})
948-
tf.reset_default_graph()
957+
# disable this case for ms domain due to onnxruntime range-1 issue
958+
# https://github.com/Microsoft/onnxruntime/issues/730
959+
if not (extra_opset and extra_opset.domain == constants.MICROSOFT_DOMAIN):
960+
x = tf.range(3.0, 3.0, 5)
961+
_ = tf.identity(x, name=_TFOUTPUT)
962+
g = self._run_test_case([_OUTPUT], {}, process_args=process_args)
963+
self.assertTrue(extra_opset is None
964+
or check_node_domain(group_nodes_by_type(g)["Range"][0], extra_opset.domain))
965+
tf.reset_default_graph()
949966

950967
delta_val = np.array(1.5, dtype=np.float32)
951968
delta = tf.placeholder(tf.float32, shape=(), name=_TFINPUT)
952969
x = tf.range(-5.0, 5.0, delta)
953970
_ = tf.identity(x, name=_TFOUTPUT)
954-
self._run_test_case([_OUTPUT], {_INPUT: delta_val})
971+
g = self._run_test_case([_OUTPUT], {_INPUT: delta_val}, process_args=process_args)
972+
self.assertTrue(extra_opset is None
973+
or check_node_domain(group_nodes_by_type(g)["Range"][0], extra_opset.domain))
955974
tf.reset_default_graph()
956975

957976
start_val = np.array(2.5, dtype=np.float32)
958977
start = tf.placeholder(tf.float32, shape=(), name=_TFINPUT)
959978
x = tf.range(start, 5.0, 10.0)
960979
_ = tf.identity(x, name=_TFOUTPUT)
961-
self._run_test_case([_OUTPUT], {_INPUT: start_val})
980+
g = self._run_test_case([_OUTPUT], {_INPUT: start_val}, process_args=process_args)
981+
self.assertTrue(extra_opset is None
982+
or check_node_domain(group_nodes_by_type(g)["Range"][0], extra_opset.domain))
983+
984+
@check_opset_min_version(7, "cast")
985+
def test_range_const(self):
986+
self._test_range_const()
987+
988+
def test_range_non_const(self):
989+
self._test_range_non_const()
990+
991+
@test_ms_domain()
992+
def test_ms_range_const(self, extra_opset):
993+
self._test_range_const(extra_opset)
994+
995+
@test_ms_domain()
996+
def test_ms_range_non_const(self, extra_opset):
997+
self._test_range_non_const(extra_opset)
962998

963999
@check_onnxruntime_incompatibility("Sqrt")
9641000
def test_rsqrt(self):

tests/test_graph.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
from onnx import helper
1919

2020
import tf2onnx
21+
from tf2onnx import constants, utils
22+
from tf2onnx.graph import GraphUtil
2123
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
2224
from tf2onnx.tfonnx import process_tf_graph
2325
from common import get_test_config, unittest_main
2426

25-
_TENSORFLOW_DOMAIN = "ai.onnx.converters.tensorflow"
26-
2727

2828
# pylint: disable=missing-docstring
2929

@@ -71,6 +71,10 @@ def get_attribute_value(attr):
7171
if "broadcast" in attr:
7272
kwarg["broadcast"] = "{}".format(int(attr["broadcast"].i))
7373

74+
# display domain if it is not onnx domain
75+
if node.domain:
76+
kwarg["domain"] = node.domain
77+
7478
g2.node(node.name, op_type=node.type, **kwarg)
7579
for node in g.get_nodes():
7680
for i in node.input:
@@ -336,7 +340,7 @@ def print_handler(ctx, node, name, args): # pylint: disable=unused-argument
336340
# becomes:
337341
# T output = Identity(T Input)
338342
self.assertEqual(node.type, "Identity")
339-
node.domain = _TENSORFLOW_DOMAIN
343+
node.domain = constants.TENSORFLOW_OPSET.domain
340344
self.assertEqual(args[0], "mode")
341345
del node.input[1:]
342346
return node
@@ -348,11 +352,35 @@ def print_handler(ctx, node, name, args): # pylint: disable=unused-argument
348352
g = process_tf_graph(sess.graph,
349353
custom_op_handlers={"Print": (print_handler, ["Identity", "mode"])},
350354
opset=self.config.opset,
351-
extra_opset=helper.make_opsetid(_TENSORFLOW_DOMAIN, 1))
355+
extra_opset=[constants.TENSORFLOW_OPSET])
352356
self.assertEqual(
353-
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Print [op_type=Identity] '
354-
'output [op_type=Identity] input1:0 -> Print Print:0 -> output }',
357+
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Print [domain="ai.onnx.converters.tensorflow" '
358+
'op_type=Identity] output [op_type=Identity] input1:0 -> Print Print:0 -> output }',
355359
onnx_to_graphviz(g))
360+
self.assertEqual(g.opset, self.config.opset)
361+
self.assertEqual(g.extra_opset, [constants.TENSORFLOW_OPSET])
362+
363+
def test_extra_opset(self):
364+
extra_opset = [
365+
utils.make_opsetid(constants.MICROSOFT_DOMAIN, 1),
366+
utils.make_opsetid("my.domain", 1024),
367+
]
368+
with tf.Session() as sess:
369+
x = tf.placeholder(tf.float32, [2, 3], name="input1")
370+
x_ = tf.add(x, x)
371+
_ = tf.identity(x_, name="output")
372+
g = process_tf_graph(sess.graph,
373+
opset=self.config.opset,
374+
extra_opset=extra_opset)
375+
self.assertEqual(g.opset, self.config.opset)
376+
self.assertEqual(g.extra_opset, extra_opset)
377+
378+
# convert between graph and model proto, make sure extra opset is preserved
379+
model_proto = g.make_model("test")
380+
model_proto = GraphUtil.optimize_model_proto(model_proto)
381+
g = GraphUtil.create_graph_from_onnx_model(model_proto)
382+
self.assertEqual(g.opset, self.config.opset)
383+
self.assertEqual(g.extra_opset, extra_opset)
356384

357385

358386
if __name__ == '__main__':

tf2onnx/constants.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
common constants
6+
"""
7+
8+
from . import utils
9+
10+
# Built-in supported domains
11+
ONNX_DOMAIN = ""
12+
AI_ONNX_ML_DOMAIN = "ai.onnx.ml"
13+
MICROSOFT_DOMAIN = "com.microsoft"
14+
15+
# Default opset version for onnx domain
16+
PREFERRED_OPSET = 7
17+
18+
# Default opset for custom ops
19+
TENSORFLOW_OPSET = utils.make_opsetid("ai.onnx.converters.tensorflow", 1)
20+
21+
# Target for the generated onnx graph. It possible targets:
22+
# onnx-1.1 = onnx at v1.1 (winml in rs4 is based on this)
23+
# caffe2 = include some workarounds for caffe2 and winml
24+
TARGET_RS4 = "rs4"
25+
TARGET_RS5 = "rs5"
26+
TARGET_RS6 = "rs6"
27+
TARGET_CAFFE2 = "caffe2"
28+
POSSIBLE_TARGETS = [TARGET_RS4, TARGET_RS5, TARGET_RS6, TARGET_CAFFE2]
29+
DEFAULT_TARGET = []

tf2onnx/convert.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,11 @@
1010
from __future__ import unicode_literals
1111

1212
import argparse
13-
from onnx import helper
1413
import tensorflow as tf
1514

16-
from tf2onnx import utils
17-
from tf2onnx import loader
15+
from tf2onnx import constants, loader, utils
1816
from tf2onnx.graph import GraphUtil
19-
from tf2onnx.tfonnx import process_tf_graph, tf_optimize, DEFAULT_TARGET, POSSIBLE_TARGETS
20-
21-
_TENSORFLOW_DOMAIN = "ai.onnx.converters.tensorflow"
17+
from tf2onnx.tfonnx import process_tf_graph, tf_optimize
2218

2319

2420
# pylint: disable=unused-argument
@@ -34,9 +30,12 @@ def get_args():
3430
parser.add_argument("--output", help="output model file")
3531
parser.add_argument("--inputs", help="model input_names")
3632
parser.add_argument("--outputs", help="model output_names")
37-
parser.add_argument("--opset", type=int, default=None, help="onnx opset to use")
33+
parser.add_argument("--opset", type=int, default=None, help="opset version to use for onnx domain")
3834
parser.add_argument("--custom-ops", help="list of custom ops")
39-
parser.add_argument("--target", default=",".join(DEFAULT_TARGET), choices=POSSIBLE_TARGETS, help="target platform")
35+
parser.add_argument("--extra_opset", default=None,
36+
help="extra opset with format like domain:version, e.g. com.microsoft:1")
37+
parser.add_argument("--target", default=",".join(constants.DEFAULT_TARGET), choices=constants.POSSIBLE_TARGETS,
38+
help="target platform")
4039
parser.add_argument("--continue_on_error", help="continue_on_error", action="store_true")
4140
parser.add_argument("--verbose", help="verbose output", action="store_true")
4241
parser.add_argument("--fold_const", help="enable tf constant_folding transformation before conversion",
@@ -65,11 +64,16 @@ def get_args():
6564
if args.target:
6665
args.target = args.target.split(",")
6766

67+
if args.extra_opset:
68+
tokens = args.extra_opset.split(':')
69+
if len(tokens) != 2:
70+
raise ValueError("invalid extra_opset argument")
71+
args.extra_opset = [utils.make_opsetid(tokens[0], int(tokens[1]))]
6872
return args
6973

7074

7175
def default_custom_op_handler(ctx, node, name, args):
72-
node.domain = _TENSORFLOW_DOMAIN
76+
node.domain = constants.TENSORFLOW_OPSET.domain
7377
return node
7478

7579

@@ -80,13 +84,12 @@ def main():
8084
# support unknown dimensions.
8185
utils.ONNX_UNKNOWN_DIMENSION = args.unknown_dim
8286

87+
extra_opset = args.extra_opset or []
88+
custom_ops = {}
8389
if args.custom_ops:
8490
# default custom ops for tensorflow-onnx are in the "tf" namespace
8591
custom_ops = {op: (default_custom_op_handler, []) for op in args.custom_ops.split(",")}
86-
extra_opset = [helper.make_opsetid(_TENSORFLOW_DOMAIN, 1)]
87-
else:
88-
custom_ops = {}
89-
extra_opset = None
92+
extra_opset.append(constants.TENSORFLOW_OPSET)
9093

9194
# get the frozen tensorflow model from graphdef, checkpoint or saved_model.
9295
if args.graphdef:

tf2onnx/custom_opsets/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
""" custom tf2onnx mapping functions. """
4+
5+
from . import ms
6+
from .. import constants
7+
8+
DOMAIN_OPSETS = {
9+
constants.MICROSOFT_DOMAIN: ms.OPSETS
10+
}

0 commit comments

Comments
 (0)