@@ -1178,27 +1178,29 @@ def dump_node_statistics(self):
1178
1178
1179
1179
return op_cnt
1180
1180
1181
- def remove_input (self , node , to_be_removed , i = None ):
1181
+ def remove_input (self , node , to_be_removed , input_index = None ):
1182
1182
"""Remove input from Node.
1183
1183
Args:
1184
1184
node: the node we expect the input on
1185
1185
to_be_removed: the node name we want to remove
1186
- i: if not None, index of the input to be removed
1186
+ input_index: if not None, index of the input to be removed,
1187
+ the method is more efficient if *input_index* is specified,
1188
+ otherwise, it has to look for every input named *old_input*.
1187
1189
"""
1188
1190
assert isinstance (node , Node ) and isinstance (to_be_removed , six .text_type )
1189
- if i is not None :
1190
- assert node .input [i ] == to_be_removed
1191
- if node .input [i ] in self ._input_to_node_name :
1192
- to_ops = self ._input_to_node_name [node .input [i ]]
1191
+ if input_index is not None :
1192
+ assert node .input [input_index ] == to_be_removed
1193
+ if node .input [input_index ] in self ._input_to_node_name :
1194
+ to_ops = self ._input_to_node_name [node .input [input_index ]]
1193
1195
if node .name in to_ops :
1194
1196
to_ops .remove (node .name )
1195
- del node .input [i ]
1197
+ del node .input [input_index ]
1196
1198
return True
1197
1199
1198
- for i2 , name in enumerate (node .input ):
1200
+ for i , name in enumerate (node .input ):
1199
1201
if name == to_be_removed :
1200
- self ._unregister_input_name (node .input [i2 ], node )
1201
- del node .input [i2 ]
1202
+ self ._unregister_input_name (node .input [i ], node )
1203
+ del node .input [i ]
1202
1204
break
1203
1205
# don't remove output from parent since others might depend on it
1204
1206
return True
@@ -1336,17 +1338,21 @@ def replace_all_inputs(self, ops, old_input, new_input):
1336
1338
for _ , g in self ._input_to_graph [old_input ].items ():
1337
1339
g .replace_all_inputs (g .get_nodes () if keep_ops else None , old_input , new_input )
1338
1340
1339
- def replace_input (self , node , old_input , new_input , i = None ):
1340
- """Replace one input in a node."""
1341
+ def replace_input (self , node , old_input , new_input , input_index = None ):
1342
+ """
1343
+ Replace one input in a node.
1344
+ The method is more efficient if *input_index* is specified.
1345
+ Otherwise, it renames every output named *old_input*.
1346
+ """
1341
1347
assert isinstance (node , Node ) and isinstance (old_input , six .text_type ) and isinstance (new_input , six .text_type )
1342
1348
is_replaced = False
1343
- if i is None :
1344
- for i2 , input_name in enumerate (node .input ):
1349
+ if input_index is None :
1350
+ for i , input_name in enumerate (node .input ):
1345
1351
if input_name == old_input :
1346
- node .input [i2 ] = new_input
1352
+ node .input [i ] = new_input
1347
1353
is_replaced = True
1348
- elif node .input [i ] == old_input :
1349
- node .input [i ] = new_input
1354
+ elif node .input [input_index ] == old_input :
1355
+ node .input [input_index ] = new_input
1350
1356
is_replaced = True
1351
1357
else :
1352
1358
raise RuntimeError ("Unable to replace input %r into %r for node %r." % (old_input , new_input , node .name ))
0 commit comments