Skip to content

Commit 469b4c4

Browse files
committed
Use the same syntax to replace an node input (2) + optimisation
1 parent 102d8f8 commit 469b4c4

16 files changed

+169
-61
lines changed

tests/test_internals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def test_rewrite_subgraph(self):
139139
op_name = utils.make_name("ReplacedOp")
140140
out_name = utils.port_name(op_name)
141141
new_node = g.make_node("Sub", inputs=input_node.input, outputs=[out_name], name=op_name)
142-
g.replace_all_inputs(ops, output_node.output[0], new_node.output[0])
142+
g.replace_all_inputs(None, output_node.output[0], new_node.output[0]) # ops
143143
for n in set(match.get_nodes()):
144144
g.remove_node(n.name)
145145
g.topological_sort(ops)

tf2onnx/graph.py

Lines changed: 128 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,8 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
425425
self._nodes = []
426426
self._nodes_by_name = {}
427427
self._output_to_node_name = {}
428+
self._input_to_node_name = {}
429+
self._input_to_graph = {}
428430
self.shapes = {}
429431
self.graph_name = graph_name or "tf2onnx"
430432
self._is_subgraph = is_subgraph
@@ -576,6 +578,9 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
576578

577579
onnx_node = helper.make_node(op_type, inputs, outputs, name=name, domain=domain, **raw_attr)
578580

581+
for name2 in onnx_node.input:
582+
self._register_input_name(name2, onnx_node)
583+
579584
if op_type in ["If", "Loop", "Scan"]:
580585
# we force the op containing inner graphs not skipped during conversion.
581586
skip_conversion = False
@@ -613,6 +618,8 @@ def append_node(self, node):
613618
self._output_to_node_name[name] = node.name
614619
self.set_dtype(name, output_dtypes[i])
615620
self.set_shape(name, output_shapes[i])
621+
for name in node.input:
622+
self._register_input_name(name, node)
616623

617624
def remove_node(self, node_name):
618625
"""Remove node in current graph."""
@@ -633,6 +640,12 @@ def remove_node(self, node_name):
633640
if op_output in self._dtypes:
634641
del self._dtypes[op_output]
635642

643+
for op_input in node.input:
644+
if op_input not in self._input_to_node_name:
645+
raise RuntimeError(
646+
"Input %r of node %r not found." % (op_input, node_name))
647+
self._unregister_input_name(op_input, node)
648+
636649
self._nodes.remove(node)
637650
node.graph = None
638651

@@ -656,16 +669,32 @@ def reset_nodes(self, ops):
656669
self.contained_graphs = remained_sub_graphs
657670
self._nodes_by_name = {op.name: op for op in ops}
658671
self._output_to_node_name = {}
672+
self._input_to_node_name = {}
659673
for op in ops:
660674
for op_output in op.output:
661675
self._output_to_node_name[op_output] = op.name
676+
if op.type == 'Placeholder':
677+
inps = [op.name]
678+
elif op.type == 'Const':
679+
inps = [op.name]
680+
else:
681+
inps = op.input
682+
for op_input in inps:
683+
self._register_input_name(op_input, op)
662684

663685
for n in self._order_sensitive_inputs:
664686
if n not in ops:
665687
self._order_sensitive_inputs.remove(n)
666688
for o in self.outputs:
667689
if o not in self._output_to_node_name:
668690
raise ValueError("graph output " + o + " not exist")
691+
for i in self.inputs:
692+
if i.name.startswith('Placeholder'):
693+
continue
694+
if i.name.startswith('keras_learning_phase'):
695+
continue
696+
if i.name not in self._input_to_node_name:
697+
raise ValueError("graph input %r not exist in graph." % i.name)
669698

670699
self._dtypes = remained_dtypes
671700
self._output_shapes = remained_shapes
@@ -782,6 +811,14 @@ def get_node_by_output_in_current_graph(self, output):
782811
ret = self._nodes_by_name.get(name)
783812
return ret
784813

