@@ -425,7 +425,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
425
425
self ._nodes = []
426
426
self ._nodes_by_name = {}
427
427
self ._output_to_node_name = {}
428
- self ._input_to_node_name = {}
428
+ self ._output_to_consumers = {}
429
429
self ._input_to_graph = {}
430
430
self .shapes = {}
431
431
self .graph_name = graph_name or "tf2onnx"
@@ -642,7 +642,7 @@ def remove_node(self, node_name):
642
642
del self ._dtypes [op_output ]
643
643
644
644
for op_input in node .input :
645
- if op_input not in self ._input_to_node_name :
645
+ if op_input not in self ._output_to_consumers :
646
646
raise RuntimeError (
647
647
"Input %r of node %r not found." % (op_input , node_name ))
648
648
self ._unregister_input_name (op_input , node )
@@ -670,7 +670,7 @@ def reset_nodes(self, ops):
670
670
self .contained_graphs = remained_sub_graphs
671
671
self ._nodes_by_name = {op .name : op for op in ops }
672
672
self ._output_to_node_name = {}
673
- self ._input_to_node_name = {}
673
+ self ._output_to_consumers = {}
674
674
for op in ops :
675
675
for op_output in op .output :
676
676
self ._output_to_node_name [op_output ] = op .name
@@ -694,7 +694,7 @@ def reset_nodes(self, ops):
694
694
continue
695
695
if i .name .startswith ('keras_learning_phase' ):
696
696
continue
697
- if i .name not in self ._input_to_node_name :
697
+ if i .name not in self ._output_to_consumers :
698
698
raise ValueError ("graph input %r not exist in graph." % i .name )
699
699
700
700
self ._dtypes = remained_dtypes
@@ -1182,8 +1182,8 @@ def remove_input(self, node, to_be_removed, input_index=None):
1182
1182
assert isinstance (node , Node ) and isinstance (to_be_removed , six .text_type )
1183
1183
if input_index is not None :
1184
1184
assert node .input [input_index ] == to_be_removed
1185
- if node .input [input_index ] in self ._input_to_node_name :
1186
- to_ops = self ._input_to_node_name [node .input [input_index ]]
1185
+ if node .input [input_index ] in self ._output_to_consumers :
1186
+ to_ops = self ._output_to_consumers [node .input [input_index ]]
1187
1187
if node .name in to_ops :
1188
1188
to_ops .remove (node .name )
1189
1189
del node .input [input_index ]
@@ -1248,15 +1248,15 @@ def insert_new_node_on_output(self, op_type, output_name, name, domain=None, **k
1248
1248
new_output = port_name (name )
1249
1249
new_node = self .make_node (op_type , [output_name ], attr = kwargs , outputs = [new_output ], name = name , domain = domain )
1250
1250
1251
- to_replace = [self .get_node_by_name (n ) for n in self ._input_to_node_name [output_name ]]
1251
+ to_replace = [self .get_node_by_name (n ) for n in self ._output_to_consumers [output_name ]]
1252
1252
to_replace = [n for n in to_replace if n != new_node ]
1253
1253
self .replace_all_inputs (output_name , new_output , ops = to_replace )
1254
1254
return new_node
1255
1255
1256
1256
def find_output_consumers (self , output_name ):
1257
1257
"""Find all nodes consuming a given output."""
1258
- if output_name in self ._input_to_node_name :
1259
- ops = self ._input_to_node_name [output_name ]
1258
+ if output_name in self ._output_to_consumers :
1259
+ ops = self ._output_to_consumers [output_name ]
1260
1260
ops = [self .get_node_by_name (n ) for n in ops ]
1261
1261
else :
1262
1262
ops = [] # self.get_nodes()
@@ -1276,9 +1276,9 @@ def find_output_consumers(self, output_name):
1276
1276
def _register_input_name (self , input_name , node , only_graph = False ):
1277
1277
"Register node taking a specific input."
1278
1278
if not only_graph :
1279
- if input_name not in self ._input_to_node_name :
1280
- self ._input_to_node_name [input_name ] = set ()
1281
- self ._input_to_node_name [input_name ].add (node .name )
1279
+ if input_name not in self ._output_to_consumers :
1280
+ self ._output_to_consumers [input_name ] = set ()
1281
+ self ._output_to_consumers [input_name ].add (node .name )
1282
1282
if self .parent_graph is not None :
1283
1283
if input_name not in self .parent_graph ._input_to_graph :
1284
1284
self .parent_graph ._input_to_graph [input_name ] = {}
@@ -1289,9 +1289,9 @@ def _unregister_input_name(self, input_name, node, only_graph=False):
1289
1289
"Unregister node taking a specific input."
1290
1290
node_name = node .name
1291
1291
if not only_graph :
1292
- if input_name in self ._input_to_node_name [input_name ]:
1293
- if node_name in self ._input_to_node_name [input_name ]:
1294
- self ._input_to_node_name [input_name ].remove (node_name )
1292
+ if input_name in self ._output_to_consumers [input_name ]:
1293
+ if node_name in self ._output_to_consumers [input_name ]:
1294
+ self ._output_to_consumers [input_name ].remove (node_name )
1295
1295
if (self .parent_graph is not None and
1296
1296
input_name in self .parent_graph ._input_to_graph and
1297
1297
id (self ) in self .parent_graph ._input_to_graph [input_name ]):
@@ -1301,20 +1301,20 @@ def _unregister_input_name(self, input_name, node, only_graph=False):
1301
1301
def replace_all_inputs (self , old_input , new_input , ops = None ):
1302
1302
"""
1303
1303
Replace all inputs pointing to old_input with new_input.
1304
- *ops* is used if defined, otherwise `_input_to_node_name `
1304
+ *ops* is used if defined, otherwise `_output_to_consumers `
1305
1305
is used to determine the impacted nodes.
1306
1306
"""
1307
1307
if old_input == new_input :
1308
1308
return
1309
- if new_input not in self ._input_to_node_name :
1310
- self ._input_to_node_name [new_input ] = set ()
1309
+ if new_input not in self ._output_to_consumers :
1310
+ self ._output_to_consumers [new_input ] = set ()
1311
1311
1312
1312
if ops is not None :
1313
1313
keep_ops = True
1314
- elif old_input in self ._input_to_node_name :
1314
+ elif old_input in self ._output_to_consumers :
1315
1315
ops = list (
1316
1316
filter (lambda a : a is not None ,
1317
- map (self .get_node_by_name , self ._input_to_node_name [old_input ])))
1317
+ map (self .get_node_by_name , self ._output_to_consumers [old_input ])))
1318
1318
keep_ops = False
1319
1319
else :
1320
1320
ops = []
@@ -1355,7 +1355,7 @@ def replace_input(self, node, old_input, new_input, input_index=None):
1355
1355
else :
1356
1356
raise RuntimeError ("Unable to replace input %r into %r for node %r." % (old_input , new_input , node .name ))
1357
1357
1358
- to_ops = self ._input_to_node_name .get (old_input , None )
1358
+ to_ops = self ._output_to_consumers .get (old_input , None )
1359
1359
if to_ops is not None :
1360
1360
if node .name in to_ops :
1361
1361
# A node may take twice the same entry.
@@ -1369,7 +1369,7 @@ def replace_inputs(self, node, new_inputs):
1369
1369
assert isinstance (node , Node ) and isinstance (new_inputs , list )
1370
1370
1371
1371
for old_input in node .input :
1372
- to_ops = self ._input_to_node_name .get (old_input , None )
1372
+ to_ops = self ._output_to_consumers .get (old_input , None )
1373
1373
if to_ops is not None and old_input in to_ops :
1374
1374
# To avoid issues when a node
1375
1375
# takes twice the same entry.
0 commit comments