Skip to content

Commit bf99f71

Browse files
author
wayuanho
committed
merge const
1 parent 1f33f1f commit bf99f71

File tree

3 files changed

+68
-8
lines changed

3 files changed

+68
-8
lines changed

tests/test_optimizers.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from tf2onnx import utils
1313
from tf2onnx.graph import GraphUtil
1414
from backend_test_base import Tf2OnnxBackendTestBase
15-
from common import unittest_main, group_nodes_by_type, check_opset_min_version
15+
from common import unittest_main, group_nodes_by_type
1616

1717

1818
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
@@ -758,9 +758,12 @@ def test_identity_in_subgraph_non_graph_output(self):
758758
# Merge Duplicated Nodes Optimizer Tests Start
759759

760760
def run_merge_duplicated_nodes_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
761-
op_type=None, remaining_op_num=None, debug=False, rtol=1e-07):
762-
self.run_and_compare(output_names_with_port, onnx_feed_dict, origin_proto, op_type=op_type,
763-
remaining_op_num=remaining_op_num, debug=debug, rtol=rtol)
761+
op_type=None, remaining_op_num=None, debug=False, rtol=1e-07,
762+
graph_validator=None):
763+
new_proto = self.run_and_compare(output_names_with_port, onnx_feed_dict, origin_proto, op_type=op_type,
764+
remaining_op_num=remaining_op_num, debug=debug, rtol=rtol)
765+
if graph_validator:
766+
self.assertTrue(graph_validator(new_proto.graph))
764767

765768
def test_duplicated_duplicated_input(self):
766769
# same input or not
@@ -800,7 +803,10 @@ def test_duplicated_duplicated_attributes(self):
800803
self.run_merge_duplicated_nodes_compare(["OUT"], {"X": np.random.randn(5, 5).astype(np.float32)}, model_proto,
801804
op_type="ReduceSum", remaining_op_num=2)
802805

803-
@check_opset_min_version(9, "Constant")
806+
def _check_initializer_num(self, graph_proto, num):
807+
print(len(graph_proto.initializer))
808+
return num == len(graph_proto.initializer)
809+
804810
def test_duplicated_duplicated_constant(self):
805811
const_val = np.array([1, 2, 3], dtype=np.float32)
806812
tensor_1 = helper.make_tensor("tensor_1", TensorProto.FLOAT, const_val.shape, const_val)
@@ -826,8 +832,35 @@ def test_duplicated_duplicated_constant(self):
826832
imp.version = self.config.opset
827833

828834
model_proto = helper.make_model(graph, producer_name="onnx-tests", opset_imports=[imp])
829-
self.run_merge_duplicated_nodes_compare(["OUT"], {}, model_proto,
830-
op_type="Constant", remaining_op_num=1)
835+
self.run_merge_duplicated_nodes_compare(["OUT"], {}, model_proto, op_type="Constant", remaining_op_num=0,
836+
graph_validator=lambda g: self._check_initializer_num(g, 1))
837+
838+
def test_duplicated_duplicated_constant_and_initializer(self):
839+
const_val = np.array([1, 2, 3], dtype=np.float32)
840+
tensor_1 = helper.make_tensor("value0", TensorProto.FLOAT, const_val.shape, const_val)
841+
tensor_2 = helper.make_tensor("value1", TensorProto.FLOAT, const_val.shape, const_val)
842+
tensor_3 = helper.make_tensor("value2", TensorProto.FLOAT, const_val.shape, const_val.tobytes(), raw=True)
843+
tensor_4 = helper.make_tensor("value3", TensorProto.FLOAT, const_val.shape, const_val.tobytes(), raw=True)
844+
node0 = helper.make_node('Constant', inputs=[], outputs=["value0"], value=tensor_1)
845+
node1 = helper.make_node('Constant', inputs=[], outputs=["value1"], value=tensor_2)
846+
node4 = helper.make_node("Mul", ["value0", "value1"], ["output1"])
847+
node5 = helper.make_node("Mul", ["value2", "output1"], ["output2"])
848+
node6 = helper.make_node("Mul", ["value3", "output2"], ["OUT"])
849+
850+
graph = helper.make_graph(
851+
[node0, node1, node4, node5, node6],
852+
"test_duplicated_duplicated_constant",
853+
[helper.make_tensor_value_info("value2", TensorProto.FLOAT, (3,))],
854+
[helper.make_tensor_value_info("OUT", TensorProto.FLOAT, (3,))],
855+
[tensor_3, tensor_4]
856+
)
857+
858+
imp = OperatorSetIdProto()
859+
imp.version = self.config.opset
860+
861+
model_proto = helper.make_model(graph, producer_name="onnx-tests", opset_imports=[imp])
862+
self.run_merge_duplicated_nodes_compare(["OUT"], {}, model_proto, op_type="Constant", remaining_op_num=0,
863+
graph_validator=lambda g: self._check_initializer_num(g, 2))
831864

