Skip to content

Commit 19d3f97

Browse files
authored
Merge pull request #1060 from xadupre/input2
Use the same syntax to replace an node input (2) + optimize replace_all_inputs
2 parents d7fe8bb + 89ad32a commit 19d3f97

29 files changed

+190
-97
lines changed

tests/test_internals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def test_rewrite_subgraph(self):
139139
op_name = utils.make_name("ReplacedOp")
140140
out_name = utils.port_name(op_name)
141141
new_node = g.make_node("Sub", inputs=input_node.input, outputs=[out_name], name=op_name)
142-
g.replace_all_inputs(ops, output_node.output[0], new_node.output[0])
142+
g.replace_all_inputs(output_node.output[0], new_node.output[0]) # ops=ops
143143
for n in set(match.get_nodes()):
144144
g.remove_node(n.name)
145145
g.topological_sort(ops)

tf2onnx/graph.py

Lines changed: 123 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,8 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
456456
self._nodes = []
457457
self._nodes_by_name = {}
458458
self._output_to_node_name = {}
459+
self._output_to_consumers = {}
460+
self._input_to_graph = {}
459461
self.shapes = {}
460462
self.graph_name = graph_name or "tf2onnx"
461463
self._is_subgraph = is_subgraph
@@ -502,7 +504,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
502504
body_graph.parent_graph = self
503505
new_node.set_body_graph_as_attr(attr_name, body_graph)
504506

505-
self.replace_all_inputs(self.get_nodes(), o, new_output_name)
507+
self.replace_all_inputs(o, new_output_name, ops=self.get_nodes())
506508
self.make_node("Identity", [new_output_name], outputs=[o], op_name_scope=n.name + "_" + "graph_outputs")
507509
self.copy_shape(new_output_name, o)
508510
self.copy_dtype(new_output_name, o)
@@ -607,6 +609,9 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
607609

608610
onnx_node = helper.make_node(op_type, inputs, outputs, name=name, domain=domain, **raw_attr)
609611

612+
for name2 in onnx_node.input:
613+
self._register_input_name(name2, onnx_node)
614+
610615
if op_type in ["If", "Loop", "Scan"]:
611616
# we force the op containing inner graphs not skipped during conversion.
612617
skip_conversion = False
@@ -635,6 +640,7 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
635640
return node
636641

637642
def append_node(self, node):
643+
"Add a node to the graph."
638644
output_shapes = node.output_shapes
639645
output_dtypes = node.output_dtypes
640646
node.graph = self
@@ -644,6 +650,8 @@ def append_node(self, node):
644650
self._output_to_node_name[name] = node.name
645651
self.set_dtype(name, output_dtypes[i])
646652
self.set_shape(name, output_shapes[i])
653+
for name in node.input:
654+
self._register_input_name(name, node)
647655

648656
def remove_node(self, node_name):
649657
"""Remove node in current graph."""
@@ -664,6 +672,12 @@ def remove_node(self, node_name):
664672
if op_output in self._dtypes:
665673
del self._dtypes[op_output]
666674

675+
for op_input in node.input:
676+
utils.make_sure(
677+
op_input in self._output_to_consumers,
678+
"Input %r of node %r not found.", op_input, node_name)
679+
self._unregister_input_name(op_input, node)
680+
667681
self._nodes.remove(node)
668682
node.graph = None
669683

@@ -687,9 +701,13 @@ def reset_nodes(self, ops):
687701
self.contained_graphs = remained_sub_graphs
688702
self._nodes_by_name = {op.name: op for op in ops}
689703
self._output_to_node_name = {}
704+
self._output_to_consumers = {}
690705
for op in ops:
691706
for op_output in op.output:
692707
self._output_to_node_name[op_output] = op.name
708+
inps = op.input
709+
for op_input in inps:
710+
self._register_input_name(op_input, op)
693711

