@@ -147,14 +147,6 @@ def _convert_shapenode_to_int64(ctx, node, input_number):
147
147
"""cast int32 shape into int64 shape."""
148
148
shape_node = node .inputs [input_number ]
149
149
name = node .input [input_number ]
150
- if shape_node .is_const ():
151
- # if it is a const, change the const to be int64
152
- shape = shape_node .get_tensor_value ()
153
- shape = np .array (list (shape ), dtype = np .int64 )
154
- shape_node .set_tensor_value (shape )
155
- ctx .set_dtype (shape_node .output [0 ], onnx_pb .TensorProto .INT64 )
156
- ctx .copy_shape (name , shape_node .output [0 ])
157
- return [node ]
158
150
159
151
cast_node = ctx .insert_new_node_on_input (node , "Cast" , name )
160
152
cast_node .set_attr ("to" , onnx_pb .TensorProto .INT64 )
@@ -902,6 +894,7 @@ def pad_op(ctx, node, name, args):
902
894
# or PadV2(T input, int32 paddings, T constant_value, @type Tpaddings), CONST mode - default value specified
903
895
# or MirrorPad(T input, int32 paddings, @type Tpaddings, @STRING mode), other mode.
904
896
# T output = Pad(T data, @STRING mode, @INTS pads, @FLOAT value)
897
+ nodes = [node ]
905
898
paddings = np .array (node .inputs [1 ].get_tensor_value ()).transpose ().flatten ()
906
899
mode = node .get_attr ("mode" )
907
900
if mode :
@@ -917,7 +910,24 @@ def pad_op(ctx, node, name, args):
917
910
918
911
ctx .remove_input (node , node .input [1 ])
919
912
node .set_attr ("pads" , paddings )
920
- return node
913
+
914
+ origin_dtype = ctx .get_dtype (node .output [0 ])
915
+ if origin_dtype not in [onnx_pb .TensorProto .FLOAT16 , onnx_pb .TensorProto .FLOAT ,
916
+ onnx_pb .TensorProto .DOUBLE ]:
917
+ cast_node = ctx .insert_new_node_on_input (node , "Cast" , node .input [0 ])
918
+ cast_node .set_attr ("to" , onnx_pb .TensorProto .FLOAT )
919
+ ctx .set_dtype (cast_node .output [0 ], onnx_pb .TensorProto .FLOAT )
920
+ ctx .copy_shape (name , cast_node .output [0 ])
921
+ nodes .append (cast_node )
922
+
923
+ cast_back_node = ctx .insert_new_node_on_output ("Cast" , node .output [0 ],
924
+ name = utils .make_name (node .name ) + "_castback" )
925
+ cast_back_node .set_attr ("to" , origin_dtype )
926
+ ctx .set_dtype (cast_back_node .output [0 ], origin_dtype )
927
+ ctx .copy_shape (name , cast_back_node .output [0 ])
928
+ nodes .append (cast_back_node )
929
+
930
+ return nodes
921
931
922
932
923
933
def rsqrt_op (ctx , node , name , args ):
@@ -1222,11 +1232,6 @@ def minmax_op(ctx, node, name, args):
1222
1232
1223
1233
1224
1234
def pack_op (ctx , node , name , args ):
1225
- # in tf, "pack" can accept one input tensor which means doing nothing,
1226
- # so remove the node in ONNX
1227
- if len (node .inputs ) == 1 :
1228
- ctx .replace_all_inputs (ctx .get_nodes (), node .output [0 ], node .input [0 ])
1229
- return None
1230
1235
1231
1236
# hack to make up for the missing onnx pack op
1232
1237
axis = node .get_attr ("axis" ).i
@@ -1650,13 +1655,15 @@ def reduce_logic_op(ctx, node, name, args):
1650
1655
1651
1656
utils .make_sure (all (i >= 0 for i in reduce_dim ), "negative reduce axis is not supported in onnx for now" )
1652
1657
1653
- cast = ctx .make_node (op_type = "Cast" , inputs = [node .input [0 ]], attr = {"to" : onnx_pb .TensorProto .INT32 })
1658
+ cast = ctx .make_node (op_type = "Cast" , inputs = [node .input [0 ]], attr = {"to" : onnx_pb .TensorProto .FLOAT })
1654
1659
keepdims = helper .get_attribute_value (node .get_attr ("keep_dims" ))
1655
1660
op_type = "ReduceMin" if node .type == "All" else "ReduceSum"
1656
1661
reduce_node = ctx .make_node (op_type = op_type , inputs = cast .output , attr = {"axes" : reduce_dim , "keepdims" : keepdims })
1657
- res = ctx .make_node (op_type = "Cast" , inputs = reduce_node .output , attr = {"to" : onnx_pb .TensorProto .BOOL },
1662
+
1663
+ zero_node = ctx .make_const (utils .make_name ("zero_reduce" ), np .array (0 , dtype = np .float32 ))
1664
+ res = ctx .make_node (op_type = "Greater" , inputs = [reduce_node .output [0 ], zero_node .output [0 ]],
1658
1665
name = node .name , outputs = node .output )
1659
- return [cast , reduce_node , res ]
1666
+ return [cast , reduce_node , zero_node , res ]
1660
1667
1661
1668
1662
1669
def zeroslike_op (ctx , node , name , args ):
0 commit comments