Skip to content

Commit d0a88a6

Browse files
committed
create_graph_from_onnx_model support extra_opset
1 parent 33f3514 commit d0a88a6

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

tf2onnx/graph.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,12 +1105,22 @@ def create_graph_from_onnx_model(onnx_model_proto):
11051105
# apply shape inference on the model
11061106
inferred_model = shape_inference.infer_shapes(onnx_model_proto)
11071107
graph_proto = inferred_model.graph
1108-
opset_version = onnx_model_proto.opset_import[0].version
1109-
main_graph = GraphUtil.create_graph_from_onnx_graph(graph_proto, opset_version)
1108+
1109+
opset_version = None
1110+
extra_opset = []
1111+
for opset in onnx_model_proto.opset_import:
1112+
if not opset.domain:
1113+
# domain field is None or empty means it is onnx domain
1114+
opset_version = opset.version
1115+
else:
1116+
extra_opset.append(opset)
1117+
1118+
utils.make_sure(opset_version is not None, "opset version is not specified for onnx domain")
1119+
main_graph = GraphUtil.create_graph_from_onnx_graph(graph_proto, opset_version, extra_opset)
11101120
return main_graph
11111121

11121122
@staticmethod
1113-
def create_graph_from_onnx_graph(graph_proto, opset_version=None):
1123+
def create_graph_from_onnx_graph(graph_proto, opset_version=None, extra_opset=None):
11141124
"""Create Graph loading onnx graph proto."""
11151125
output_shapes = {}
11161126
output_dtypes = {}
@@ -1137,15 +1147,15 @@ def create_graph_from_onnx_graph(graph_proto, opset_version=None):
11371147
for n in graph_proto.output:
11381148
output_names.append(n.name)
11391149

1140-
g = Graph(nodes_to_append, output_shapes, output_dtypes, None, opset_version, None, output_names)
1150+
g = Graph(nodes_to_append, output_shapes, output_dtypes, None, opset_version, extra_opset, output_names)
11411151
const_nodes = GraphUtil._parse_graph_initializer(g, graph_proto)
11421152
GraphUtil._parse_graph_input(g, graph_proto, [n.name for n in const_nodes])
11431153

11441154
for n in g.get_nodes():
11451155
for attr_name, attr_val in n.attr.items():
11461156
if attr_val.HasField('g'):
11471157
# it was assumed that the a.g has inferred shapes/dtypes.
1148-
sub_g = GraphUtil.create_graph_from_onnx_graph(attr_val.g, opset_version)
1158+
sub_g = GraphUtil.create_graph_from_onnx_graph(attr_val.g, opset_version, extra_opset)
11491159
n.set_body_graph_as_attr(attr_name, sub_g)
11501160
return g
11511161

0 commit comments

Comments
 (0)