694712
for n in self._order_sensitive_inputs:
695713
if n not in ops:
@@ -823,6 +841,8 @@ def set_node_by_name(self, node):
823841
self._nodes_by_name[node.name] = node
824842
for op_output in node.output:
825843
self._output_to_node_name[op_output] = node.name
844+
for name in node.input:
845+
self._register_input_name(name, node)
826846

827847
def change_node_name(self, node, new_name):
828848
"""Remove node in current graph."""
@@ -838,7 +858,7 @@ def change_node_name(self, node, new_name):
838858
if k == old_output:
839859
self.outputs[j] = new_output
840860
break
841-
self.replace_all_inputs(self.get_nodes(), old_output, new_output)
861+
self.replace_all_inputs(old_output, new_output, ops=self.get_nodes())
842862
return new_node
843863

844864
def add_graph_input(self, name, dtype=None, shape=None):
@@ -1164,13 +1184,12 @@ def dump_node_statistics(self):
11641184
op_cnt[n.type] += 1
11651185
body_graphs = n.get_body_graphs()
11661186
if body_graphs:
1167-
for _, b_g in body_graphs.items():
1187+
for b_g in body_graphs.values():
11681188
op_cnt += b_g.dump_node_statistics()
11691189

11701190
return op_cnt
11711191

