@@ -345,6 +345,9 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
345
345
346
346
self ._output_shapes = output_shapes
347
347
self ._opset = find_opset (opset )
348
+
349
+ if extra_opset is not None :
350
+ utils .make_sure (isinstance (extra_opset , list ), "invalid extra_opset" )
348
351
self ._extra_opset = extra_opset
349
352
350
353
self ._order_sensitive_inputs = []
@@ -384,12 +387,16 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
384
387
def create_new_graph_with_same_config (self ):
385
388
"""Create a clean graph inheriting current graph's configuration."""
386
389
return Graph ([], output_shapes = {}, dtypes = {}, target = self ._target , opset = self ._opset ,
387
- extra_opset = self ._extra_opset , output_names = [])
390
+ extra_opset = self .extra_opset , output_names = [])
388
391
389
392
@property
390
393
def opset (self ):
391
394
return self ._opset
392
395
396
+ @property
397
+ def extra_opset (self ):
398
+ return self ._extra_opset
399
+
393
400
def is_target (self , * names ):
394
401
"""Return True if target platform contains any name."""
395
402
return any (name in self ._target for name in names )
@@ -801,8 +808,8 @@ def make_model(self, graph_doc, optimize=False, graph_name="tf2onnx", **kwargs):
801
808
imp = OperatorSetIdProto ()
802
809
imp .version = self ._opset
803
810
opsets .append (imp )
804
- if self ._extra_opset is not None :
805
- opsets .extend (self ._extra_opset )
811
+ if self .extra_opset is not None :
812
+ opsets .extend (self .extra_opset )
806
813
kwargs ["opset_imports" ] = opsets
807
814
model_proto = helper .make_model (graph , ** kwargs )
808
815
0 commit comments