Skip to content

Commit 1f33f1f

Browse files
author
wayuanho
committed
Revert "map Const to Constant instead of initializer"
This reverts commit 01a4343.
1 parent 1da54e1 commit 1f33f1f

File tree

1 file changed

+12
-16
lines changed

1 file changed

+12
-16
lines changed

tf2onnx/graph.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def is_nhwc(self):
141141

142142
def is_const(self):
143143
"""Return True if node is a constant."""
144-
return self.type in ["Const", "ConstV2", "Constant"]
144+
return self.type in ["Const", "ConstV2"]
145145

146146
def is_graph_input(self):
147147
return self.type in ["Placeholder", "PlaceholderWithDefault", "PlaceholderV2"]
@@ -862,6 +862,7 @@ def make_graph(self, doc, graph_name="tf2onnx"):
862862
for op in self.get_nodes():
863863
if op.is_const():
864864
const_ops.append(op)
865+
continue
865866
elif op.is_graph_input():
866867
if op not in self._order_sensitive_inputs:
867868
order_non_sensitive_placeholders.append(op)
@@ -871,6 +872,7 @@ def make_graph(self, doc, graph_name="tf2onnx"):
871872

872873
# create initializers for placeholder with default nodes
873874
initializers = []
875+
placeholder_default_const_ops = []
874876
for op in placeholder_ops:
875877
if op.type == "PlaceholderWithDefault":
876878
utils.make_sure(op.inputs[0] is not None, "Cannot find node with output {}".format(op.input[0]))
@@ -880,24 +882,18 @@ def make_graph(self, doc, graph_name="tf2onnx"):
880882
value = op.inputs[0].get_tensor_value(as_list=False)
881883
tensor = numpy_helper.from_array(value, op.output[0])
882884
initializers.append(tensor)
883-
const_ops.remove(op.inputs[0])
884-
ops.remove(op.inputs[0])
885+
placeholder_default_const_ops.append(op.inputs[0])
885886

886887
# create initializers for constant nodes
888+
const_ops = [op for op in const_ops if op not in placeholder_default_const_ops]
887889
for op in const_ops:
888-
# Constant support more dtypes after opset 9
889-
if self.opset < 9:
890-
# not to use numpy_helper.from_array to create a new tensor
891-
# because sometimes onnx will have a bug that only check the tensor data in specific field
892-
# such as at upsample it only checks the float_data field.
893-
t = op.get_attr("value")
894-
tensor = helper.get_attribute_value(t)
895-
tensor.name = op.output[0]
896-
initializers.append(tensor)
897-
ops.remove(op)
898-
else:
899-
op.type = "Constant"
900-
op.update_proto()
890+
# not to use numpy_helper.from_array to create a new tensor
891+
# because sometimes onnx will have a bug that only check the tensor data in specific field
892+
# such as at upsample it only checks the float_data field.
893+
t = op.get_attr("value")
894+
tensor = helper.get_attribute_value(t)
895+
tensor.name = op.output[0]
896+
initializers.append(tensor)
901897

902898
# create input_tensor_values
903899
input_ids = [op.output[0] for op in placeholder_ops]

0 commit comments

Comments
 (0)