Skip to content

Commit 74b120e

Browse files
committed
extract common constants
1 parent 343affe commit 74b120e

File tree

8 files changed

+67
-59
lines changed

8 files changed

+67
-59
lines changed

examples/custom_op_via_python.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
import tensorflow as tf
55
import tf2onnx
66
from onnx import helper
7-
8-
_TENSORFLOW_DOMAIN = "ai.onnx.converters.tensorflow"
7+
from tf2onnx import constants
98

109

1110
def print_handler(ctx, node, name, args):
@@ -14,7 +13,7 @@ def print_handler(ctx, node, name, args):
1413
# becomes:
1514
# T output = Identity(T Input)
1615
node.type = "Identity"
17-
node.domain = _TENSORFLOW_DOMAIN
16+
node.domain = constants.DEFAULT_CUSTOM_OP_OPSET.domain
1817
del node.input[1:]
1918
return node
2019

@@ -26,7 +25,7 @@ def print_handler(ctx, node, name, args):
2625
_ = tf.identity(x_, name="output")
2726
onnx_graph = tf2onnx.tfonnx.process_tf_graph(sess.graph,
2827
custom_op_handlers={"Print": print_handler},
29-
extra_opset=[helper.make_opsetid(_TENSORFLOW_DOMAIN, 1)],
28+
extra_opset=[constants.DEFAULT_CUSTOM_OP_OPSET],
3029
input_names=["input:0"],
3130
output_names=["output:0"])
3231
model_proto = onnx_graph.make_model("test")

tests/common.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
from collections import defaultdict
1111

1212
from distutils.version import LooseVersion
13-
from tf2onnx import utils
14-
from tf2onnx.tfonnx import DEFAULT_TARGET, POSSIBLE_TARGETS
13+
from tf2onnx import constants, utils
1514