832865
def test_duplicated_node_is_graph_output(self):
833866
node0 = helper.make_node('Add', inputs=["X", "X"], outputs=["value0"])

tf2onnx/graph.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from onnx import helper, numpy_helper, shape_inference, OperatorSetIdProto, AttributeProto, TensorProto
1919
from tf2onnx import utils, __version__
20-
from tf2onnx.utils import port_name, find_opset
20+
from tf2onnx.utils import make_name, port_name, find_opset
2121
from tf2onnx import optimizer
2222
from tf2onnx.schemas import get_schema, infer_onnx_shape_dtype
2323
from tf2onnx import constants
@@ -146,6 +146,11 @@ def is_const(self):
146146
def is_graph_input(self):
147147
return self.type in ["Placeholder", "PlaceholderWithDefault", "PlaceholderV2"]
148148

149+
def is_graph_input_default_const(self):
150+
return self.is_const() and any(
151+
out.is_graph_input() for out in self.graph.find_output_consumers(self.output[0])
152+
)
153+
149154
def __str__(self):
150155
return str(self._op)
151156

@@ -711,6 +716,20 @@ def add_graph_input(self, name, dtype=None, shape=None):
711716
new_node = self.make_node("Placeholder", [], outputs=[name], dtypes=[dtype], shapes=[shape])
712717
self._order_sensitive_inputs.append(new_node)
713718

719+
def add_graph_input_with_default(self, name, default_const, dtype=None, shape=None):
720+
"""Add placeholderwithdefault."""
721+
if dtype is None:
722+
dtype = self.get_dtype(name)
723+
724+
if shape is None:
725+
shape = self.get_shape(name)
726+
727+
default_const_name = port_name(make_name("{}_default".format(name)))
728+
default_const.output = [default_const_name]
729+
new_node = self.make_node("PlaceholderWithDefault", [default_const_name], outputs=[name],
730+
dtypes=[dtype], shapes=[shape])
731+
self._order_sensitive_inputs.append(new_node)
732+
714733
def add_graph_output(self, name, dtype=None, shape=None):
715734
"""Add node output as graph's output."""
716735
utils.make_sure(name in self._output_to_node_name, "output %s not exist in the graph", name)
@@ -765,6 +784,8 @@ def set_shape(self, name, val):
765784
"""Set new shape of node."""
766785
if isinstance(val, np.ndarray):
767786
val = val.tolist()
787+
if isinstance(val, tuple):
788+
val = list(val)
768789
node = self.get_node_by_output(name, search_in_parent_graphs=True)
769790
utils.make_sure(node is not None, "cannot find node by output id %s", name)
770791
node.graph._output_shapes[name] = val
@@ -1358,3 +1379,5 @@ def _parse_graph_input(g, graph_proto, const_node_names):
13581379
dtype = dtypes[name]
13591380
if name not in const_node_names:
13601381
g.add_graph_input(name, dtype, shape)
1382+
else:
1383+
g.add_graph_input_with_default(name, g.get_node_by_name(name), dtype, shape)

tf2onnx/optimizer/merge_duplicated_nodes_optimizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ def _merge_duplicated_nodes(self, graph):
5151
def _group_nodes_by_type_inputs(graph):
5252
res = defaultdict(list)
5353
for node in graph.get_nodes():
54+
# default const of graph input cannot be merged
55+
if node.is_graph_input_default_const():
56+
continue
5457
res[_KeyToGroupNodes(node.type, tuple(node.input))].append(node)
5558
return res
5659

@@ -72,6 +75,7 @@ def _del_nodes_if_duplicated(self, nodes_group, graph):
7275
def _have_equal_attr(self, node_1, node_2, graph):
7376
if node_1.attr == node_2.attr:
7477
return True
78+
# above check guarantees consts here are able to be merged
7579
if node_1.is_const() and node_2.is_const():
7680
# get_tensor_value is costly so that we check their shape first
7781
shape_1 = graph.get_shape(node_1.output[0])

0 commit comments

Comments
 (0)