Skip to content

Commit f7d49c7

Browse files
authored
Allow --extra_opset to accept a list of extra opsets. (#2136)
Signed-off-by: Jay Zhang <[email protected]>
1 parent 2781a32 commit f7d49c7

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ python -m tf2onnx.convert
139139
[--signature_def SIGNATURE_DEF]
140140
[--concrete_function CONCRETE_FUNCTION]
141141
[--target TARGET]
142+
[--extra_opset list-of-extra-opset]
142143
[--custom-ops list-of-custom-ops]
143144
[--load_op_libraries tensorflow_library_path]
144145
[--large_model]
@@ -217,6 +218,11 @@ Only valid with parameter `--saved_model`. If a model contains a list of concret
217218

218219
Some models require special handling to run on some runtimes. In particular, the model may use unsupported data types. Workarounds are activated with ```--target TARGET```. Currently supported values are listed on this [wiki](https://github.com/onnx/tensorflow-onnx/wiki/target). If your model will be run on Windows ML, you should specify the appropriate target value.
219220

221+
#### --extra_opset
222+
223+
If you want to convert a TF model using an existing custom op, this can specify the correspongding domain and version.
224+
The format is a comma-separated map of domain and version, for example: `ai.onnx.contrib:1`.
225+
220226
#### --custom-ops
221227

222228
If a model contains ops not recognized by onnx runtime, you can tag these ops with a custom op domain so that the

tf2onnx/convert.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,14 @@ def get_args():
123123
if not args.tflite:
124124
parser.error("dequantize flag is currently only supported for tflite")
125125
if args.extra_opset:
126-
tokens = args.extra_opset.split(':')
127-
if len(tokens) != 2:
128-
parser.error("invalid extra_opset argument")
129-
args.extra_opset = [utils.make_opsetid(tokens[0], int(tokens[1]))]
126+
all_extra_opsets = args.extra_opset.split(',')
127+
extra_opset_list = []
128+
for extra_opset in all_extra_opsets:
129+
tokens = extra_opset.split(':')
130+
if len(tokens) != 2:
131+
parser.error("invalid extra_opset argument")
132+
extra_opset_list.append(utils.make_opsetid(tokens[0], int(tokens[1])))
133+
args.extra_opset = extra_opset_list
130134
if args.load_op_libraries:
131135
args.load_op_libraries = args.load_op_libraries.split(",")
132136
return args

0 commit comments

Comments
 (0)