1615
__all__ = ["TestConfig", "get_test_config", "unittest_main",
1716
"check_tf_min_version", "skip_tf_versions",
@@ -26,8 +25,8 @@ class TestConfig(object):
2625
def __init__(self):
2726
self.platform = sys.platform
2827
self.tf_version = self._get_tf_version()
29-
self.opset = int(os.environ.get("TF2ONNX_TEST_OPSET", 7))
30-
self.target = os.environ.get("TF2ONNX_TEST_TARGET", ",".join(DEFAULT_TARGET)).split(',')
28+
self.opset = int(os.environ.get("TF2ONNX_TEST_OPSET", constants.PREFERRED_OPSET))
29+
self.target = os.environ.get("TF2ONNX_TEST_TARGET", ",".join(constants.DEFAULT_TARGET)).split(',')
3130
self.backend = os.environ.get("TF2ONNX_TEST_BACKEND", "onnxruntime")
3231
self.backend_version = self._get_backend_version()
3332
self.is_debug_mode = False
@@ -83,7 +82,7 @@ def load():
8382
choices=["caffe2", "onnxmsrtnext", "onnxruntime"],
8483
help="backend to test against")
8584
parser.add_argument("--opset", type=int, default=config.opset, help="opset to test against")
86-
parser.add_argument("--target", default=",".join(config.target), choices=POSSIBLE_TARGETS,
85+
parser.add_argument("--target", default=",".join(config.target), choices=constants.POSSIBLE_TARGETS,
8786
help="target platform")
8887
parser.add_argument("--debug", help="output debugging information", action="store_true")
8988
parser.add_argument("--temp_dir", help="temp dir")

tests/test_graph.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,11 @@
1818
from onnx import helper
1919

2020
import tf2onnx
21+
from tf2onnx import constants
2122
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
2223
from tf2onnx.tfonnx import process_tf_graph
2324
from common import get_test_config, unittest_main
2425

25-
_TENSORFLOW_DOMAIN = "ai.onnx.converters.tensorflow"
26-
2726

2827
# pylint: disable=missing-docstring
2928

@@ -336,7 +335,7 @@ def print_handler(ctx, node, name, args): # pylint: disable=unused-argument
336335
# becomes:
337336
# T output = Identity(T Input)
338337
self.assertEqual(node.type, "Identity")
339-
node.domain = _TENSORFLOW_DOMAIN
338+
node.domain = constants.DEFAULT_CUSTOM_OP_OPSET.domain
340339
self.assertEqual(args[0], "mode")
341340
del node.input[1:]
342341
return node
@@ -348,7 +347,7 @@ def print_handler(ctx, node, name, args): # pylint: disable=unused-argument
348347
g = process_tf_graph(sess.graph,
349348
custom_op_handlers={"Print": (print_handler, ["Identity", "mode"])},
350349
opset=self.config.opset,
351-
extra_opset=helper.make_opsetid(_TENSORFLOW_DOMAIN, 1))
350+
extra_opset=[constants.DEFAULT_CUSTOM_OP_OPSET])
352351
self.assertEqual(
353352
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Print [op_type=Identity] '
354353
'output [op_type=Identity] input1:0 -> Print Print:0 -> output }',

tf2onnx/constants.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
common constants
6+
"""
7+
8+
from onnx import helper
9+
10+
# Built-in supported domains
11+
ONNX_DOMAIN = ""
12+
AI_ONNX_ML_DOMAIN = "ai.onnx.ml"
13+
MICROSOFT_DOMAIN = "com.microsoft"
14+
15+
# Default opset version for onnx domain
16+
PREFERRED_OPSET = 7
17+
18+
# Default opset for custom ops
19+
DEFAULT_CUSTOM_OP_OPSET = helper.make_opsetid("ai.onnx.converters.tensorflow", 1)
20+
21+
# Target for the generated onnx graph. It possible targets:
22+
# onnx-1.1 = onnx at v1.1 (winml in rs4 is based on this)
23+
# caffe2 = include some workarounds for caffe2 and winml
24+
TARGET_RS4 = "rs4"
25+
TARGET_RS5 = "rs5"
26+
TARGET_RS6 = "rs6"
27+
TARGET_CAFFE2 = "caffe2"
28+
POSSIBLE_TARGETS = [TARGET_RS4, TARGET_RS5, TARGET_RS6, TARGET_CAFFE2]
29+
DEFAULT_TARGET = []

tf2onnx/convert.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,9 @@
1313
from onnx import helper
1414
import tensorflow as tf
1515

16-
from tf2onnx import utils
17-
from tf2onnx import loader
16+
from tf2onnx import constants, loader, utils
1817
from tf2onnx.graph import GraphUtil
19-
from tf2onnx.tfonnx import process_tf_graph, tf_optimize, DEFAULT_TARGET, POSSIBLE_TARGETS
20-
21-
_TENSORFLOW_DOMAIN = "ai.onnx.converters.tensorflow"
18+
from tf2onnx.tfonnx import process_tf_graph, tf_optimize
2219

2320

2421
# pylint: disable=unused-argument
@@ -36,7 +33,8 @@ def get_args():
3633
parser.add_argument("--outputs", help="model output_names")
3734
parser.add_argument("--opset", type=int, default=None, help="onnx opset to use")
3835
parser.add_argument("--custom-ops", help="list of custom ops")
39-
parser.add_argument("--target", default=",".join(DEFAULT_TARGET), choices=POSSIBLE_TARGETS, help="target platform")
36+
parser.add_argument("--target", default=",".join(constants.DEFAULT_TARGET), choices=constants.POSSIBLE_TARGETS,
37+
help="target platform")
4038
parser.add_argument("--continue_on_error", help="continue_on_error", action="store_true")
4139
parser.add_argument("--verbose", help="verbose output", action="store_true")
4240
parser.add_argument("--fold_const", help="enable tf constant_folding transformation before conversion",
@@ -69,7 +67,7 @@ def get_args():
6967

7068

7169
def default_custom_op_handler(ctx, node, name, args):
72-
node.domain = _TENSORFLOW_DOMAIN
70+
node.domain = constants.DEFAULT_CUSTOM_OP_OPSET.domain
7371
return node
7472

7573

@@ -83,7 +81,7 @@ def main():
8381
if args.custom_ops:
8482
# default custom ops for tensorflow-onnx are in the "tf" namespace
8583
custom_ops = {op: (default_custom_op_handler, []) for op in args.custom_ops.split(",")}
86-
extra_opset = [helper.make_opsetid(_TENSORFLOW_DOMAIN, 1)]
84+
extra_opset = [constants.DEFAULT_CUSTOM_OP_OPSET]
8785
else:
8886
custom_ops = {}
8987
extra_opset = None

tf2onnx/schemas.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from collections import defaultdict, OrderedDict
1313
from onnx import defs
1414

15-
ONNX_DOMAIN = ""
15+
from . import constants
1616

1717

1818
class OnnxOpSchema(object):
@@ -97,9 +97,9 @@ def _parse_domain_opset_versions(schemas):
9797
_domain_opset_versions = _parse_domain_opset_versions(_schemas)
9898

9999

100-
def get_schema(name, max_inclusive_opset_version, domain=ONNX_DOMAIN):
100+
def get_schema(name, max_inclusive_opset_version, domain=None):
101101
"""Get schema by name within specific version."""
102-
domain = domain or ONNX_DOMAIN
102+
domain = domain or constants.ONNX_DOMAIN
103103
domain_version_schema_map = _schemas[name]
104104
version_schema_map = domain_version_schema_map[domain]
105105
for version, schema in version_schema_map.items():
@@ -108,7 +108,7 @@ def get_schema(name, max_inclusive_opset_version, domain=ONNX_DOMAIN):
108108
return None
109109

110110

111-
def get_max_supported_opset_version(domain=ONNX_DOMAIN):
111+
def get_max_supported_opset_version(domain=None):
112112
"""Get max supported opset version by current onnx package given a domain."""
113-
domain = domain or ONNX_DOMAIN
113+
domain = domain or constants.ONNX_DOMAIN
114114
return _domain_opset_versions.get(domain, None)

tf2onnx/tfonnx.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from tensorflow.tools.graph_transforms import TransformGraph
2222

2323
import tf2onnx
24-
from tf2onnx import schemas, utils
24+
from tf2onnx import constants, schemas, utils
2525
from tf2onnx.function import * # pylint: disable=wildcard-import
2626
from tf2onnx.graph import Graph
2727
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
@@ -32,16 +32,6 @@
3232
logging.basicConfig(level=logging.INFO)
3333
log = logging.getLogger("tf2onnx")
3434

35-
# Target for the generated onnx graph. It possible targets:
36-
# onnx-1.1 = onnx at v1.1 (winml in rs4 is based on this)
37-
# caffe2 = include some workarounds for caffe2 and winml
38-
TARGET_RS4 = "rs4"
39-
TARGET_RS5 = "rs5"
40-
TARGET_RS6 = "rs6"
41-
TARGET_CAFFE2 = "caffe2"
42-
POSSIBLE_TARGETS = [TARGET_RS4, TARGET_RS5, TARGET_RS6, TARGET_CAFFE2]
43-
DEFAULT_TARGET = []
44-
4535

4636
# pylint: disable=useless-return,broad-except,logging-not-lazy,unused-argument,missing-docstring
4737
# FIXME:
@@ -178,7 +168,7 @@ def broadcast_op(ctx, node, name, args):
178168
node.set_attr("broadcast", 1)
179169
# this works around shortcomings in the broadcasting code
180170
# of caffe2 and winml/rs4.
181-
if ctx.is_target(TARGET_RS4):
171+
if ctx.is_target(constants.TARGET_RS4):
182172
# in rs4 mul and add do not support scalar correctly
183173
if not shape0:
184174
if node.inputs[0].is_const():
@@ -201,7 +191,7 @@ def broadcast_op7(ctx, node, name, args):
201191
if shape0 != shape1:
202192
# this works around shortcomings in the broadcasting code
203193
# of caffe2 and winml/rs4.
204-
if ctx.is_target(TARGET_RS4):
194+
if ctx.is_target(constants.TARGET_RS4):
205195
# in rs4 mul and add do not support scalar correctly
206196
if not shape0:
207197
if node.inputs[0].is_const():
@@ -1027,7 +1017,7 @@ def stridedslice_op(ctx, node, name, args):
10271017

10281018

10291019
def pow_op(ctx, node, name, args):
1030-
if ctx.is_target(TARGET_CAFFE2):
1020+
if ctx.is_target(constants.TARGET_CAFFE2):
10311021
# workaround a bug in caffe2 pre Feb2018, pow(a, b) becomes np.exp(np.log(a) * b)
10321022
node.type = "Log"
10331023
b = node.input[1]
@@ -1300,7 +1290,8 @@ def onehot_op9(ctx, node, name, args):
13001290
# in ONNX, op's schema is (input, depth, value, @int axis), meaning of "value" is [off-value, on-value]
13011291
# onnxruntime only supports int64
13021292
output_dtype = ctx.get_dtype(node.input[2])
1303-
if ctx.is_target(TARGET_RS6) and output_dtype not in [onnx_pb.TensorProto.INT64, onnx_pb.TensorProto.INT32]:
1293+
if ctx.is_target(constants.TARGET_RS6) \
1294+
and output_dtype not in [onnx_pb.TensorProto.INT64, onnx_pb.TensorProto.INT32]:
13041295
log.warning("unsupported dtype in onnxruntime, onehot-9 can't be used directly")
13051296
onehot_op(ctx, node, name, args)
13061297
return
@@ -1315,21 +1306,21 @@ def onehot_op9(ctx, node, name, args):
13151306
off_on_value = ctx.make_node("Concat", [off_value, on_value], attr={"axis": 0}).output[0]
13161307

13171308
indices = node.input[0]
1318-
if ctx.is_target(TARGET_RS6) and ctx.get_dtype(indices) != onnx_pb.TensorProto.INT64:
1309+
if ctx.is_target(constants.TARGET_RS6) and ctx.get_dtype(indices) != onnx_pb.TensorProto.INT64:
13191310
indices = ctx.make_node("Cast", [indices], attr={"to": onnx_pb.TensorProto.INT64}).output[0]
13201311
node.input[0] = indices
13211312

1322-
if ctx.is_target(TARGET_RS6) and ctx.get_dtype(depth) != onnx_pb.TensorProto.INT64:
1313+
if ctx.is_target(constants.TARGET_RS6) and ctx.get_dtype(depth) != onnx_pb.TensorProto.INT64:
13231314
depth = ctx.make_node("Cast", [depth], attr={"to": onnx_pb.TensorProto.INT64}).output[0]
13241315
node.input[1] = depth
13251316

1326-
if ctx.is_target(TARGET_RS6) and output_dtype != onnx_pb.TensorProto.INT64:
1317+
if ctx.is_target(constants.TARGET_RS6) and output_dtype != onnx_pb.TensorProto.INT64:
13271318
off_on_value = ctx.make_node("Cast", [off_on_value], attr={"to": onnx_pb.TensorProto.INT64}).output[0]
13281319
node.input[2] = off_on_value
13291320

13301321
del node.input[3]
13311322

1332-
if ctx.is_target(TARGET_RS6) and output_dtype != onnx_pb.TensorProto.INT64:
1323+
if ctx.is_target(constants.TARGET_RS6) and output_dtype != onnx_pb.TensorProto.INT64:
13331324
new_node_name = utils.make_name("onehot_output")
13341325
new_node = ctx.insert_new_node_on_output("Cast", node.output[0], new_node_name, to=output_dtype)
13351326
ctx.set_dtype(new_node.output[0], output_dtype)
@@ -2505,7 +2496,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
25052496
if inputs_as_nchw is None:
25062497
inputs_as_nchw = []
25072498
if target is None:
2508-
target = DEFAULT_TARGET
2499+
target = constants.DEFAULT_TARGET
25092500

25102501
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes = tensorflow_to_onnx(tf_graph, shape_override)
25112502

@@ -2558,9 +2549,9 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
25582549

25592550
# post-processing rewriters
25602551
late_rewriters = []
2561-
if TARGET_RS5 in target:
2552+
if constants.TARGET_RS5 in target:
25622553
late_rewriters.append(rewrite_incomplete_type_support_rs5)
2563-
if TARGET_RS6 in target:
2554+
if constants.TARGET_RS6 in target:
25642555
late_rewriters.append(rewrite_incomplete_type_support_rs6)
25652556
if late_rewriters:
25662557
run_rewriters(g, late_rewriters, continue_on_error)

tf2onnx/utils.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from google.protobuf import text_format
2121
import onnx
2222
from onnx import helper, onnx_pb, defs, numpy_helper
23+
from . import constants
2324

2425
#
2526
# mapping dtypes from tensorflow to onnx
@@ -174,12 +175,7 @@ def get_tf_tensor_data(tensor):
174175
data = tensor.bool_val
175176
elif tensor.string_val:
176177
data = tensor.string_val
177-
elif tensor.dtype in [
178-
tf.int32,
179-
tf.int64,
180-
tf.float32,
181-
tf.float16
182-
]:
178+
elif tensor.dtype in [tf.int32, tf.int64, tf.float32, tf.float16]:
183179
data = None
184180
else:
185181
raise ValueError('tensor data not supported')
@@ -249,16 +245,13 @@ def make_onnx_inputs_outputs(name, elem_type, shape, **kwargs):
249245
return helper.make_tensor_value_info(name, elem_type, make_onnx_shape(shape), **kwargs)
250246

251247

252-
PREFERRED_OPSET = 7
253-
254-
255248
def find_opset(opset):
256249
"""Find opset."""
257250
if opset is None or opset == 0:
258251
opset = defs.onnx_opset_version()
259-
if opset > PREFERRED_OPSET:
252+
if opset > constants.PREFERRED_OPSET:
260253
# if we use a newer onnx opset than most runtimes support, default to the one most supported
261-
opset = PREFERRED_OPSET
254+
opset = constants.PREFERRED_OPSET
262255
return opset
263256

264257

0 commit comments

Comments
 (0)