@@ -670,7 +670,7 @@ def cast_op(ctx, node, name, args):
670
670
return node
671
671
672
672
673
- def sign_op (ctx , node , name , args ):
673
+ def sign_op4 (ctx , node , name , args ):
674
674
"""Sign op."""
675
675
# T sign = Sign(T Input)
676
676
nodes = []
@@ -691,13 +691,9 @@ def sign_op(ctx, node, name, args):
691
691
692
692
693
693
def sign_op9 (ctx , node , name , args ):
694
- # Currently supported: `float32`
695
- # Ignored: `bfloat16`
696
- # TODO: add support for `int32`, `int64`
697
694
node_dtype = ctx .get_dtype (node .output [0 ])
698
695
utils .make_sure (node_dtype , "Dtype of {} is None" .format (node .name ))
699
- if node_dtype in [onnx_pb .TensorProto .BOOL , onnx_pb .TensorProto .FLOAT16 ,
700
- onnx_pb .TensorProto .COMPLEX64 , onnx_pb .TensorProto .COMPLEX128 ]:
696
+ if node_dtype in [onnx_pb .TensorProto .BOOL , onnx_pb .TensorProto .COMPLEX64 , onnx_pb .TensorProto .COMPLEX128 ]:
701
697
raise ValueError ("dtype " + str (node_dtype ) + " is not supported in onnx for now" )
702
698
return node
703
699
@@ -1805,7 +1801,7 @@ def where_op(ctx, node, name, args):
1805
1801
"Pack" : (pack_op , []),
1806
1802
"Unpack" : (unpack_op , []),
1807
1803
"Erf" : (erf_op , []),
1808
- "Sign" : (sign_op , []),
1804
+ "Sign" : (sign_op4 , []),
1809
1805
"ZerosLike" : (zeroslike_op , []),
1810
1806
}
1811
1807
0 commit comments