1172-
@staticmethod
1173-
def remove_input(node, to_be_removed, input_index=None):
1192+
def remove_input(self, node, to_be_removed, input_index=None):
11741193
"""Remove input from Node.
11751194
Args:
11761195
node: the node we expect the input on
@@ -1182,15 +1201,24 @@ def remove_input(node, to_be_removed, input_index=None):
11821201
assert isinstance(node, Node) and isinstance(to_be_removed, six.text_type)
11831202
if input_index is not None:
11841203
assert node.input[input_index] == to_be_removed
1204+
if node.input[input_index] in self._output_to_consumers:
1205+
to_ops = self._output_to_consumers[node.input[input_index]]
1206+
if node.name in to_ops:
1207+
to_ops.remove(node.name)
11851208
del node.input[input_index]
1186-
return True
1209+
return
11871210

11881211
for i, name in enumerate(node.input):
11891212
if name == to_be_removed:
1213+
utils.make_sure(
1214+
node.input.count(node.input[i]) <= 1,
1215+
"Node %r takes multiple times the same input %r. This case is not handled.",
1216+
node.name, node.input[i])
1217+
self._unregister_input_name(node.input[i], node)
11901218
del node.input[i]
11911219
break
1220+
11921221
# don't remove output from parent since others might depend on it
1193-
return True
11941222

11951223
def insert_new_node_on_input(self, node, op_type, input_name, name=None, domain=None, **kwargs):
11961224
"""Create and insert a new node into the graph.
@@ -1238,43 +1266,93 @@ def insert_new_node_on_output(self, op_type, output_name, name, domain=None, **k
12381266
new_output = port_name(name)
12391267
new_node = self.make_node(op_type, [output_name], attr=kwargs, outputs=[new_output], name=name, domain=domain)
12401268

1241-
to_replace = [n for n in self.get_nodes() if n != new_node]
1242-
self.replace_all_inputs(to_replace, output_name, new_output)
1269+
to_replace = [self.get_node_by_name(n) for n in self._output_to_consumers[output_name]]
1270+
to_replace = [n for n in to_replace if n != new_node]
1271+
self.replace_all_inputs(output_name, new_output, ops=to_replace)
12431272
return new_node
12441273

12451274
def find_output_consumers(self, output_name):
12461275
"""Find all nodes consuming a given output."""
1276+
if output_name in self._output_to_consumers:
1277+
ops = self._output_to_consumers[output_name]
1278+
ops = [self.get_node_by_name(n) for n in ops]
1279+
else:
1280+
ops = [] # self.get_nodes()
12471281
nodes = []
1248-
for node in self.get_nodes():
1282+
for node in ops:
1283+
if node is None:
1284+
continue
12491285
if output_name in node.input:
12501286
nodes.append(node)
12511287

1252-
# find consumers in sub graphs
1253-
body_graphs = node.get_body_graphs()
1254-
if body_graphs:
1255-
for g in body_graphs.values():
1256-
nodes.extend(g.find_output_consumers(output_name))
1288+
# find consumers in sub graphs
1289+
if output_name in self._input_to_graph:
1290+
for g in self._input_to_graph[output_name].values():
1291+
nodes.extend(g.find_output_consumers(output_name))
12571292
return nodes
12581293

1259-
@staticmethod
1260-
def replace_all_inputs(ops, old_input, new_input):
1261-
"""Replace all inputs pointing to old_input with new_input."""
1294+
def _register_input_name(self, input_name, node, only_graph=False):
1295+
"Register node taking a specific input."
1296+
if not only_graph:
1297+
if input_name not in self._output_to_consumers:
1298+
self._output_to_consumers[input_name] = set()
1299+
self._output_to_consumers[input_name].add(node.name)
1300+
if self.parent_graph is not None:
1301+
if input_name not in self.parent_graph._input_to_graph:
1302+
self.parent_graph._input_to_graph[input_name] = {}
1303+
self.parent_graph._input_to_graph[input_name][id(self)] = self
1304+
self.parent_graph._register_input_name(input_name, node, only_graph=True)
1305+
1306+
def _unregister_input_name(self, input_name, node, only_graph=False):
1307+
"Unregister node taking a specific input."
1308+
node_name = node.name
1309+
if not only_graph:
1310+
if input_name in self._output_to_consumers[input_name]:
1311+
if node_name in self._output_to_consumers[input_name]:
1312+
self._output_to_consumers[input_name].remove(node_name)
1313+
if (self.parent_graph is not None and
1314+
input_name in self.parent_graph._input_to_graph and
1315+
id(self) in self.parent_graph._input_to_graph[input_name]):
1316+
del self.parent_graph._input_to_graph[input_name][id(self)]
1317+
self.parent_graph._unregister_input_name(input_name, node, only_graph=True)
1318+
1319+
def replace_all_inputs(self, old_input, new_input, ops=None):
1320+
"""
1321+
Replace all inputs pointing to old_input with new_input.
1322+
*ops* is used if defined, otherwise `_output_to_consumers`
1323+
is used to determine the impacted nodes.
1324+
"""
12621325
if old_input == new_input:
12631326
return
1327+
if new_input not in self._output_to_consumers:
1328+
self._output_to_consumers[new_input] = set()
1329+
1330+
if ops is not None:
1331+
keep_ops = True
1332+
elif old_input in self._output_to_consumers:
1333+
ops = list(
1334+
filter(lambda a: a is not None,
1335+
map(self.get_node_by_name, self._output_to_consumers[old_input])))
1336+
keep_ops = False
1337+
else:
1338+
ops = []
1339+
keep_ops = False
12641340

12651341
for node in ops:
1342+
assert node is not None
12661343
if old_input in node.input and new_input in node.output:
12671344
raise RuntimeError("creating a circle in the graph is not allowed: " + node.name)
1345+
self._register_input_name(new_input, node)
12681346

12691347
for i, input_name in enumerate(node.input):
12701348
if input_name == old_input:
1271-
node.input[i] = new_input
1349+
self.replace_input(node, node.input[i], new_input, i)
12721350

1273-
# modify references in sub graphs
1274-
body_graphs = node.get_body_graphs()
1275-
if body_graphs:
1276-
for g in body_graphs.values():
1277-
g.replace_all_inputs(g.get_nodes(), old_input, new_input)
1351+
# modify references in sub graphs
1352+
if old_input in self._input_to_graph:
1353+
for g in self._input_to_graph[old_input].values():
1354+
g.replace_all_inputs(old_input, new_input,
1355+
ops=g.get_nodes() if keep_ops else None)
12781356

12791357
def replace_input(self, node, old_input, new_input, input_index=None):
12801358
"""
@@ -1294,11 +1372,31 @@ def replace_input(self, node, old_input, new_input, input_index=None):
12941372
is_replaced = True
12951373
else:
12961374
raise RuntimeError("Unable to replace input %r into %r for node %r." % (old_input, new_input, node.name))
1375+
1376+
to_ops = self._output_to_consumers.get(old_input, None)
1377+
if to_ops is not None:
1378+
if node.name in to_ops:
1379+
# A node may take twice the same entry.
1380+
to_ops.remove(node.name)
1381+
1382+
self._register_input_name(new_input, node)
12971383
return is_replaced
12981384

12991385
def replace_inputs(self, node, new_inputs):
13001386
"""Replace node inputs."""
13011387
assert isinstance(node, Node) and isinstance(new_inputs, list)
1388+
1389+
for old_input in node.input:
1390+
to_ops = self._output_to_consumers.get(old_input, None)
1391+
if to_ops is not None and old_input in to_ops:
1392+
# To avoid issues when a node
1393+
# takes twice the same entry.
1394+
to_ops.remove(old_input)
1395+
1396+
for input_name in new_inputs:
1397+
assert isinstance(input_name, six.text_type)
1398+
self._register_input_name(input_name, node)
1399+
13021400
node.input = new_inputs
13031401
return True
13041402

@@ -1374,7 +1472,7 @@ def delete_unused_nodes(self, outputs_name):
13741472
for node in related_nodes:
13751473
attr_body_graphs = node.get_body_graphs()
13761474
if attr_body_graphs:
1377-
for _, body_graph in attr_body_graphs.items():
1475+
for body_graph in attr_body_graphs.values():
13781476
body_graph.delete_unused_nodes(body_graph.outputs)
13791477
self.reset_nodes(related_nodes)
13801478

tf2onnx/onnx_opset/controlflow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ class TensorListStack:
492492
def version_7(cls, ctx, node, **kwargs):
493493
if node.inputs[0].is_while():
494494
ctx.remove_node(node.name)
495-
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], node.input[0])
495+
ctx.replace_all_inputs(node.output[0], node.input[0]) # ops=ctx.get_nodes()
496496

497497

498498
@tf_op(["While", "StatelessWhile"])
@@ -582,7 +582,7 @@ def version_7(cls, ctx, node, **kwargs):
582582
for idx, n in reversed(to_remove):
583583
ctx.remove_node(n.name)
584584
# make the node output bad
585-
ctx.replace_all_inputs(ctx.get_nodes(), n.output[0], "@@ALLOC")
585+
ctx.replace_all_inputs(n.output[0], "@@ALLOC") # ops=ctx.get_nodes()
586586
del body.func_inputs[idx]
587587
del cond_graph.func_inputs[idx]
588588
del tf_while_inputs[idx]
@@ -618,7 +618,7 @@ def version_7(cls, ctx, node, **kwargs):
618618

619619
# shift output consumers
620620
for k, v in output_map.items():
621-
ctx.replace_all_inputs(ctx.get_nodes(), k, v)
621+
ctx.replace_all_inputs(k, v) # ops=ctx.get_nodes()
622622

623623
wire_while_body(ctx, body, loop_node.inputs, body_input_to_state_var, cond_input_to_state_var, output_shapes,
624624
output_dtypes, body_name, node.name, cond_graph, tf_while_inputs, removed_scan_outputs)
@@ -813,7 +813,7 @@ def prefix_graph(g, scope):
813813
if old_output == oname:
814814
g.outputs[i] = new_output
815815
break
816-
g.replace_all_inputs(ops, old_output, new_output)
816+
g.replace_all_inputs(old_output, new_output, ops=ops)
817817
to_remove.append(node)
818818
for node in to_remove:
819819
g.remove_node(node.name)

tf2onnx/onnx_opset/math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,4 +695,4 @@ def atan2(y, x):
695695
"Add", inputs=[atan_node.output[0], pi_part.output[0]],
696696
op_name_scope=node.name + 'all',
697697
shapes=[shape], dtypes=[onnx_dtype])
698-
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], last_node.output[0])
698+
ctx.replace_all_inputs(node.output[0], last_node.output[0]) # ops=ctx.get_nodes()

tf2onnx/onnx_opset/misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def version_1(cls, ctx, node, **kwargs):
3030
# if identity has a const as input, remove it
3131
input_name = node.input[0]
3232
output_name = node.output[0]
33-
ctx.replace_all_inputs(ctx.get_nodes(), output_name, input_name)
33+
ctx.replace_all_inputs(output_name, input_name) # ops=ctx.get_nodes()
3434
ctx.remove_node(node.name)
3535

3636

tf2onnx/onnx_opset/nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def version_1(cls, ctx, node, **kwargs):
451451
downstream_nodes = ctx.find_output_consumers(node.output[0])
452452
downstream_nodes.remove(output_shape)
453453
downstream_nodes.remove(slice_node)
454-
ctx.replace_all_inputs(downstream_nodes, node.output[0], slice_node.output[0])
454+
ctx.replace_all_inputs(node.output[0], slice_node.output[0], ops=downstream_nodes)
455455

456456
conv_dims_attr(node, "strides", spatial=spatial)
457457
conv_dims_attr(node, "dilations", spatial=spatial)

tf2onnx/onnx_opset/quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,4 @@ def version_10(cls, ctx, node, **kwargs):
7878
"DequantizeLinear", [new_node.output[0], pb_scale.name, zero_point.name],
7979
op_name_scope=node.name, attr={"axis": axis},
8080
shapes=[shape], dtypes=[dtype])
81-
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], last_node.output[0])
81+
ctx.replace_all_inputs(node.output[0], last_node.output[0]) # ops=ctx.get_nodes()

tf2onnx/onnx_opset/rnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def make_sigmoid(i, w, b):
153153
h_node = ctx.make_node("Mul", [co_node.output[0], o])
154154

155155
def replace_output(old_output, new_output):
156-
ctx.replace_all_inputs(ctx.get_nodes(), old_output, new_output)
156+
ctx.replace_all_inputs(old_output, new_output) # ops=ctx.get_nodes()
157157
ctx.copy_dtype(old_output, new_output)
158158
ctx.copy_shape(old_output, new_output)
159159

tf2onnx/onnx_opset/tensor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def version_1(cls, ctx, node, **kwargs):
115115
# if identity has a const as input, remove it
116116
input_name = node.input[0]
117117
output_name = node.output[0]
118-
ctx.replace_all_inputs(ctx.get_nodes(), output_name, input_name)
118+
ctx.replace_all_inputs(output_name, input_name) # ops=ctx.get_nodes()
119119
ctx.remove_node(node.name)
120120

121121

@@ -125,7 +125,7 @@ class IdentityN:
125125
def version_1(cls, ctx, node, **kwargs):
126126
ctx.remove_node(node.name)
127127
for input_name, output_name in zip(node.input, node.output):
128-
ctx.replace_all_inputs(ctx.get_nodes(), output_name, input_name)
128+
ctx.replace_all_inputs(output_name, input_name) # ops=ctx.get_nodes()
129129

130130

131131
@tf_op("Reshape")
@@ -1050,7 +1050,7 @@ def version_1(cls, ctx, node, **kwargs):
10501050
# concat all unqueezes
10511051
concat = ctx.make_node("Concat", inputs, op_name_scope=node.name, attr={"axis": axis},
10521052
shapes=shapes, dtypes=dtypes)
1053-
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], concat.output[0])
1053+
ctx.replace_all_inputs(node.output[0], concat.output[0]) # ops=ctx.get_nodes()
10541054

10551055

10561056
@tf_op("Unpack")

0 commit comments

Comments
 (0)