@@ -378,7 +378,8 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
378
378
379
379
new_outputs = [o if o != output else new_output_name for output in n .output ]
380
380
new_node = self .make_node (n .type , n .input , outputs = new_outputs , attr = n .attr , name = n .name ,
381
- skip_conversion = n ._skip_conversion , dtypes = n_dtypes , shapes = n_shapes )
381
+ skip_conversion = n ._skip_conversion , dtypes = n_dtypes , shapes = n_shapes ,
382
+ domain = n .domain )
382
383
383
384
if body_graphs :
384
385
for attr_name , body_graph in body_graphs .items ():
@@ -423,7 +424,7 @@ def make_const(self, name, np_val, skip_conversion=False, raw=True):
423
424
return node
424
425
425
426
def make_node (self , op_type , inputs , attr = None , output_count = 1 , outputs = None , skip_conversion = True ,
426
- op_name_scope = None , name = None , shapes = None , dtypes = None ):
427
+ op_name_scope = None , name = None , shapes = None , dtypes = None , domain = None ):
427
428
"""Make a new onnx node in the graph"""
428
429
if attr is None :
429
430
attr = {}
@@ -456,7 +457,7 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
456
457
n = self .get_node_by_output_in_current_graph (o )
457
458
utils .make_sure (n is None , "output tensor named %s already exists in node: \n %s" , o , n )
458
459
459
- onnx_node = helper .make_node (op_type , inputs , outputs , name = name , ** raw_attr )
460
+ onnx_node = helper .make_node (op_type , inputs , outputs , name = name , domain = domain , ** raw_attr )
460
461
461
462
if op_type in ["If" , "Loop" , "Scan" ]:
462
463
# we force the op containing inner graphs not skipped during conversion.
@@ -883,7 +884,7 @@ def remove_input(node, to_be_removed):
883
884
# don't remove output from parent since others might depend on it
884
885
return True
885
886
886
- def insert_new_node_on_input (self , node , op_type , input_name , name = None , ** kwargs ):
887
+ def insert_new_node_on_input (self , node , op_type , input_name , name = None , domain = None , ** kwargs ):
887
888
"""Create and insert a new node into the graph.
888
889
Args:
889
890
node: we want to replace the input for this node
@@ -898,14 +899,14 @@ def insert_new_node_on_input(self, node, op_type, input_name, name=None, **kwarg
898
899
if name is None :
899
900
name = utils .make_name (node .name )
900
901
new_output = port_name (name )
901
- new_node = self .make_node (op_type , [input_name ], attr = kwargs , outputs = [new_output ], name = name )
902
+ new_node = self .make_node (op_type , [input_name ], attr = kwargs , outputs = [new_output ], name = name , domain = domain )
902
903
for i , n in enumerate (node .input ):
903
904
if n == input_name :
904
905
node .input [i ] = new_output
905
906
break
906
907
return new_node
907
908
908
- def insert_new_node_on_output (self , op_type , output_name , name = None , ** kwargs ):
909
+ def insert_new_node_on_output (self , op_type , output_name , name = None , domain = None , ** kwargs ):
909
910
"""Create and insert a new node into the graph.
910
911
Args:
911
912
op_type: type for new operation
@@ -922,7 +923,7 @@ def insert_new_node_on_output(self, op_type, output_name, name=None, **kwargs):
922
923
type (op_type ))
923
924
924
925
new_output = port_name (name )
925
- new_node = self .make_node (op_type , [output_name ], attr = kwargs , outputs = [new_output ], name = name )
926
+ new_node = self .make_node (op_type , [output_name ], attr = kwargs , outputs = [new_output ], name = name , domain = domain )
926
927
927
928
to_replace = [n for n in self .get_nodes () if n != new_node ]
928
929
self .replace_all_inputs (to_replace , output_name , new_output )
0 commit comments