Skip to content

Commit 86ad416

Browse files
authored
Merge pull request #616 from lucienwang1009/constant
map Const to Constant instead of initializer for opset>=9
2 parents 41fad17 + 01a4343 commit 86ad416

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

tf2onnx/graph.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def is_nhwc(self):
141141

142142
def is_const(self):
143143
"""Return True if node is a constant."""
144-
return self.type in ["Const", "ConstV2"]
144+
return self.type in ["Const", "ConstV2", "Constant"]
145145

146146
def is_graph_input(self):
147147
return self.type in ["Placeholder", "PlaceholderWithDefault", "PlaceholderV2"]
@@ -844,7 +844,6 @@ def make_graph(self, doc, graph_name="tf2onnx"):
844844
for op in self.get_nodes():
845845
if op.is_const():
846846
const_ops.append(op)
847-
continue
848847
elif op.is_graph_input():
849848
if op not in self._order_sensitive_inputs:
850849
order_non_sensitive_placeholders.append(op)
@@ -854,7 +853,6 @@ def make_graph(self, doc, graph_name="tf2onnx"):
854853

855854
# create initializers for placeholder with default nodes
856855
initializers = []
857-
placeholder_default_const_ops = []
858856
for op in placeholder_ops:
859857
if op.type == "PlaceholderWithDefault":
860858
utils.make_sure(op.inputs[0] is not None, "Cannot find node with output {}".format(op.input[0]))
@@ -864,18 +862,24 @@ def make_graph(self, doc, graph_name="tf2onnx"):
864862
value = op.inputs[0].get_tensor_value(as_list=False)
865863
tensor = numpy_helper.from_array(value, op.output[0])
866864
initializers.append(tensor)
867-
placeholder_default_const_ops.append(op.inputs[0])
865+
const_ops.remove(op.inputs[0])
866+
ops.remove(op.inputs[0])
868867

869868
# create initializers for constant nodes
870-
const_ops = [op for op in const_ops if op not in placeholder_default_const_ops]
871869
for op in const_ops:
872-
# not to use numpy_helper.from_array to create a new tensor
873-
# because sometimes onnx will have a bug that only check the tensor data in specific field
874-
# such as at upsample it only checks the float_data field.
875-
t = op.get_attr("value")
876-
tensor = helper.get_attribute_value(t)
877-
tensor.name = op.output[0]
878-
initializers.append(tensor)
870+
# Constant support more dtypes after opset 9
871+
if self.opset < 9:
872+
# not to use numpy_helper.from_array to create a new tensor
873+
# because sometimes onnx will have a bug that only check the tensor data in specific field
874+
# such as at upsample it only checks the float_data field.
875+
t = op.get_attr("value")
876+
tensor = helper.get_attribute_value(t)
877+
tensor.name = op.output[0]
878+
initializers.append(tensor)
879+
ops.remove(op)
880+
else:
881+
op.type = "Constant"
882+
op.update_proto()
879883

880884
# create input_tensor_values
881885
input_ids = [op.output[0] for op in placeholder_ops]

0 commit comments

Comments
 (0)