Skip to content

Commit 759b242

Browse files
zerollzengguschmue
andauthored
Add support for register custom ops (#1518)
* Add support for load tf op libraries Signed-off-by: zerollzeng <[email protected]> * Fix missing code Signed-off-by: zerollzeng <[email protected]> Co-authored-by: Guenther Schmuelling <[email protected]>
1 parent 98c3567 commit 759b242

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

tf2onnx/convert.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ def get_args():
7171
parser.add_argument("--custom-ops", help="comma-separated map of custom ops to domains in format OpName:domain")
7272
parser.add_argument("--extra_opset", default=None,
7373
help="extra opset with format like domain:version, e.g. com.microsoft:1")
74+
parser.add_argument("--load_op_libraries",
75+
help="comma-separated list of tf op library paths to register before loading model")
7476
parser.add_argument("--target", default=",".join(constants.DEFAULT_TARGET), choices=constants.POSSIBLE_TARGETS,
7577
help="target platform")
7678
parser.add_argument("--continue_on_error", help="continue_on_error", action="store_true")
@@ -119,7 +121,8 @@ def get_args():
119121
if len(tokens) != 2:
120122
parser.error("invalid extra_opset argument")
121123
args.extra_opset = [utils.make_opsetid(tokens[0], int(tokens[1]))]
122-
124+
if args.load_op_libraries:
125+
args.load_op_libraries = args.load_op_libraries.split(",")
123126
return args
124127

125128

@@ -197,6 +200,9 @@ def main():
197200
outputs = None
198201
model_path = None
199202

203+
if args.load_op_libraries:
204+
for op_path in args.load_op_libraries:
205+
tf.load_op_library(op_path)
200206
if args.graphdef:
201207
graph_def, inputs, outputs = tf_loader.from_graphdef(args.graphdef, args.inputs, args.outputs)
202208
model_path = args.graphdef

0 commit comments

Comments
 (0)