Skip to content

Commit c9aa7ab

Browse files
committed
rename into _output_to_consumers
1 parent 1250847 commit c9aa7ab

File tree

1 file changed

+22
-22
lines changed

1 file changed

+22
-22
lines changed

tf2onnx/graph.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
425425
self._nodes = []
426426
self._nodes_by_name = {}
427427
self._output_to_node_name = {}
428-
self._input_to_node_name = {}
428+
self._output_to_consumers = {}
429429
self._input_to_graph = {}
430430
self.shapes = {}
431431
self.graph_name = graph_name or "tf2onnx"
@@ -642,7 +642,7 @@ def remove_node(self, node_name):
642642
del self._dtypes[op_output]
643643

644644
for op_input in node.input:
645-
if op_input not in self._input_to_node_name:
645+
if op_input not in self._output_to_consumers:
646646
raise RuntimeError(
647647
"Input %r of node %r not found." % (op_input, node_name))
648648
self._unregister_input_name(op_input, node)
@@ -670,7 +670,7 @@ def reset_nodes(self, ops):
670670
self.contained_graphs = remained_sub_graphs
671671
self._nodes_by_name = {op.name: op for op in ops}
672672
self._output_to_node_name = {}
673-
self._input_to_node_name = {}
673+
self._output_to_consumers = {}
674674
for op in ops:
675675
for op_output in op.output:
676676
self._output_to_node_name[op_output] = op.name
@@ -694,7 +694,7 @@ def reset_nodes(self, ops):
694694
continue
695695
if i.name.startswith('keras_learning_phase'):
696696
continue
697-
if i.name not in self._input_to_node_name:
697+
if i.name not in self._output_to_consumers:
698698
raise ValueError("graph input %r not exist in graph." % i.name)
699699

700700
self._dtypes = remained_dtypes
@@ -1182,8 +1182,8 @@ def remove_input(self, node, to_be_removed, input_index=None):
11821182
assert isinstance(node, Node) and isinstance(to_be_removed, six.text_type)
11831183
if input_index is not None:
11841184
assert node.input[input_index] == to_be_removed
1185-
if node.input[input_index] in self._input_to_node_name:
1186-
to_ops = self._input_to_node_name[node.input[input_index]]
1185+
if node.input[input_index] in self._output_to_consumers:
1186+
to_ops = self._output_to_consumers[node.input[input_index]]
11871187
if node.name in to_ops:
11881188
to_ops.remove(node.name)
11891189
del node.input[input_index]
@@ -1248,15 +1248,15 @@ def insert_new_node_on_output(self, op_type, output_name, name, domain=None, **k
12481248
new_output = port_name(name)
12491249
new_node = self.make_node(op_type, [output_name], attr=kwargs, outputs=[new_output], name=name, domain=domain)
12501250

1251-
to_replace = [self.get_node_by_name(n) for n in self._input_to_node_name[output_name]]
1251+
to_replace = [self.get_node_by_name(n) for n in self._output_to_consumers[output_name]]
12521252
to_replace = [n for n in to_replace if n != new_node]
12531253
self.replace_all_inputs(output_name, new_output, ops=to_replace)
12541254
return new_node
12551255

12561256
def find_output_consumers(self, output_name):
12571257
"""Find all nodes consuming a given output."""
1258-
if output_name in self._input_to_node_name:
1259-
ops = self._input_to_node_name[output_name]
1258+
if output_name in self._output_to_consumers:
1259+
ops = self._output_to_consumers[output_name]
12601260
ops = [self.get_node_by_name(n) for n in ops]
12611261
else:
12621262
ops = [] # self.get_nodes()
@@ -1276,9 +1276,9 @@ def find_output_consumers(self, output_name):
12761276
def _register_input_name(self, input_name, node, only_graph=False):
12771277
"Register node taking a specific input."
12781278
if not only_graph:
1279-
if input_name not in self._input_to_node_name:
1280-
self._input_to_node_name[input_name] = set()
1281-
self._input_to_node_name[input_name].add(node.name)
1279+
if input_name not in self._output_to_consumers:
1280+
self._output_to_consumers[input_name] = set()
1281+
self._output_to_consumers[input_name].add(node.name)
12821282
if self.parent_graph is not None:
12831283
if input_name not in self.parent_graph._input_to_graph:
12841284
self.parent_graph._input_to_graph[input_name] = {}
@@ -1289,9 +1289,9 @@ def _unregister_input_name(self, input_name, node, only_graph=False):
12891289
"Unregister node taking a specific input."
12901290
node_name = node.name
12911291
if not only_graph:
1292-
if input_name in self._input_to_node_name[input_name]:
1293-
if node_name in self._input_to_node_name[input_name]:
1294-
self._input_to_node_name[input_name].remove(node_name)
1292+
if input_name in self._output_to_consumers[input_name]:
1293+
if node_name in self._output_to_consumers[input_name]:
1294+
self._output_to_consumers[input_name].remove(node_name)
12951295
if (self.parent_graph is not None and
12961296
input_name in self.parent_graph._input_to_graph and
12971297
id(self) in self.parent_graph._input_to_graph[input_name]):
@@ -1301,20 +1301,20 @@ def _unregister_input_name(self, input_name, node, only_graph=False):
13011301
def replace_all_inputs(self, old_input, new_input, ops=None):
13021302
"""
13031303
Replace all inputs pointing to old_input with new_input.
1304-
*ops* is used if defined, otherwise `_input_to_node_name`
1304+
*ops* is used if defined, otherwise `_output_to_consumers`
13051305
is used to determine the impacted nodes.
13061306
"""
13071307
if old_input == new_input:
13081308
return
1309-
if new_input not in self._input_to_node_name:
1310-
self._input_to_node_name[new_input] = set()
1309+
if new_input not in self._output_to_consumers:
1310+
self._output_to_consumers[new_input] = set()
13111311

13121312
if ops is not None:
13131313
keep_ops = True
1314-
elif old_input in self._input_to_node_name:
1314+
elif old_input in self._output_to_consumers:
13151315
ops = list(
13161316
filter(lambda a: a is not None,
1317-
map(self.get_node_by_name, self._input_to_node_name[old_input])))
1317+
map(self.get_node_by_name, self._output_to_consumers[old_input])))
13181318
keep_ops = False
13191319
else:
13201320
ops = []
@@ -1355,7 +1355,7 @@ def replace_input(self, node, old_input, new_input, input_index=None):
13551355
else:
13561356
raise RuntimeError("Unable to replace input %r into %r for node %r." % (old_input, new_input, node.name))
13571357

1358-
to_ops = self._input_to_node_name.get(old_input, None)
1358+
to_ops = self._output_to_consumers.get(old_input, None)
13591359
if to_ops is not None:
13601360
if node.name in to_ops:
13611361
# A node may take twice the same entry.
@@ -1369,7 +1369,7 @@ def replace_inputs(self, node, new_inputs):
13691369
assert isinstance(node, Node) and isinstance(new_inputs, list)
13701370

13711371
for old_input in node.input:
1372-
to_ops = self._input_to_node_name.get(old_input, None)
1372+
to_ops = self._output_to_consumers.get(old_input, None)
13731373
if to_ops is not None and old_input in to_ops:
13741374
# To avoid issues when a node
13751375
# takes twice the same entry.

0 commit comments

Comments
 (0)