Skip to content

Commit f5c9406

Browse files
committed
Merge branch 'master' of https://github.com/onnx/tensorflow-onnx into input2
2 parents 7b244e0 + ee2b202 commit f5c9406

File tree

1 file changed

+23
-17
lines changed

1 file changed

+23
-17
lines changed

tf2onnx/graph.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,27 +1178,29 @@ def dump_node_statistics(self):
11781178

11791179
return op_cnt
11801180

1181-
def remove_input(self, node, to_be_removed, i=None):
1181+
def remove_input(self, node, to_be_removed, input_index=None):
11821182
"""Remove input from Node.
11831183
Args:
11841184
node: the node we expect the input on
11851185
to_be_removed: the node name we want to remove
1186-
i: if not None, index of the input to be removed
1186+
input_index: if not None, index of the input to be removed,
1187+
the method is more efficient if *input_index* is specified,
1188+
otherwise, it has to look for every input named *old_input*.
11871189
"""
11881190
assert isinstance(node, Node) and isinstance(to_be_removed, six.text_type)
1189-
if i is not None:
1190-
assert node.input[i] == to_be_removed
1191-
if node.input[i] in self._input_to_node_name:
1192-
to_ops = self._input_to_node_name[node.input[i]]
1191+
if input_index is not None:
1192+
assert node.input[input_index] == to_be_removed
1193+
if node.input[input_index] in self._input_to_node_name:
1194+
to_ops = self._input_to_node_name[node.input[input_index]]
11931195
if node.name in to_ops:
11941196
to_ops.remove(node.name)
1195-
del node.input[i]
1197+
del node.input[input_index]
11961198
return True
11971199

1198-
for i2, name in enumerate(node.input):
1200+
for i, name in enumerate(node.input):
11991201
if name == to_be_removed:
1200-
self._unregister_input_name(node.input[i2], node)
1201-
del node.input[i2]
1202+
self._unregister_input_name(node.input[i], node)
1203+
del node.input[i]
12021204
break
12031205
# don't remove output from parent since others might depend on it
12041206
return True
@@ -1336,17 +1338,21 @@ def replace_all_inputs(self, ops, old_input, new_input):
13361338
for _, g in self._input_to_graph[old_input].items():
13371339
g.replace_all_inputs(g.get_nodes() if keep_ops else None, old_input, new_input)
13381340

1339-
def replace_input(self, node, old_input, new_input, i=None):
1340-
"""Replace one input in a node."""
1341+
def replace_input(self, node, old_input, new_input, input_index=None):
1342+
"""
1343+
Replace one input in a node.
1344+
The method is more efficient if *input_index* is specified.
1345+
Otherwise, it renames every output named *old_input*.
1346+
"""
13411347
assert isinstance(node, Node) and isinstance(old_input, six.text_type) and isinstance(new_input, six.text_type)
13421348
is_replaced = False
1343-
if i is None:
1344-
for i2, input_name in enumerate(node.input):
1349+
if input_index is None:
1350+
for i, input_name in enumerate(node.input):
13451351
if input_name == old_input:
1346-
node.input[i2] = new_input
1352+
node.input[i] = new_input
13471353
is_replaced = True
1348-
elif node.input[i] == old_input:
1349-
node.input[i] = new_input
1354+
elif node.input[input_index] == old_input:
1355+
node.input[input_index] = new_input
13501356
is_replaced = True
13511357
else:
13521358
raise RuntimeError("Unable to replace input %r into %r for node %r." % (old_input, new_input, node.name))

0 commit comments

Comments
 (0)