7
7
then b and c can be merged into one node to avoid duplicated computation
8
8
"""
9
9
10
- from collections import defaultdict , namedtuple
10
+ from collections import defaultdict
11
11
12
12
import numpy as np
13
13
14
14
from .optimizer_base import GraphOptimizerBase
15
15
16
16
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring
17
17
18
- _KeyToGroupNodes = namedtuple ("key" , "type input" )
19
-
20
18
21
19
class MergeDuplicatedNodesOptimizer (GraphOptimizerBase ):
22
20
"""Remove duplicate nodes.
@@ -41,6 +39,7 @@ def _optimize_at_current_graph_level(self, graph):
41
39
def _merge_duplicated_nodes (self , graph ):
42
40
# "duplicated" means: op_type, input and attribute are same
43
41
# while attr is un-hashable so doesn't include it when grouping nodes
42
+ # we do hash the tensor data of const values
44
43
nodes_groups = self ._group_nodes_by_type_inputs (graph )
45
44
for _ , nodes_group in nodes_groups .items ():
46
45
if self ._skip_node_type (nodes_group [0 ]):
@@ -54,7 +53,11 @@ def _group_nodes_by_type_inputs(graph):
54
53
# default const of graph input cannot be merged
55
54
if node .is_graph_input_default_const ():
56
55
continue
57
- res [_KeyToGroupNodes (node .type , tuple (node .input ))].append (node )
56
+ tensor_data_hash = None
57
+ if node .is_const ():
58
+ # Many constants have the same size so this is helpful
59
+ tensor_data_hash = hash (node .attr ['value' ].t .raw_data )
60
+ res [(node .type , tuple (node .input ), tensor_data_hash )].append (node )
58
61
return res
59
62
60
63
def _del_nodes_if_duplicated (self , nodes_group , graph ):
@@ -75,7 +78,7 @@ def _del_nodes_if_duplicated(self, nodes_group, graph):
75
78
def _have_equal_attr (self , node_1 , node_2 , graph ):
76
79
if node_1 .attr == node_2 .attr :
77
80
return True
78
- # above check guarantees consts here are able to be merged
81
+ # consts have a name attr that can differ among equal consts so they must be handled separately
79
82
if node_1 .is_const () and node_2 .is_const ():
80
83
# get_tensor_value is costly so that we check their shape first
81
84
shape_1 = graph .get_shape (node_1 .output [0 ])
0 commit comments