Skip to content

Commit 11b2b8f

Browse files
authored
Merge pull request #639 from lucienwang1009/refine_merge_const
refine merge const
2 parents 831d42e + e0bc75d commit 11b2b8f

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

tf2onnx/optimizer/merge_duplicated_nodes_optimizer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,18 +61,24 @@ def _del_nodes_if_duplicated(self, nodes_group, graph):
6161
unprocessed_node = []
6262
nodes_to_process = [nodes_group[0]]
6363
for node in nodes_group[1:]:
64-
if self._have_equal_attr(node, nodes_to_process[0]):
64+
if self._have_equal_attr(node, nodes_to_process[0], graph):
6565
nodes_to_process.append(node)
6666
else:
6767
unprocessed_node.append(node)
6868

6969
self._merge_nodes_that_are_duplicated(nodes_to_process, graph)
7070
nodes_group = unprocessed_node
7171

72-
def _have_equal_attr(self, node_1, node_2):
72+
def _have_equal_attr(self, node_1, node_2, graph):
7373
if node_1.attr == node_2.attr:
7474
return True
7575
if node_1.is_const() and node_2.is_const():
76+
# get_tensor_value is costly so that we check their shape first
77+
shape_1 = graph.get_shape(node_1.output[0])
78+
shape_2 = graph.get_shape(node_2.output[0])
79+
if shape_1 is not None and shape_2 is not None and \
80+
shape_1 != shape_2:
81+
return False
7682
const_1 = node_1.get_tensor_value(as_list=False)
7783
const_2 = node_2.get_tensor_value(as_list=False)
7884
if const_1.dtype == const_2.dtype and \

0 commit comments

Comments
 (0)