Skip to content

Commit 874e5c5

Browse files
committed
make ops optional in replace_all_inputs
1 parent 71b0ddb commit 874e5c5

29 files changed

+72
-72
lines changed

tests/test_internals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def test_rewrite_subgraph(self):
139139
op_name = utils.make_name("ReplacedOp")
140140
out_name = utils.port_name(op_name)
141141
new_node = g.make_node("Sub", inputs=input_node.input, outputs=[out_name], name=op_name)
142-
g.replace_all_inputs(None, output_node.output[0], new_node.output[0]) # ops
142+
g.replace_all_inputs(output_node.output[0], new_node.output[0]) # ops=ops
143143
for n in set(match.get_nodes()):
144144
g.remove_node(n.name)
145145
g.topological_sort(ops)

tf2onnx/graph.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
473473
body_graph.parent_graph = self
474474
new_node.set_body_graph_as_attr(attr_name, body_graph)
475475

476-
self.replace_all_inputs(self.get_nodes(), o, new_output_name)
476+
self.replace_all_inputs(o, new_output_name, ops=self.get_nodes())
477477
self.make_node("Identity", [new_output_name], outputs=[o], op_name_scope=n.name + "_" + "graph_outputs")
478478
self.copy_shape(new_output_name, o)
479479
self.copy_dtype(new_output_name, o)
@@ -839,7 +839,7 @@ def change_node_name(self, node, new_name):
839839
if k == old_output:
840840
self.outputs[j] = new_output
841841
break
842-
self.replace_all_inputs(self.get_nodes(), old_output, new_output)
842+
self.replace_all_inputs(old_output, new_output, ops=self.get_nodes())
843843
return new_node
844844