814+
def get_node_by_input_in_current_graph(self, input_name):
815+
"""Get nodes by node input id."""
816+
names = self._output_to_node_name.get(input_name)
817+
ret = None
818+
if name:
819+
ret = [self._nodes_by_name.get(name) for name in names]
820+
return ret
821+
785822
def get_node_by_name(self, name):
786823
"""Get node by name."""
787824
ret = self._nodes_by_name.get(name)
@@ -792,6 +829,8 @@ def set_node_by_name(self, node):
792829
self._nodes_by_name[node.name] = node
793830
for op_output in node.output:
794831
self._output_to_node_name[op_output] = node.name
832+
for name in node.input:
833+
self._register_input_name(name, node)
795834

796835
def change_node_name(self, node, new_name):
797836
"""Remove node in current graph."""
@@ -1138,8 +1177,7 @@ def dump_node_statistics(self):
11381177

11391178
return op_cnt
11401179

1141-
@staticmethod
1142-
def remove_input(node, to_be_removed, i=None):
1180+
def remove_input(self, node, to_be_removed, i=None):
11431181
"""Remove input from Node.
11441182
Args:
11451183
node: the node we expect the input on
@@ -1149,11 +1187,16 @@ def remove_input(node, to_be_removed, i=None):
11491187
assert isinstance(node, Node) and isinstance(to_be_removed, six.text_type)
11501188
if i is not None:
11511189
assert node.input[i] == to_be_removed
1190+
if node.input[i] in self._input_to_node_name:
1191+
to_ops = self._input_to_node_name[node.input[i]]
1192+
if node.name in to_ops:
1193+
to_ops.remove(node.name)
11521194
del node.input[i]
11531195
return True
11541196

11551197
for i2, name in enumerate(node.input):
11561198
if name == to_be_removed:
1199+
self._unregister_input_name(node.input[i2], node)
11571200
del node.input[i2]
11581201
break
11591202
# don't remove output from parent since others might depend on it
@@ -1205,43 +1248,90 @@ def insert_new_node_on_output(self, op_type, output_name, name, domain=None, **k
12051248
new_output = port_name(name)
12061249
new_node = self.make_node(op_type, [output_name], attr=kwargs, outputs=[new_output], name=name, domain=domain)
12071250

1208-
to_replace = [n for n in self.get_nodes() if n != new_node]
1251+
to_replace = [self.get_node_by_name(n) for n in self._input_to_node_name[output_name]]
1252+
to_replace = [n for n in to_replace if n != new_node]
12091253
self.replace_all_inputs(to_replace, output_name, new_output)
12101254
return new_node
12111255

12121256
def find_output_consumers(self, output_name):
12131257
"""Find all nodes consuming a given output."""
1258+
if output_name in self._input_to_node_name:
1259+
ops = self._input_to_node_name[output_name]
1260+
ops = [self.get_node_by_name(n) for n in ops]
1261+
else:
1262+
ops = self.get_nodes()
12141263
nodes = []
1215-
for node in self.get_nodes():
1264+
for node in ops:
1265+
if node is None:
1266+
continue
12161267
if output_name in node.input:
12171268
nodes.append(node)
12181269

1219-
# find consumers in sub graphs
1220-
body_graphs = node.get_body_graphs()
1221-
if body_graphs:
1222-
for g in body_graphs.values():
1223-
nodes.extend(g.find_output_consumers(output_name))
1270+
# find consumers in sub graphs
1271+
if output_name in self._input_to_graph:
1272+
for _, g in self._input_to_graph[output_name].items():
1273+
nodes.extend(g.find_output_consumers(output_name))
12241274
return nodes
12251275

1226-
@staticmethod
1227-
def replace_all_inputs(ops, old_input, new_input):
1228-
"""Replace all inputs pointing to old_input with new_input."""
1276+
def _register_input_name(self, input_name, node, only_graph=False):
1277+
"Register node taking a specific input."
1278+
if not only_graph:
1279+
if input_name not in self._input_to_node_name:
1280+
self._input_to_node_name[input_name] = set()
1281+
self._input_to_node_name[input_name].add(node.name)
1282+
if self.parent_graph is not None:
1283+
if input_name not in self.parent_graph._input_to_graph:
1284+
self.parent_graph._input_to_graph[input_name] = {}
1285+
self.parent_graph._input_to_graph[input_name][id(self)] = self
1286+
self.parent_graph._register_input_name(input_name, node, only_graph=True)
1287+
1288+
def _unregister_input_name(self, input_name, node, only_graph=False):
1289+
"Unregister node taking a specific input."
1290+
node_name = node.name
1291+
if not only_graph:
1292+
if node_name in self._input_to_node_name[input_name]:
1293+
if node_name in self._input_to_node_name[input_name]:
1294+
self._input_to_node_name[input_name].remove(node_name)
1295+
if (self.parent_graph is not None and
1296+
input_name in self.parent_graph._input_to_graph and
1297+
id(self) in self.parent_graph._input_to_graph[input_name]):
1298+
del self.parent_graph._input_to_graph[input_name][id(self)]
1299+
self.parent_graph._unregister_input_name(input_name, node, only_graph=True)
1300+
1301+
def replace_all_inputs(self, ops, old_input, new_input):
1302+
"""
1303+
Replace all inputs pointing to old_input with new_input.
1304+
*ops* is used if defined, otherwise _input_to_node_name
1305+
is used to determine the impacted nodes.
1306+
"""
12291307
if old_input == new_input:
12301308
return
1309+
if new_input not in self._input_to_node_name:
1310+
self._input_to_node_name[new_input] = set()
1311+
1312+
if ops is not None:
1313+
keep_ops = True
1314+
elif old_input in self._input_to_node_name:
1315+
ops = [self.get_node_by_name(n) for n in self._input_to_node_name[old_input]]
1316+
keep_ops = False
1317+
else:
1318+
ops = []
1319+
keep_ops = False
12311320

12321321
for node in ops:
1322+
assert node is not None
12331323
if old_input in node.input and new_input in node.output:
12341324
raise RuntimeError("creating a circle in the graph is not allowed: " + node.name)
1325+
self._register_input_name(new_input, node)
12351326

12361327
for i, input_name in enumerate(node.input):
12371328
if input_name == old_input:
1238-
node.input[i] = new_input
1329+
self.replace_input(node, node.input[i], new_input, i)
12391330

1240-
# modify references in sub graphs
1241-
body_graphs = node.get_body_graphs()
1242-
if body_graphs:
1243-
for g in body_graphs.values():
1244-
g.replace_all_inputs(g.get_nodes(), old_input, new_input)
1331+
# modify references in sub graphs
1332+
if old_input in self._input_to_graph:
1333+
for _, g in self._input_to_graph[old_input].items():
1334+
g.replace_all_inputs(g.get_nodes() if keep_ops else None, old_input, new_input)
12451335

12461336
def replace_input(self, node, old_input, new_input, i=None):
12471337
"""Replace one input in a node."""
@@ -1257,11 +1347,31 @@ def replace_input(self, node, old_input, new_input, i=None):
12571347
is_replaced = True
12581348
else:
12591349
raise RuntimeError("Unable to replace input %r into %r for node %r." % (old_input, new_input, node.name))
1350+
1351+
to_ops = self._input_to_node_name.get(old_input, None)
1352+
if to_ops is not None:
1353+
if node.name in to_ops:
1354+
# A node may take twice the same entry.
1355+
to_ops.remove(node.name)
1356+
1357+
self._register_input_name(new_input, node)
12601358
return is_replaced
12611359

12621360
def replace_inputs(self, node, new_inputs):
12631361
"""Replace node inputs."""
12641362
assert isinstance(node, Node) and isinstance(new_inputs, list)
1363+
1364+
for old_input in node.input:
1365+
to_ops = self._input_to_node_name.get(old_input, None)
1366+
if to_ops is not None and old_input in to_ops:
1367+
# To avoid issues when a node
1368+
# takes twice the same entry.
1369+
to_ops.remove(old_input)
1370+
1371+
for input_name in new_inputs:
1372+
assert isinstance(input_name, six.text_type)
1373+
self._register_input_name(input_name, node)
1374+
12651375
node.input = new_inputs
12661376
return True
12671377

tf2onnx/onnx_opset/controlflow.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ class TensorListStack:
493493
def version_7(cls, ctx, node, **kwargs):
494494
if node.inputs[0].is_while():
495495
ctx.remove_node(node.name)
496-
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], node.input[0])
496+
ctx.replace_all_inputs(None, node.output[0], node.input[0]) # ctx.get_nodes()
497497

498498

499499
@tf_op(["While", "StatelessWhile"])
@@ -583,7 +583,7 @@ def version_7(cls, ctx, node, **kwargs):
583583
for idx, n in reversed(to_remove):
584584
ctx.remove_node(n.name)
585585
# make the node output bad
586-
ctx.replace_all_inputs(ctx.get_nodes(), n.output[0], "@@ALLOC")
586+
ctx.replace_all_inputs(None, n.output[0], "@@ALLOC") # ctx.get_nodes()
587587
del body.func_inputs[idx]
588588
del cond_graph.func_inputs[idx]
589589
del tf_while_inputs[idx]
@@ -619,7 +619,7 @@ def version_7(cls, ctx, node, **kwargs):
619619

620620
# shift output consumers
621621
for k, v in output_map.items():
622-
ctx.replace_all_inputs(ctx.get_nodes(), k, v)
622+
ctx.replace_all_inputs(None, k, v) # ctx.get_nodes()
623623

624624
wire_while_body(ctx, body, loop_node.inputs, body_input_to_state_var, cond_input_to_state_var, output_shapes,
625625
output_dtypes, body_name, node.name, cond_graph, tf_while_inputs, removed_scan_outputs)

tf2onnx/onnx_opset/math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,4 +695,4 @@ def atan2(y, x):
695695
"Add", inputs=[atan_node.output[0], pi_part.output[0]],
696696
op_name_scope=node.name + 'all',
697697
shapes=[shape], dtypes=[onnx_dtype])
698-
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], last_node.output[0])
698+
ctx.replace_all_inputs(None, node.output[0], last_node.output[0]) # ctx.get_nodes()

tf2onnx/onnx_opset/misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def version_1(cls, ctx, node, **kwargs):
3030
# if identity has a const as input, remove it
3131
input_name = node.input[0]
3232
output_name = node.output[0]
33-
ctx.replace_all_inputs(ctx.get_nodes(), output_name, input_name)
33+
ctx.replace_all_inputs(None, output_name, input_name) # ctx.get_nodes()
3434
ctx.remove_node(node.name)
3535

3636

tf2onnx/onnx_opset/quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,4 @@ def version_10(cls, ctx, node, **kwargs):
7878
"DequantizeLinear", [new_node.output[0], pb_scale.name, zero_point.name],
7979
op_name_scope=node.name, attr={"axis": axis},
8080
shapes=[shape], dtypes=[dtype])
81-
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], last_node.output[0])
81+
ctx.replace_all_inputs(None, node.output[0], last_node.output[0]) # ctx.get_nodes()

tf2onnx/onnx_opset/rnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def make_sigmoid(i, w, b):
153153
h_node = ctx.make_node("Mul", [co_node.output[0], o])
154154

155155
def replace_output(old_output, new_output):
156-
ctx.replace_all_inputs(ctx.get_nodes(), old_output, new_output)
156+
ctx.replace_all_inputs(None, old_output, new_output) # ctx.get_nodes()
157157
ctx.copy_dtype(old_output, new_output)
158158
ctx.copy_shape(old_output, new_output)
159159

tf2onnx/onnx_opset/tensor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def version_1(cls, ctx, node, **kwargs):
116116
# if identity has a const as input, remove it
117117
input_name = node.input[0]
118118
output_name = node.output[0]
119-
ctx.replace_all_inputs(ctx.get_nodes(), output_name, input_name)
119+
ctx.replace_all_inputs(None, output_name, input_name) # ctx.get_nodes()
120120
ctx.remove_node(node.name)
121121

122122

@@ -126,7 +126,7 @@ class IdentityN:
126126
def version_1(cls, ctx, node, **kwargs):
127127
ctx.remove_node(node.name)
128128
for input_name, output_name in zip(node.input, node.output):
129-
ctx.replace_all_inputs(ctx.get_nodes(), output_name, input_name)
129+
ctx.replace_all_inputs(None, output_name, input_name) # ctx.get_nodes()
130130

131131

132132
@tf_op("Reshape")
@@ -1051,7 +1051,7 @@ def version_1(cls, ctx, node, **kwargs):
10511051
# concat all unqueezes
10521052
concat = ctx.make_node("Concat", inputs, op_name_scope=node.name, attr={"axis": axis},
10531053
shapes=shapes, dtypes=dtypes)
1054-
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], concat.output[0])
1054+
ctx.replace_all_inputs(None, node.output[0], concat.output[0]) # ctx.get_nodes()
10551055

10561056

10571057
@tf_op("Unpack")

0 commit comments

Comments
 (0)