@@ -141,7 +141,7 @@ def is_nhwc(self):
141
141
142
142
def is_const (self ):
143
143
"""Return True if node is a constant."""
144
- return self .type in ["Const" , "ConstV2" ]
144
+ return self .type in ["Const" , "ConstV2" , "Constant" ]
145
145
146
146
def is_graph_input (self ):
147
147
return self .type in ["Placeholder" , "PlaceholderWithDefault" , "PlaceholderV2" ]
@@ -844,7 +844,6 @@ def make_graph(self, doc, graph_name="tf2onnx"):
844
844
for op in self .get_nodes ():
845
845
if op .is_const ():
846
846
const_ops .append (op )
847
- continue
848
847
elif op .is_graph_input ():
849
848
if op not in self ._order_sensitive_inputs :
850
849
order_non_sensitive_placeholders .append (op )
@@ -854,7 +853,6 @@ def make_graph(self, doc, graph_name="tf2onnx"):
854
853
855
854
# create initializers for placeholder with default nodes
856
855
initializers = []
857
- placeholder_default_const_ops = []
858
856
for op in placeholder_ops :
859
857
if op .type == "PlaceholderWithDefault" :
860
858
utils .make_sure (op .inputs [0 ] is not None , "Cannot find node with output {}" .format (op .input [0 ]))
@@ -864,18 +862,24 @@ def make_graph(self, doc, graph_name="tf2onnx"):
864
862
value = op .inputs [0 ].get_tensor_value (as_list = False )
865
863
tensor = numpy_helper .from_array (value , op .output [0 ])
866
864
initializers .append (tensor )
867
- placeholder_default_const_ops .append (op .inputs [0 ])
865
+ const_ops .remove (op .inputs [0 ])
866
+ ops .remove (op .inputs [0 ])
868
867
869
868
# create initializers for constant nodes
870
- const_ops = [op for op in const_ops if op not in placeholder_default_const_ops ]
871
869
for op in const_ops :
872
- # not to use numpy_helper.from_array to create a new tensor
873
- # because sometimes onnx will have a bug that only check the tensor data in specific field
874
- # such as at upsample it only checks the float_data field.
875
- t = op .get_attr ("value" )
876
- tensor = helper .get_attribute_value (t )
877
- tensor .name = op .output [0 ]
878
- initializers .append (tensor )
870
+ # Constant support more dtypes after opset 9
871
+ if self .opset < 9 :
872
+ # not to use numpy_helper.from_array to create a new tensor
873
+ # because sometimes onnx will have a bug that only check the tensor data in specific field
874
+ # such as at upsample it only checks the float_data field.
875
+ t = op .get_attr ("value" )
876
+ tensor = helper .get_attribute_value (t )
877
+ tensor .name = op .output [0 ]
878
+ initializers .append (tensor )
879
+ ops .remove (op )
880
+ else :
881
+ op .type = "Constant"
882
+ op .update_proto ()
879
883
880
884
# create input_tensor_values
881
885
input_ids = [op .output [0 ] for op in placeholder_ops ]
0 commit comments