Skip to content

Commit 451149e

Browse files
merge constant with same value
1 parent aa19744 commit 451149e

File tree

2 files changed

+45
-3
lines changed

2 files changed

+45
-3
lines changed

tests/test_optimizers.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
from __future__ import unicode_literals
99

1010
import numpy as np
11-
from onnx import helper, TensorProto
11+
from onnx import helper, TensorProto, OperatorSetIdProto
1212
from tf2onnx import utils
1313
from tf2onnx.graph import GraphUtil
1414
from backend_test_base import Tf2OnnxBackendTestBase
15-
from common import unittest_main, group_nodes_by_type
15+
from common import unittest_main, group_nodes_by_type, check_opset_min_version
1616

1717

1818
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
@@ -629,6 +629,35 @@ def test_duplicated_duplicated_attributes(self):
629629
self.run_merge_duplicated_nodes_compare(["OUT"], {"X": np.random.randn(5, 5).astype(np.float32)}, model_proto,
630630
op_type="ReduceSum", remaining_op_num=2)
631631

632+
@check_opset_min_version(9, "Constant")
633+
def test_duplicated_duplicated_constant(self):
634+
const_val = np.array([1, 2, 3], dtype=np.float32)
635+
tensor_1 = helper.make_tensor("tensor_1", TensorProto.FLOAT, const_val.shape, const_val)
636+
tensor_2 = helper.make_tensor("tensor_2", TensorProto.FLOAT, const_val.shape, const_val)
637+
tensor_3 = helper.make_tensor("tensor_3", TensorProto.FLOAT, const_val.shape, const_val.tobytes(), raw=True)
638+
tensor_4 = helper.make_tensor("tensor_4", TensorProto.FLOAT, const_val.shape, const_val.tobytes(), raw=True)
639+
node0 = helper.make_node('Constant', inputs=[], outputs=["value0"], value=tensor_1)
640+
node1 = helper.make_node('Constant', inputs=[], outputs=["value1"], value=tensor_2)
641+
node2 = helper.make_node('Constant', inputs=[], outputs=["value2"], value=tensor_3)
642+
node3 = helper.make_node('Constant', inputs=[], outputs=["value3"], value=tensor_4)
643+
node4 = helper.make_node("Mul", ["value0", "value1"], ["output1"])
644+
node5 = helper.make_node("Mul", ["value2", "output1"], ["output2"])
645+
node6 = helper.make_node("Mul", ["value3", "output2"], ["OUT"])
646+
647+
graph = helper.make_graph(
648+
[node0, node1, node2, node3, node4, node5, node6],
649+
"test_duplicated_duplicated_constant",
650+
[],
651+
[helper.make_tensor_value_info("OUT", TensorProto.FLOAT, (3,))],
652+
)
653+
654+
imp = OperatorSetIdProto()
655+
imp.version = self.config.opset
656+
657+
model_proto = helper.make_model(graph, producer_name="onnx-tests", opset_imports=[imp])
658+
self.run_merge_duplicated_nodes_compare(["OUT"], {}, model_proto,
659+
op_type="Constant", remaining_op_num=1)
660+
632661
def test_duplicated_node_is_graph_output(self):
633662
node0 = helper.make_node('Add', inputs=["X", "X"], outputs=["value0"])
634663
node1 = helper.make_node('Add', inputs=["X", "X"], outputs=["value1"])

tf2onnx/optimizer/merge_duplicated_nodes_optimizer.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
from collections import defaultdict, namedtuple
1111

12+
import numpy as np
13+
1214
from .optimizer_base import GraphOptimizerBase
1315

1416
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring
@@ -59,14 +61,25 @@ def _del_nodes_if_duplicated(self, nodes_group, graph):
5961
unprocessed_node = []
6062
nodes_to_process = [nodes_group[0]]
6163
for node in nodes_group[1:]:
62-
if node.attr == nodes_to_process[0].attr:
64+
if self._have_equal_attr(node, nodes_to_process[0]):
6365
nodes_to_process.append(node)
6466
else:
6567
unprocessed_node.append(node)
6668

6769
self._merge_nodes_that_are_duplicated(nodes_to_process, graph)
6870
nodes_group = unprocessed_node
6971

72+
def _have_equal_attr(self, node_1, node_2):
73+
if node_1.attr == node_2.attr:
74+
return True
75+
if node_1.is_const() and node_2.is_const():
76+
const_1 = node_1.get_tensor_value(as_list=False)
77+
const_2 = node_2.get_tensor_value(as_list=False)
78+
if const_1.dtype == const_2.dtype and \
79+
np.array_equal(const_1, const_2):
80+
return True
81+
return False
82+
7083
def _merge_nodes_that_are_duplicated(self, nodes_to_process, graph):
7184
# node's output may not all be used, so have to select the one that uses most of node's outputs
7285
nodes_to_process.sort(key=self._len_of_node_output, reverse=True)

0 commit comments

Comments
 (0)