Skip to content

Commit 1c60588

Browse files
Fix usage of custom_ops, custom_op_handlers, and custom_rewriter ags (#1708)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 5e48449 commit 1c60588

File tree

1 file changed

+26
-10
lines changed

1 file changed

+26
-10
lines changed

tf2onnx/convert.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ def get_args():
7171
parser.add_argument("--opset", type=int, default=None, help="opset version to use for onnx domain")
7272
parser.add_argument("--dequantize", help="Remove quantization from model. Only supported for tflite currently.",
7373
action="store_true")
74-
parser.add_argument("--custom-ops", help="comma-separated map of custom ops to domains in format OpName:domain")
74+
parser.add_argument("--custom-ops", help="Comma-separated map of custom ops to domains in format OpName:domain. "
75+
"Domain 'ai.onnx.converters.tensorflow' is used by default.")
7576
parser.add_argument("--extra_opset", default=None,
7677
help="extra opset with format like domain:version, e.g. com.microsoft:1")
7778
parser.add_argument("--load_op_libraries",
@@ -137,13 +138,19 @@ def default_custom_op_handler(ctx, node, name, args):
137138

138139

139140
def _convert_common(frozen_graph, name="unknown", large_model=False, output_path=None,
140-
output_frozen_graph=None, **kwargs):
141+
output_frozen_graph=None, custom_ops=None, custom_op_handlers=None, **kwargs):
141142
"""Common processing for conversion."""
142143

143144
model_proto = None
144145
external_tensor_storage = None
145146
const_node_values = None
146147

148+
if custom_ops is not None:
149+
if custom_op_handlers is None:
150+
custom_op_handlers = {}
151+
custom_op_handlers.update(
152+
{op: (make_default_custom_op_handler(domain), []) for op, domain in custom_ops.items()})
153+
147154
with tf.Graph().as_default() as tf_graph:
148155
if large_model:
149156
const_node_values = compress_graph_def(frozen_graph)
@@ -152,7 +159,8 @@ def _convert_common(frozen_graph, name="unknown", large_model=False, output_path
152159
utils.save_protobuf(output_frozen_graph, frozen_graph)
153160
if not kwargs.get("tflite_path") and not kwargs.get("tfjs_path"):
154161
tf.import_graph_def(frozen_graph, name='')
155-
g = process_tf_graph(tf_graph, const_node_values=const_node_values, **kwargs)
162+
g = process_tf_graph(tf_graph, const_node_values=const_node_values,
163+
custom_op_handlers=custom_op_handlers, **kwargs)
156164
if constants.ENV_TF2ONNX_CATCH_ERRORS in os.environ:
157165
catch_errors = constants.ENV_TF2ONNX_CATCH_ERRORS.upper() == "TRUE"
158166
else:
@@ -180,7 +188,7 @@ def main():
180188
extra_opset = args.extra_opset or []
181189
tflite_path = None
182190
tfjs_path = None
183-
custom_ops = {}
191+
custom_op_handlers = {}
184192
initialized_tables = None
185193
tensors_to_rename = {}
186194
if args.custom_ops:
@@ -192,7 +200,7 @@ def main():
192200
# default custom ops for tensorflow-onnx are in the "tf" namespace
193201
using_tf_opset = True
194202
domain = constants.TENSORFLOW_OPSET.domain
195-
custom_ops[op] = (make_default_custom_op_handler(domain), [])
203+
custom_op_handlers[op] = (make_default_custom_op_handler(domain), [])
196204
if using_tf_opset:
197205
extra_opset.append(constants.TENSORFLOW_OPSET)
198206

@@ -259,7 +267,7 @@ def main():
259267
continue_on_error=args.continue_on_error,
260268
target=args.target,
261269
opset=args.opset,
262-
custom_op_handlers=custom_ops,
270+
custom_op_handlers=custom_op_handlers,
263271
extra_opset=extra_opset,
264272
shape_override=args.shape_override,
265273
input_names=inputs,
@@ -371,7 +379,9 @@ def _from_keras_tf1(model, input_signature=None, opset=None, custom_ops=None, cu
371379
continue_on_error=True,
372380
target=target,
373381
opset=opset,
374-
custom_op_handlers=custom_ops,
382+
custom_ops=custom_ops,
383+
custom_op_handlers=custom_op_handlers,
384+
custom_rewriter=custom_rewriter,
375385
extra_opset=extra_opset,
376386
shape_override=shape_override,
377387
input_names=input_names,
@@ -475,7 +485,9 @@ def wrap_call(*args, training=False, **kwargs):
475485
continue_on_error=True,
476486
target=target,
477487
opset=opset,
478-
custom_op_handlers=custom_ops,
488+
custom_ops=custom_ops,
489+
custom_op_handlers=custom_op_handlers,
490+
custom_rewriter=custom_rewriter,
479491
extra_opset=extra_opset,
480492
shape_override=shape_override,
481493
input_names=input_names,
@@ -537,7 +549,9 @@ def from_function(function, input_signature=None, opset=None, custom_ops=None, c
537549
continue_on_error=True,
538550
target=target,
539551
opset=opset,
540-
custom_op_handlers=custom_ops,
552+
custom_ops=custom_ops,
553+
custom_op_handlers=custom_op_handlers,
554+
custom_rewriter=custom_rewriter,
541555
extra_opset=extra_opset,
542556
shape_override=shape_override,
543557
input_names=input_names,
@@ -599,7 +613,9 @@ def from_graph_def(graph_def, name=None, input_names=None, output_names=None, op
599613
continue_on_error=True,
600614
target=target,
601615
opset=opset,
602-
custom_op_handlers=custom_ops,
616+
custom_ops=custom_ops,
617+
custom_op_handlers=custom_op_handlers,
618+
custom_rewriter=custom_rewriter,
603619
extra_opset=extra_opset,
604620
shape_override=shape_override,
605621
input_names=input_names,

0 commit comments

Comments
 (0)