@@ -1258,7 +1258,7 @@ def remove_input(self, node, to_be_removed, input_index=None):
1258
1258
1259
1259
# don't remove output from parent since others might depend on it
1260
1260
1261
- def insert_new_node_on_input (self , node , op_type , input_name , name = None , domain = None , ** kwargs ):
1261
+ def insert_new_node_on_input (self , node , op_type , input_name , name = None , domain = None , input_index = None , ** kwargs ):
1262
1262
"""Create and insert a new node into the graph.
1263
1263
Args:
1264
1264
node: we want to replace the input for this node
@@ -1279,10 +1279,13 @@ def insert_new_node_on_input(self, node, op_type, input_name, name=None, domain=
1279
1279
input_name = [input_name ]
1280
1280
1281
1281
new_node = self .make_node (op_type , input_name , attr = kwargs , outputs = [new_output ], name = name , domain = domain )
1282
- for i , n in enumerate (node .input ):
1283
- if n == input_name [0 ]:
1284
- self .replace_input (node , node .input [i ], new_output , i )
1285
- break
1282
+ if input_index is None :
1283
+ for i , n in enumerate (node .input ):
1284
+ if n == input_name [0 ]:
1285
+ self .replace_input (node , node .input [i ], new_output , i )
1286
+ break
1287
+ else :
1288
+ self .replace_input (node , node .input [input_index ], new_output , input_index )
1286
1289
return new_node
1287
1290
1288
1291
def insert_node_on_output (self , node , output_name = None ):
0 commit comments