Skip to content

Commit f5f8b2d

Browse files
Merge pull request #1102 from onnx/tom/ImproveOptDupNodePerf
Hash tensor tensor values in merge_duplicated_nodes to increase conversion speed
2 parents 07c51be + ec9d775 commit f5f8b2d

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

tests/test_optimizers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -918,8 +918,8 @@ def test_duplicated_duplicated_constant(self):
918918
const_val = np.array([1, 2, 3], dtype=np.float32)
919919
tensor_1 = helper.make_tensor("tensor_1", TensorProto.FLOAT, const_val.shape, const_val)
920920
tensor_2 = helper.make_tensor("tensor_2", TensorProto.FLOAT, const_val.shape, const_val)
921-
tensor_3 = helper.make_tensor("tensor_3", TensorProto.FLOAT, const_val.shape, const_val.tobytes(), raw=True)
922-
tensor_4 = helper.make_tensor("tensor_4", TensorProto.FLOAT, const_val.shape, const_val.tobytes(), raw=True)
921+
tensor_3 = helper.make_tensor("tensor_3", TensorProto.FLOAT, const_val.shape, const_val)
922+
tensor_4 = helper.make_tensor("tensor_4", TensorProto.FLOAT, const_val.shape, const_val)
923923
node0 = helper.make_node('Constant', inputs=[], outputs=["value0"], value=tensor_1)
924924
node1 = helper.make_node('Constant', inputs=[], outputs=["value1"], value=tensor_2)
925925
node2 = helper.make_node('Constant', inputs=[], outputs=["value2"], value=tensor_3)
@@ -941,8 +941,8 @@ def test_duplicated_duplicated_constant(self):
941941

942942
def test_duplicated_duplicated_constant_and_initializer(self):
943943
const_val = np.array([1, 2, 3], dtype=np.float32)
944-
tensor_1 = helper.make_tensor("value0", TensorProto.FLOAT, const_val.shape, const_val)
945-
tensor_2 = helper.make_tensor("value1", TensorProto.FLOAT, const_val.shape, const_val)
944+
tensor_1 = helper.make_tensor("value0", TensorProto.FLOAT, const_val.shape, const_val.tobytes(), raw=True)
945+
tensor_2 = helper.make_tensor("value1", TensorProto.FLOAT, const_val.shape, const_val.tobytes(), raw=True)
946946
tensor_3 = helper.make_tensor("value2", TensorProto.FLOAT, const_val.shape, const_val.tobytes(), raw=True)
947947
tensor_4 = helper.make_tensor("value3", TensorProto.FLOAT, const_val.shape, const_val.tobytes(), raw=True)
948948
node0 = helper.make_node('Constant', inputs=[], outputs=["value0"], value=tensor_1)

tf2onnx/optimizer/merge_duplicated_nodes_optimizer.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,14 @@
77
then b and c can be merged into one node to avoid duplicated computation
88
"""
99

10-
from collections import defaultdict, namedtuple
10+
from collections import defaultdict
1111

1212
import numpy as np
1313

1414
from .optimizer_base import GraphOptimizerBase
1515

1616
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring
1717

18-
_KeyToGroupNodes = namedtuple("key", "type input")
19-
2018

2119
class MergeDuplicatedNodesOptimizer(GraphOptimizerBase):
2220
"""Remove duplicate nodes.
@@ -41,6 +39,7 @@ def _optimize_at_current_graph_level(self, graph):
4139
def _merge_duplicated_nodes(self, graph):
4240
# "duplicated" means: op_type, input and attribute are same
4341
# while attr is un-hashable so doesn't include it when grouping nodes
42+
# we do hash the tensor data of const values
4443
nodes_groups = self._group_nodes_by_type_inputs(graph)
4544
for _, nodes_group in nodes_groups.items():
4645
if self._skip_node_type(nodes_group[0]):
@@ -54,7 +53,11 @@ def _group_nodes_by_type_inputs(graph):
5453
# default const of graph input cannot be merged
5554
if node.is_graph_input_default_const():
5655
continue
57-
res[_KeyToGroupNodes(node.type, tuple(node.input))].append(node)
56+
tensor_data_hash = None
57+
if node.is_const():
58+
# Many constants have the same size so this is helpful
59+
tensor_data_hash = hash(node.attr['value'].t.raw_data)
60+
res[(node.type, tuple(node.input), tensor_data_hash)].append(node)
5861
return res
5962

6063
def _del_nodes_if_duplicated(self, nodes_group, graph):
@@ -75,7 +78,7 @@ def _del_nodes_if_duplicated(self, nodes_group, graph):
7578
def _have_equal_attr(self, node_1, node_2, graph):
7679
if node_1.attr == node_2.attr:
7780
return True
78-
# above check guarantees consts here are able to be merged
81+
# consts have a name attr that can differ among equal consts so they must be handled separately
7982
if node_1.is_const() and node_2.is_const():
8083
# get_tensor_value is costly so that we check their shape first
8184
shape_1 = graph.get_shape(node_1.output[0])

0 commit comments

Comments
 (0)