845845
def add_graph_input(self, name, dtype=None, shape=None):
@@ -1250,7 +1250,7 @@ def insert_new_node_on_output(self, op_type, output_name, name, domain=None, **k
12501250

12511251
to_replace = [self.get_node_by_name(n) for n in self._input_to_node_name[output_name]]
12521252
to_replace = [n for n in to_replace if n != new_node]
1253-
self.replace_all_inputs(to_replace, output_name, new_output)
1253+
self.replace_all_inputs(output_name, new_output, ops=to_replace)
12541254
return new_node
12551255

12561256
def find_output_consumers(self, output_name):
@@ -1298,11 +1298,11 @@ def _unregister_input_name(self, input_name, node, only_graph=False):
12981298
del self.parent_graph._input_to_graph[input_name][id(self)]
12991299
self.parent_graph._unregister_input_name(input_name, node, only_graph=True)
13001300

1301-
def replace_all_inputs(self, ops, old_input, new_input):
1301+
def replace_all_inputs(self, old_input, new_input, ops=None):
13021302
"""
13031303
Replace all inputs pointing to old_input with new_input.
1304-
*ops* is used if defined, otherwise _input_to_node_name
1305-
is used to determine the impacted nodes.
1304+
*ops* is used if defined, otherwise `_input_to_node_name`
1305+
is used to determine the impacted nodes.
13061306
"""
13071307
if old_input == new_input:
13081308
return
@@ -1333,7 +1333,8 @@ def replace_all_inputs(self, ops, old_input, new_input):
13331333
# modify references in sub graphs
13341334
if old_input in self._input_to_graph:
13351335
for g in self._input_to_graph[old_input].values():
1336-
g.replace_all_inputs(g.get_nodes() if keep_ops else None, old_input, new_input)
1336+
g.replace_all_inputs(old_input, new_input,
1337+
ops=g.get_nodes() if keep_ops else None)
13371338

13381339
def replace_input(self, node, old_input, new_input, input_index=None):
13391340
"""

tf2onnx/onnx_opset/controlflow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ class TensorListStack:
492492
def version_7(cls, ctx, node, **kwargs):
493493
if node.inputs[0].is_while():
494494
ctx.remove_node(node.name)
495-
ctx.replace_all_inputs(None, node.output[0], node.input[0]) # ctx.get_nodes()
495+
ctx.replace_all_inputs(node.output[0], node.input[0]) # ops=ctx.get_nodes()
496496

497497

498498
@tf_op(["While", "StatelessWhile"])
@@ -582,7 +582,7 @@ def version_7(cls, ctx, node, **kwargs):
582582
for idx, n in reversed(to_remove):
583583
ctx.remove_node(n.name)
584584
# make the node output bad
585-
ctx.replace_all_inputs(None, n.output[0], "@@ALLOC") # ctx.get_nodes()
585+
ctx.replace_all_inputs(n.output[0], "@@ALLOC") # ops=ctx.get_nodes()
586586
del body.func_inputs[idx]
587587
del cond_graph.func_inputs[idx]
588588
del tf_while_inputs[idx]
@@ -618,7 +618,7 @@ def version_7(cls, ctx, node, **kwargs):
618618

619619
# shift output consumers
620620
for k, v in output_map.items():
621-
ctx.replace_all_inputs(None, k, v) # ctx.get_nodes()
621+
ctx.replace_all_inputs(k, v) # ops=ctx.get_nodes()
622622

623623
wire_while_body(ctx, body, loop_node.inputs, body_input_to_state_var, cond_input_to_state_var, output_shapes,
624624
output_dtypes, body_name, node.name, cond_graph, tf_while_inputs, removed_scan_outputs)
@@ -813,7 +813,7 @@ def prefix_graph(g, scope):
813813
if old_output == oname:
814814
g.outputs[i] = new_output
815815
break
816-
g.replace_all_inputs(ops, old_output, new_output)
816+
g.replace_all_inputs(old_output, new_output, ops=ops)
817817
to_remove.append(node)
818818
for node in to_remove:
819819
g.remove_node(node.name)

tf2onnx/onnx_opset/math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,4 +695,4 @@ def atan2(y, x):
695695
"Add", inputs=[atan_node.output[0], pi_part.output[0]],
696696
op_name_scope=node.name + 'all',
697697
shapes=[shape], dtypes=[onnx_dtype])
698-
ctx.replace_all_inputs(None, node.output[0], last_node.output[0]) # ctx.get_nodes()
698+
ctx.replace_all_inputs(node.output[0], last_node.output[0]) # ops=ctx.get_nodes()

tf2onnx/onnx_opset/misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def version_1(cls, ctx, node, **kwargs):
3030
# if identity has a const as input, remove it
3131
input_name = node.input[0]
3232
output_name = node.output[0]
33-
ctx.replace_all_inputs(None, output_name, input_name) # ctx.get_nodes()
33+
ctx.replace_all_inputs(output_name, input_name) # ops=ctx.get_nodes()
3434
ctx.remove_node(node.name)
3535

3636

tf2onnx/onnx_opset/nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def version_1(cls, ctx, node, **kwargs):
451451
downstream_nodes = ctx.find_output_consumers(node.output[0])
452452
downstream_nodes.remove(output_shape)
453453
downstream_nodes.remove(slice_node)
454-
ctx.replace_all_inputs(downstream_nodes, node.output[0], slice_node.output[0])
454+
ctx.replace_all_inputs(node.output[0], slice_node.output[0], ops=downstream_nodes)
455455

456456
conv_dims_attr(node, "strides", spatial=spatial)
457457
conv_dims_attr(node, "dilations", spatial=spatial)

tf2onnx/onnx_opset/quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,4 @@ def version_10(cls, ctx, node, **kwargs):
7878
"DequantizeLinear", [new_node.output[0], pb_scale.name, zero_point.name],
7979
op_name_scope=node.name, attr={"axis": axis},
8080
shapes=[shape], dtypes=[dtype])
81-
ctx.replace_all_inputs(None, node.output[0], last_node.output[0]) # ctx.get_nodes()
81+
ctx.replace_all_inputs(node.output[0], last_node.output[0]) # ops=ctx.get_nodes()

tf2onnx/onnx_opset/rnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def make_sigmoid(i, w, b):
153153
h_node = ctx.make_node("Mul", [co_node.output[0], o])
154154

155155
def replace_output(old_output, new_output):
156-
ctx.replace_all_inputs(None, old_output, new_output) # ctx.get_nodes()
156+
ctx.replace_all_inputs(old_output, new_output) # ops=ctx.get_nodes()
157157
ctx.copy_dtype(old_output, new_output)
158158
ctx.copy_shape(old_output, new_output)
159159

tf2onnx/onnx_opset/tensor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def version_1(cls, ctx, node, **kwargs):
115115
# if identity has a const as input, remove it
116116
input_name = node.input[0]
117117
output_name = node.output[0]
118-
ctx.replace_all_inputs(None, output_name, input_name) # ctx.get_nodes()
118+
ctx.replace_all_inputs(output_name, input_name) # ops=ctx.get_nodes()
119119
ctx.remove_node(node.name)
120120

121121

@@ -125,7 +125,7 @@ class IdentityN:
125125
def version_1(cls, ctx, node, **kwargs):
126126
ctx.remove_node(node.name)
127127
for input_name, output_name in zip(node.input, node.output):
128-
ctx.replace_all_inputs(None, output_name, input_name) # ctx.get_nodes()
128+
ctx.replace_all_inputs(output_name, input_name) # ops=ctx.get_nodes()
129129

130130

131131
@tf_op("Reshape")
@@ -1050,7 +1050,7 @@ def version_1(cls, ctx, node, **kwargs):
10501050
# concat all unqueezes
10511051
concat = ctx.make_node("Concat", inputs, op_name_scope=node.name, attr={"axis": axis},
10521052
shapes=shapes, dtypes=dtypes)
1053-
ctx.replace_all_inputs(None, node.output[0], concat.output[0]) # ctx.get_nodes()
1053+
ctx.replace_all_inputs(node.output[0], concat.output[0]) # ops=ctx.get_nodes()
10541054

10551055

10561056
@tf_op("Unpack")

tf2onnx/optimizer/back_to_back_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def _optimize_transpose(g, node, consumer_nodes):
138138
shape = g.get_shape(node2.output[0])
139139
dtype = g.get_dtype(node2.output[0])
140140
node2_consumers = g.find_output_consumers(node2.output[0])
141-
g.replace_all_inputs(node2_consumers, node2.output[0], node.input[0])
141+
g.replace_all_inputs(node2.output[0], node.input[0], ops=node2_consumers)
142142
g.remove_node(node2.name)
143143
if set(node2.output) & set(g.outputs):
144144
g.make_node("Identity", [node.input[0]],
@@ -173,7 +173,7 @@ def _optimize_squeeze_unsqueeze(g, node, consumer_nodes):
173173
return []
174174

175175
node2_consumers = g.find_output_consumers(node2.output[0])
176-
g.replace_all_inputs(node2_consumers, node2.output[0], node.input[0])
176+
g.replace_all_inputs(node2.output[0], node.input[0], ops=node2_consumers)
177177
g.remove_node(node.name)
178178
g.remove_node(node2.name)
179179
return []

0 commit comments

Comments
 (0)