@@ -71,7 +71,8 @@ def get_args():
71
71
parser .add_argument ("--opset" , type = int , default = None , help = "opset version to use for onnx domain" )
72
72
parser .add_argument ("--dequantize" , help = "Remove quantization from model. Only supported for tflite currently." ,
73
73
action = "store_true" )
74
- parser .add_argument ("--custom-ops" , help = "comma-separated map of custom ops to domains in format OpName:domain" )
74
+ parser .add_argument ("--custom-ops" , help = "Comma-separated map of custom ops to domains in format OpName:domain. "
75
+ "Domain 'ai.onnx.converters.tensorflow' is used by default." )
75
76
parser .add_argument ("--extra_opset" , default = None ,
76
77
help = "extra opset with format like domain:version, e.g. com.microsoft:1" )
77
78
parser .add_argument ("--load_op_libraries" ,
@@ -137,13 +138,19 @@ def default_custom_op_handler(ctx, node, name, args):
137
138
138
139
139
140
def _convert_common (frozen_graph , name = "unknown" , large_model = False , output_path = None ,
140
- output_frozen_graph = None , ** kwargs ):
141
+ output_frozen_graph = None , custom_ops = None , custom_op_handlers = None , ** kwargs ):
141
142
"""Common processing for conversion."""
142
143
143
144
model_proto = None
144
145
external_tensor_storage = None
145
146
const_node_values = None
146
147
148
+ if custom_ops is not None :
149
+ if custom_op_handlers is None :
150
+ custom_op_handlers = {}
151
+ custom_op_handlers .update (
152
+ {op : (make_default_custom_op_handler (domain ), []) for op , domain in custom_ops .items ()})
153
+
147
154
with tf .Graph ().as_default () as tf_graph :
148
155
if large_model :
149
156
const_node_values = compress_graph_def (frozen_graph )
@@ -152,7 +159,8 @@ def _convert_common(frozen_graph, name="unknown", large_model=False, output_path
152
159
utils .save_protobuf (output_frozen_graph , frozen_graph )
153
160
if not kwargs .get ("tflite_path" ) and not kwargs .get ("tfjs_path" ):
154
161
tf .import_graph_def (frozen_graph , name = '' )
155
- g = process_tf_graph (tf_graph , const_node_values = const_node_values , ** kwargs )
162
+ g = process_tf_graph (tf_graph , const_node_values = const_node_values ,
163
+ custom_op_handlers = custom_op_handlers , ** kwargs )
156
164
if constants .ENV_TF2ONNX_CATCH_ERRORS in os .environ :
157
165
catch_errors = constants .ENV_TF2ONNX_CATCH_ERRORS .upper () == "TRUE"
158
166
else :
@@ -180,7 +188,7 @@ def main():
180
188
extra_opset = args .extra_opset or []
181
189
tflite_path = None
182
190
tfjs_path = None
183
- custom_ops = {}
191
+ custom_op_handlers = {}
184
192
initialized_tables = None
185
193
tensors_to_rename = {}
186
194
if args .custom_ops :
@@ -192,7 +200,7 @@ def main():
192
200
# default custom ops for tensorflow-onnx are in the "tf" namespace
193
201
using_tf_opset = True
194
202
domain = constants .TENSORFLOW_OPSET .domain
195
- custom_ops [op ] = (make_default_custom_op_handler (domain ), [])
203
+ custom_op_handlers [op ] = (make_default_custom_op_handler (domain ), [])
196
204
if using_tf_opset :
197
205
extra_opset .append (constants .TENSORFLOW_OPSET )
198
206
@@ -259,7 +267,7 @@ def main():
259
267
continue_on_error = args .continue_on_error ,
260
268
target = args .target ,
261
269
opset = args .opset ,
262
- custom_op_handlers = custom_ops ,
270
+ custom_op_handlers = custom_op_handlers ,
263
271
extra_opset = extra_opset ,
264
272
shape_override = args .shape_override ,
265
273
input_names = inputs ,
@@ -371,7 +379,9 @@ def _from_keras_tf1(model, input_signature=None, opset=None, custom_ops=None, cu
371
379
continue_on_error = True ,
372
380
target = target ,
373
381
opset = opset ,
374
- custom_op_handlers = custom_ops ,
382
+ custom_ops = custom_ops ,
383
+ custom_op_handlers = custom_op_handlers ,
384
+ custom_rewriter = custom_rewriter ,
375
385
extra_opset = extra_opset ,
376
386
shape_override = shape_override ,
377
387
input_names = input_names ,
@@ -475,7 +485,9 @@ def wrap_call(*args, training=False, **kwargs):
475
485
continue_on_error = True ,
476
486
target = target ,
477
487
opset = opset ,
478
- custom_op_handlers = custom_ops ,
488
+ custom_ops = custom_ops ,
489
+ custom_op_handlers = custom_op_handlers ,
490
+ custom_rewriter = custom_rewriter ,
479
491
extra_opset = extra_opset ,
480
492
shape_override = shape_override ,
481
493
input_names = input_names ,
@@ -537,7 +549,9 @@ def from_function(function, input_signature=None, opset=None, custom_ops=None, c
537
549
continue_on_error = True ,
538
550
target = target ,
539
551
opset = opset ,
540
- custom_op_handlers = custom_ops ,
552
+ custom_ops = custom_ops ,
553
+ custom_op_handlers = custom_op_handlers ,
554
+ custom_rewriter = custom_rewriter ,
541
555
extra_opset = extra_opset ,
542
556
shape_override = shape_override ,
543
557
input_names = input_names ,
@@ -599,7 +613,9 @@ def from_graph_def(graph_def, name=None, input_names=None, output_names=None, op
599
613
continue_on_error = True ,
600
614
target = target ,
601
615
opset = opset ,
602
- custom_op_handlers = custom_ops ,
616
+ custom_ops = custom_ops ,
617
+ custom_op_handlers = custom_op_handlers ,
618
+ custom_rewriter = custom_rewriter ,
603
619
extra_opset = extra_opset ,
604
620
shape_override = shape_override ,
605
621
input_names = input_names ,
0 commit comments