Skip to content

Commit 9960de5

Browse files
author
wayuanho
authored
Merge pull request #647 from lucienwang1009/revert_constant
Revert constant
2 parents fd16d28 + bf99f71 commit 9960de5

File tree

3 files changed

+80
-24
lines changed

3 files changed

+80
-24
lines changed

tests/test_optimizers.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from tf2onnx import utils
1313
from tf2onnx.graph import GraphUtil
1414
from backend_test_base import Tf2OnnxBackendTestBase
15-
from common import unittest_main, group_nodes_by_type, check_opset_min_version
15+
from common import unittest_main, group_nodes_by_type
1616

1717

1818
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
@@ -758,9 +758,12 @@ def test_identity_in_subgraph_non_graph_output(self):
758758
# Merge Duplicated Nodes Optimizer Tests Start
759759

760760
def run_merge_duplicated_nodes_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
761-
op_type=None, remaining_op_num=None, debug=False, rtol=1e-07):
762-
self.run_and_compare(output_names_with_port, onnx_feed_dict, origin_proto, op_type=op_type,
763-
remaining_op_num=remaining_op_num, debug=debug, rtol=rtol)
761+
op_type=None, remaining_op_num=None, debug=False, rtol=1e-07,
762+
graph_validator=None):
763+
new_proto = self.run_and_compare(output_names_with_port, onnx_feed_dict, origin_proto, op_type=op_type,
764+
remaining_op_num=remaining_op_num, debug=debug, rtol=rtol)
765+
if graph_validator:
766+
self.assertTrue(graph_validator(new_proto.graph))
764767

765768
def test_duplicated_duplicated_input(self):
766769
# same input or not
@@ -800,7 +803,10 @@ def test_duplicated_duplicated_attributes(self):
800803
self.run_merge_duplicated_nodes_compare(["OUT"], {"X": np.random.randn(5, 5).astype(np.float32)}, model_proto,
801804
op_type="ReduceSum", remaining_op_num=2)
802805

803-
@check_opset_min_version(9, "Constant")
806+
def _check_initializer_num(self, graph_proto, num):
807+
print(len(graph_proto.initializer))
808+
return num == len(graph_proto.initializer)
809+
804810
def test_duplicated_duplicated_constant(self):
805811
const_val = np.array([1, 2, 3], dtype=np.float32)
806812
tensor_1 = helper.make_tensor("tensor_1", TensorProto.FLOAT, const_val.shape, const_val)
@@ -826,8 +832,35 @@ def test_duplicated_duplicated_constant(self):
826832
imp.version = self.config.opset
827833

828834
model_proto = helper.make_model(graph, producer_name="onnx-tests", opset_imports=[imp])
829-
self.run_merge_duplicated_nodes_compare(["OUT"], {}, model_proto,
830-
op_type="Constant", remaining_op_num=1)
835+
self.run_merge_duplicated_nodes_compare(["OUT"], {}, model_proto, op_type="Constant", remaining_op_num=0,
836+
graph_validator=lambda g: self._check_initializer_num(g, 1))
837+
838+
def test_duplicated_duplicated_constant_and_initializer(self):
839+
const_val = np.array([1, 2, 3], dtype=np.float32)
840+
tensor_1 = helper.make_tensor("value0", TensorProto.FLOAT, const_val.shape, const_val)
841+
tensor_2 = helper.make_tensor("value1", TensorProto.FLOAT, const_val.shape, const_val)
842+
tensor_3 = helper.make_tensor("value2", TensorProto.FLOAT, const_val.shape, const_val.tobytes(), raw=True)
843+
tensor_4 = helper.make_tensor("value3", TensorProto.FLOAT, const_val.shape, const_val.tobytes(), raw=True)
844+
node0 = helper.make_node('Constant', inputs=[], outputs=["value0"], value=tensor_1)
845+
node1 = helper.make_node('Constant', inputs=[], outputs=["value1"], value=tensor_2)
846+
node4 = helper.make_node("Mul", ["value0", "value1"], ["output1"])
847+
node5 = helper.make_node("Mul", ["value2", "output1"], ["output2"])
848+
node6 = helper.make_node("Mul", ["value3", "output2"], ["OUT"])
849+
850+
graph = helper.make_graph(
851+
[node0, node1, node4, node5, node6],
852+
"test_duplicated_duplicated_constant",
853+
[helper.make_tensor_value_info("value2", TensorProto.FLOAT, (3,))],
854+
[helper.make_tensor_value_info("OUT", TensorProto.FLOAT, (3,))],
855+
[tensor_3, tensor_4]
856+
)
857+
858+
imp = OperatorSetIdProto()
859+
imp.version = self.config.opset
860+
861+
model_proto = helper.make_model(graph, producer_name="onnx-tests", opset_imports=[imp])
862+
self.run_merge_duplicated_nodes_compare(["OUT"], {}, model_proto, op_type="Constant", remaining_op_num=0,
863+
graph_validator=lambda g: self._check_initializer_num(g, 2))
831864

