Skip to content

Commit 6137189

Browse files
Allow specifying domain of custom ops (#1206)
* Allow specifying domain of custom ops Signed-off-by: Tom Wildenhain <[email protected]> * Updated readme Signed-off-by: Tom Wildenhain <[email protected]>
1 parent d010b69 commit 6137189

File tree

2 files changed

+24
-9
lines changed

2 files changed

+24
-9
lines changed

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,13 @@ Only valid with parameter `--saved_model`. When set, creates a zip file containi
200200

201201
Saves the frozen tensorflow graph to file.
202202

203+
#### --custom-ops
204+
205+
If a model contains ops not recognized by onnx runtime, you can tag these ops with a custom op domain so that the
206+
runtime can still open the model. The format is a comma-separated map of tf op names to domains in the format
207+
OpName:domain. If only an op name is provided (no colon), the default domain of `ai.onnx.converters.tensorflow`
208+
will be used.
209+
203210
#### --target
204211

205212
Some models require special handling to run on some runtimes. In particular, the model may use unsupported data types. Workarounds are activated with ```--target TARGET```. Currently supported values are listed on this [wiki](https://github.com/onnx/tensorflow-onnx/wiki/target). If your model will be run on Windows ML, you should specify the appropriate target value.

tf2onnx/convert.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def get_args():
6060
parser.add_argument("--inputs", help="model input_names")
6161
parser.add_argument("--outputs", help="model output_names")
6262
parser.add_argument("--opset", type=int, default=None, help="opset version to use for onnx domain")
63-
parser.add_argument("--custom-ops", help="list of custom ops")
63+
parser.add_argument("--custom-ops", help="comma-separated map of custom ops to domains in format OpName:domain")
6464
parser.add_argument("--extra_opset", default=None,
6565
help="extra opset with format like domain:version, e.g. com.microsoft:1")
6666
parser.add_argument("--target", default=",".join(constants.DEFAULT_TARGET), choices=constants.POSSIBLE_TARGETS,
@@ -103,11 +103,11 @@ def get_args():
103103

104104
return args
105105

106-
107-
def default_custom_op_handler(ctx, node, name, args):
108-
node.domain = constants.TENSORFLOW_OPSET.domain
109-
return node
110-
106+
def make_default_custom_op_handler(domain):
107+
def default_custom_op_handler(ctx, node, name, args):
108+
node.domain = domain
109+
return node
110+
return default_custom_op_handler
111111

112112
def main():
113113
args = get_args()
@@ -121,9 +121,17 @@ def main():
121121
custom_ops = {}
122122
initialized_tables = None
123123
if args.custom_ops:
124-
# default custom ops for tensorflow-onnx are in the "tf" namespace
125-
custom_ops = {op: (default_custom_op_handler, []) for op in args.custom_ops.split(",")}
126-
extra_opset.append(constants.TENSORFLOW_OPSET)
124+
using_tf_opset = False
125+
for op in args.custom_ops.split(","):
126+
if ":" in op:
127+
op, domain = op.split(":")
128+
else:
129+
# default custom ops for tensorflow-onnx are in the "tf" namespace
130+
using_tf_opset = True
131+
domain = constants.TENSORFLOW_OPSET.domain
132+
custom_ops[op] = (make_default_custom_op_handler(domain), [])
133+
if using_tf_opset:
134+
extra_opset.append(constants.TENSORFLOW_OPSET)
127135

128136
# get the frozen tensorflow model from graphdef, checkpoint or saved_model.
129137
if args.graphdef:

0 commit comments

Comments
 (0)