Skip to content

Commit cbffde4

Browse files
authored
Merge pull request #1019 from xadupre/output
Remove unnecessary deepcopy when accessing output from class Node
2 parents 097e28c + d239316 commit cbffde4

File tree

5 files changed

+11
-5
lines changed

5 files changed

+11
-5
lines changed

tf2onnx/graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def input(self, val):
6060

6161
@property
6262
def output(self):
63-
return copy.deepcopy(self._output)
63+
return self._output
6464

6565
@output.setter
6666
def output(self, val):
@@ -71,7 +71,7 @@ def output(self, val):
7171
for o in self._output:
7272
del self.graph._output_to_node_name[o]
7373

74-
self._output = val
74+
self._output = val.copy()
7575
for o in self._output:
7676
utils.make_sure(o not in self.graph._output_to_node_name, "output %s already in output mapping", o)
7777
self.graph._output_to_node_name[o] = self.name

tf2onnx/onnx_opset/controlflow.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,9 @@ def version_7(cls, ctx, node, **kwargs):
515515

516516
output_shapes = node.output_shapes
517517
output_dtypes = node.output_dtypes
518-
output_names = node.output
518+
# node.output must be copied as some element
519+
# may be removed from output_names below
520+
output_names = node.output.copy()
519521

520522
# Make maximum_iterations int64 and replace -1(tf) with maxsize(onnx). If the const node has no other consumers,
521523
# modify it in place. Otherwise, make a new const node and leave the original unchanged.

tf2onnx/onnx_opset/generator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,8 @@ def version_8(cls, ctx, node, **kwargs):
186186
class IteratorGetNext:
187187
@classmethod
188188
def version_8(cls, ctx, node, **kwargs):
189-
output_names = node.output
189+
output_names = node.output.copy() # to make sure remove_node
190+
# does not alter the list
190191
type_0 = ctx.get_dtype(output_names[0])
191192
type_1 = ctx.get_dtype(output_names[1])
192193
shape_0 = ctx.get_shape(output_names[0])
@@ -200,7 +201,8 @@ def version_8(cls, ctx, node, **kwargs):
200201
class QueueDequeueManyV2:
201202
@classmethod
202203
def version_8(cls, ctx, node, **kwargs):
203-
outputs = node.output
204+
outputs = node.output.copy() # copy to make remove_node
205+
# does not alter the list
204206
shapes = node.output_shapes
205207
dtypes = node.output_dtypes
206208
ctx.remove_node(node.name)

tf2onnx/onnx_opset/nn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,7 @@ def version_6(cls, ctx, node, **kwargs):
666666
consumers = [ctx.find_output_consumers(output_name) for output_name in node.output[1:]]
667667
if not any(consumers):
668668
new_output = [node.output[0]]
669+
# the setter makes a copy of new_output
669670
node.output = new_output
670671

671672
conv_convert_inputs(ctx, node, with_kernel=False)

tf2onnx/optimizer/back_to_back_optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def _optimize_conv_batchnorm_fusion(g, node, consumer_nodes):
237237
node2_shape = g.get_shape(node2.output[0])
238238
node2_dtype = g.get_dtype(node2.output[0])
239239
g.remove_node(node2.name)
240+
# the setter makes a copy
240241
node.output = node2_output
241242
g.set_shape(node2_output[0], node2_shape)
242243
g.set_dtype(node2_output[0], node2_dtype)

0 commit comments

Comments
 (0)