832865
def test_duplicated_node_is_graph_output(self):
833866
node0 = helper.make_node('Add', inputs=["X", "X"], outputs=["value0"])

tf2onnx/graph.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from onnx import helper, numpy_helper, shape_inference, OperatorSetIdProto, AttributeProto, TensorProto
1919
from tf2onnx import utils, __version__
20-
from tf2onnx.utils import port_name, find_opset
20+
from tf2onnx.utils import make_name, port_name, find_opset
2121
from tf2onnx import optimizer
2222
from tf2onnx.schemas import get_schema, infer_onnx_shape_dtype
2323
from tf2onnx import constants
@@ -141,11 +141,16 @@ def is_nhwc(self):
141141

142142
def is_const(self):
143143
"""Return True if node is a constant."""
144-
return self.type in ["Const", "ConstV2", "Constant"]
144+
return self.type in ["Const", "ConstV2"]
145145

146146
def is_graph_input(self):
147147
return self.type in ["Placeholder", "PlaceholderWithDefault", "PlaceholderV2"]
148148

149+
def is_graph_input_default_const(self):
150+
return self.is_const() and any(
151+
out.is_graph_input() for out in self.graph.find_output_consumers(self.output[0])
152+
)
153+
149154
def __str__(self):
150155
return str(self._op)
151156

@@ -711,6 +716,20 @@ def add_graph_input(self, name, dtype=None, shape=None):
711716
new_node = self.make_node("Placeholder", [], outputs=[name], dtypes=[dtype], shapes=[shape])
712717
self._order_sensitive_inputs.append(new_node)
713718

719+
def add_graph_input_with_default(self, name, default_const, dtype=None, shape=None):
720+
"""Add placeholderwithdefault."""
721+
if dtype is None:
722+
dtype = self.get_dtype(name)
723+
724+
if shape is None:
725+
shape = self.get_shape(name)
726+
727+
default_const_name = port_name(make_name("{}_default".format(name)))
728+
default_const.output = [default_const_name]
729+
new_node = self.make_node("PlaceholderWithDefault", [default_const_name], outputs=[name],
730+
dtypes=[dtype], shapes=[shape])
731+
self._order_sensitive_inputs.append(new_node)
732+
714733
def add_graph_output(self, name, dtype=None, shape=None):
715734
"""Add node output as graph's output."""
716735
utils.make_sure(name in self._output_to_node_name, "output %s not exist in the graph", name)
@@ -765,6 +784,8 @@ def set_shape(self, name, val):
765784
"""Set new shape of node."""
766785
if isinstance(val, np.ndarray):
767786
val = val.tolist()
787+
if isinstance(val, tuple):
788+
val = list(val)
768789
node = self.get_node_by_output(name, search_in_parent_graphs=True)
769790
utils.make_sure(node is not None, "cannot find node by output id %s", name)
770791
node.graph._output_shapes[name] = val
@@ -862,6 +883,7 @@ def make_graph(self, doc, graph_name="tf2onnx"):
862883
for op in self.get_nodes():
863884
if op.is_const():
864885
const_ops.append(op)
886+
continue
865887
elif op.is_graph_input():
866888
if op not in self._order_sensitive_inputs:
867889
order_non_sensitive_placeholders.append(op)
@@ -871,6 +893,7 @@ def make_graph(self, doc, graph_name="tf2onnx"):
871893

