@@ -648,12 +648,16 @@ def sign_op4(ctx, node, name, args):
648
648
node_dtype = ctx .get_dtype (node .output [0 ])
649
649
utils .make_sure (node_dtype , "Dtype of {} is None" .format (node .name ))
650
650
if node_dtype in [onnx_pb .TensorProto .COMPLEX64 , onnx_pb .TensorProto .COMPLEX128 ]:
651
- raise ValueError ("dtype " + node_dtype + " is not supported in onnx for now" )
652
- input_tensor_type = utils .map_onnx_to_numpy_type (node_dtype )
651
+ raise ValueError ("dtype " + str (node_dtype ) + " is not supported in onnx for now" )
653
652
zero_name = utils .make_name ("{}_zero" .format (node .name ))
654
- ctx .make_const (zero_name , np .array (0 , dtype = input_tensor_type ))
655
- greater_node = ctx .make_node ("Greater" , [node .input [0 ], zero_name ])
656
- less_node = ctx .make_node ("Less" , [node .input [0 ], zero_name ])
653
+ ctx .make_const (zero_name , np .array (0 , dtype = utils .ONNX_TO_NUMPY_DTYPE [1 ]))
654
+ if node_dtype in [onnx_pb .TensorProto .INT32 , onnx_pb .TensorProto .INT64 ]:
655
+ cast_node_0 = ctx .make_node ("Cast" , [node .input [0 ]], {"to" : 1 })
656
+ greater_node = ctx .make_node ("Greater" , [cast_node_0 .output [0 ], zero_name ])
657
+ less_node = ctx .make_node ("Less" , [cast_node_0 .output [0 ], zero_name ])
658
+ else :
659
+ greater_node = ctx .make_node ("Greater" , [node .input [0 ], zero_name ])
660
+ less_node = ctx .make_node ("Less" , [node .input [0 ], zero_name ])
657
661
cast_node_1 = ctx .make_node ("Cast" , [greater_node .output [0 ]], {"to" : node_dtype })
658
662
cast_node_2 = ctx .make_node ("Cast" , [less_node .output [0 ]], {"to" : node_dtype })
659
663
0 commit comments