Skip to content

Commit 33f3514

Browse files
committed
support extra opset mapping
1 parent cd77071 commit 33f3514

File tree

2 files changed

+30
-9
lines changed

2 files changed

+30
-9
lines changed

tf2onnx/convert.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ def get_args():
3131
parser.add_argument("--output", help="output model file")
3232
parser.add_argument("--inputs", help="model input_names")
3333
parser.add_argument("--outputs", help="model output_names")
34-
parser.add_argument("--opset", type=int, default=None, help="onnx opset to use")
34+
parser.add_argument("--opset", type=int, default=None, help="opset version to use for onnx domain")
3535
parser.add_argument("--custom-ops", help="list of custom ops")
36+
parser.add_argument("--extra_opset", default=None,
37+
help="extra opset with format like domain:version, e.g. com.microsoft:1")
3638
parser.add_argument("--target", default=",".join(constants.DEFAULT_TARGET), choices=constants.POSSIBLE_TARGETS,
3739
help="target platform")
3840
parser.add_argument("--continue_on_error", help="continue_on_error", action="store_true")
@@ -63,6 +65,11 @@ def get_args():
6365
if args.target:
6466
args.target = args.target.split(",")
6567

68+
if args.extra_opset:
69+
tokens = args.extra_opset.split(':')
70+
if len(tokens) != 2:
71+
raise ValueError("invalid extra_opset argument")
72+
args.extra_opset = [helper.make_opsetid(tokens[0], int(tokens[1]))]
6673
return args
6774

6875

@@ -78,13 +85,12 @@ def main():
7885
# support unknown dimensions.
7986
utils.ONNX_UNKNOWN_DIMENSION = args.unknown_dim
8087

88+
extra_opset = args.extra_opset or []
89+
custom_ops = {}
8190
if args.custom_ops:
8291
# default custom ops for tensorflow-onnx are in the "tf" namespace
8392
custom_ops = {op: (default_custom_op_handler, []) for op in args.custom_ops.split(",")}
84-
extra_opset = [constants.DEFAULT_CUSTOM_OP_OPSET]
85-
else:
86-
custom_ops = {}
87-
extra_opset = None
93+
extra_opset.append(constants.DEFAULT_CUSTOM_OP_OPSET)
8894

8995
# get the frozen tensorflow model from graphdef, checkpoint or saved_model.
9096
if args.graphdef:

tf2onnx/tfonnx.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from tensorflow.tools.graph_transforms import TransformGraph
2222

2323
import tf2onnx
24-
from tf2onnx import constants, schemas, utils
24+
from tf2onnx import constants, custom, schemas, utils
2525
from tf2onnx.function import * # pylint: disable=wildcard-import
2626
from tf2onnx.graph import Graph
2727
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
@@ -2322,13 +2322,28 @@ def tensorflow_onnx_mapping(g, continue_on_error, custom_op_handlers):
23222322
mapped_op = collections.Counter()
23232323
unmapped_op = collections.Counter()
23242324

2325-
# create ops mapping for the desired opset
2325+
# create ops mapping for the desired opsets
23262326
ops_mapping = {}
2327+
2328+
# load mapping for onnx domain
23272329
for target_opset, op_map in _OPSETS:
23282330
if target_opset <= g.opset:
23292331
ops_mapping.update(op_map)
23302332

2331-
# apply custom ops on top of the assembled opset. We can either completment the opset
2333+
# load mapping for known extra opsets
2334+
# order matters, later mapping overrides earlier's
2335+
if g.extra_opset is not None:
2336+
for extra_opset in g.extra_opset:
2337+
if extra_opset.domain == constants.MICROSOFT_DOMAIN:
2338+
# microsoft domain
2339+
for target_opset, op_map in custom.ms.OPSETS:
2340+
if target_opset <= extra_opset.version:
2341+
ops_mapping.update(op_map)
2342+
else:
2343+
# unknown opset, assume used in custom_op_handlers, skip it
2344+
pass
2345+
2346+
# apply custom ops on top of the assembled opset. We can either complement the opset
23322347
# or override existing ops with a custom op.
23332348
if custom_op_handlers is not None:
23342349
custom_opset = {k: v for k, v in custom_op_handlers.items()}
@@ -2337,7 +2352,7 @@ def tensorflow_onnx_mapping(g, continue_on_error, custom_op_handlers):
23372352
ops = [n for n in g.get_nodes()]
23382353
for node in ops:
23392354
if node.need_skip():
2340-
log.debug("explictly skip node " + node.name)
2355+
log.debug("explicitly skip node " + node.name)
23412356
continue
23422357

23432358
op = node.type

0 commit comments

Comments
 (0)