872894
# create initializers for placeholder with default nodes
873895
initializers = []
896+
placeholder_default_const_ops = []
874897
for op in placeholder_ops:
875898
if op.type == "PlaceholderWithDefault":
876899
utils.make_sure(op.inputs[0] is not None, "Cannot find node with output {}".format(op.input[0]))
@@ -880,24 +903,18 @@ def make_graph(self, doc, graph_name="tf2onnx"):
880903
value = op.inputs[0].get_tensor_value(as_list=False)
881904
tensor = numpy_helper.from_array(value, op.output[0])
882905
initializers.append(tensor)
883-
const_ops.remove(op.inputs[0])
884-
ops.remove(op.inputs[0])
906+
placeholder_default_const_ops.append(op.inputs[0])
885907

886908
# create initializers for constant nodes
909+
const_ops = [op for op in const_ops if op not in placeholder_default_const_ops]
887910
for op in const_ops:
888-
# Constant support more dtypes after opset 9
889-
if self.opset < 9:
890-
# not to use numpy_helper.from_array to create a new tensor
891-
# because sometimes onnx will have a bug that only check the tensor data in specific field
892-
# such as at upsample it only checks the float_data field.
893-
t = op.get_attr("value")
894-
tensor = helper.get_attribute_value(t)
895-
tensor.name = op.output[0]
896-
initializers.append(tensor)
897-
ops.remove(op)
898-
else:
899-
op.type = "Constant"
900-
op.update_proto()
911+
# not to use numpy_helper.from_array to create a new tensor
912+
# because sometimes onnx will have a bug that only check the tensor data in specific field
913+
# such as at upsample it only checks the float_data field.
914+
t = op.get_attr("value")
915+
tensor = helper.get_attribute_value(t)
916+
tensor.name = op.output[0]
917+
initializers.append(tensor)
901918

902919
# create input_tensor_values
903920
input_ids = [op.output[0] for op in placeholder_ops]
@@ -1362,3 +1379,5 @@ def _parse_graph_input(g, graph_proto, const_node_names):
13621379
dtype = dtypes[name]
13631380
if name not in const_node_names:
13641381
g.add_graph_input(name, dtype, shape)
1382+
else:
1383+
g.add_graph_input_with_default(name, g.get_node_by_name(name), dtype, shape)

tf2onnx/optimizer/merge_duplicated_nodes_optimizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ def _merge_duplicated_nodes(self, graph):
5151
def _group_nodes_by_type_inputs(graph):
5252
res = defaultdict(list)
5353
for node in graph.get_nodes():
54+
# default const of graph input cannot be merged
55+
if node.is_graph_input_default_const():
56+
continue
5457
res[_KeyToGroupNodes(node.type, tuple(node.input))].append(node)
5558
return res
5659

@@ -72,6 +75,7 @@ def _del_nodes_if_duplicated(self, nodes_group, graph):
7275
def _have_equal_attr(self, node_1, node_2, graph):
7376
if node_1.attr == node_2.attr:
7477
return True
78+
# above check guarantees consts here are able to be merged
7579
if node_1.is_const() and node_2.is_const():
7680
# get_tensor_value is costly so that we check their shape first
7781
shape_1 = graph.get_shape(node_1.output[0])

0 commit comments

Comments
 (0)