Skip to content

Commit d064ff0

Browse files
committed
add extra_opset property for graph
1 parent 74b120e commit d064ff0

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

tf2onnx/graph.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,9 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
345345

346346
self._output_shapes = output_shapes
347347
self._opset = find_opset(opset)
348+
349+
if extra_opset is not None:
350+
utils.make_sure(isinstance(extra_opset, list), "invalid extra_opset")
348351
self._extra_opset = extra_opset
349352

350353
self._order_sensitive_inputs = []
@@ -384,12 +387,16 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
384387
def create_new_graph_with_same_config(self):
385388
"""Create a clean graph inheriting current graph's configuration."""
386389
return Graph([], output_shapes={}, dtypes={}, target=self._target, opset=self._opset,
387-
extra_opset=self._extra_opset, output_names=[])
390+
extra_opset=self.extra_opset, output_names=[])
388391

389392
@property
390393
def opset(self):
391394
return self._opset
392395

396+
@property
397+
def extra_opset(self):
398+
return self._extra_opset
399+
393400
def is_target(self, *names):
394401
"""Return True if target platform contains any name."""
395402
return any(name in self._target for name in names)
@@ -801,8 +808,8 @@ def make_model(self, graph_doc, optimize=False, graph_name="tf2onnx", **kwargs):
801808
imp = OperatorSetIdProto()
802809
imp.version = self._opset
803810
opsets.append(imp)
804-
if self._extra_opset is not None:
805-
opsets.extend(self._extra_opset)
811+
if self.extra_opset is not None:
812+
opsets.extend(self.extra_opset)
806813
kwargs["opset_imports"] = opsets
807814
model_proto = helper.make_model(graph, **kwargs)
808815

0 commit comments

Comments
 (0)