Skip to content

Commit 3d70c12

Browse files
Hash tensor tensor values in merge_duplicated_nodes to increase conversion speed. Decreases total time from 16 min to 6 min for bit/m-r152x4. For small models time is unchanged
1 parent 07c51be commit 3d70c12

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

tf2onnx/optimizer/merge_duplicated_nodes_optimizer.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,14 @@
77
then b and c can be merged into one node to avoid duplicated computation
88
"""
99

10-
from collections import defaultdict, namedtuple
10+
from collections import defaultdict
1111

1212
import numpy as np
1313

1414
from .optimizer_base import GraphOptimizerBase
1515

1616
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring
1717

18-
_KeyToGroupNodes = namedtuple("key", "type input")
19-
2018

2119
class MergeDuplicatedNodesOptimizer(GraphOptimizerBase):
2220
"""Remove duplicate nodes.
@@ -41,6 +39,7 @@ def _optimize_at_current_graph_level(self, graph):
4139
def _merge_duplicated_nodes(self, graph):
4240
# "duplicated" means: op_type, input and attribute are same
4341
# while attr is un-hashable so doesn't include it when grouping nodes
42+
# we do hash the tensor data of const values
4443
nodes_groups = self._group_nodes_by_type_inputs(graph)
4544
for _, nodes_group in nodes_groups.items():
4645
if self._skip_node_type(nodes_group[0]):
@@ -54,7 +53,11 @@ def _group_nodes_by_type_inputs(graph):
5453
# default const of graph input cannot be merged
5554
if node.is_graph_input_default_const():
5655
continue
57-
res[_KeyToGroupNodes(node.type, tuple(node.input))].append(node)
56+
tensor_data_hash = None
57+
if node.is_const():
58+
# Many constants have the same size so this is helpful
59+
tensor_data_hash = hash(node.attr['value'].t.raw_data)
60+
res[(node.type, tuple(node.input), tensor_data_hash)].append(node)
5861
return res
5962

6063
def _del_nodes_if_duplicated(self, nodes_group, graph):
@@ -75,7 +78,7 @@ def _del_nodes_if_duplicated(self, nodes_group, graph):
7578
def _have_equal_attr(self, node_1, node_2, graph):
7679
if node_1.attr == node_2.attr:
7780
return True
78-
# above check guarantees consts here are able to be merged
81+
# consts have a name attr that can differ among equal consts so they must be handled separately
7982
if node_1.is_const() and node_2.is_const():
8083
# get_tensor_value is costly so that we check their shape first
8184
shape_1 = graph.get_shape(node_1.output[0])

0 commit comments

Comments
 (0)