@@ -690,6 +690,18 @@ def sign_op(ctx, node, name, args):
690
690
return nodes
691
691
692
692
693
+ def sign_op9 (ctx , node , name , args ):
694
+ # Currently supported: `float32`
695
+ # Ignored: `bfloat16`
696
+ # TODO: add support for `int32`, `int64`
697
+ node_dtype = ctx .get_dtype (node .output [0 ])
698
+ utils .make_sure (node_dtype , "Dtype of {} is None" .format (node .name ))
699
+ if node_dtype in [onnx_pb .TensorProto .BOOL , onnx_pb .TensorProto .FLOAT16 ,
700
+ onnx_pb .TensorProto .COMPLEX64 , onnx_pb .TensorProto .COMPLEX128 ]:
701
+ raise ValueError ("dtype " + str (node_dtype ) + " is not supported in onnx for now" )
702
+ return node
703
+
704
+
693
705
def biasadd_op (ctx , node , name , args ):
694
706
# T output = BiasAdd(T value, T bias, @string data_format)
695
707
# T output = BiasAddV1(T value, T bias)
@@ -1868,6 +1880,7 @@ def where_op(ctx, node, name, args):
1868
1880
"Less" : (logical_compare_op , []),
1869
1881
"ResizeBilinear" : (upsample_op9 , ["Upsample" , "linear" ]),
1870
1882
"ResizeNearestNeighbor" : (upsample_op9 , ["Upsample" , "nearest" ]),
1883
+ "Sign" : (sign_op9 , []),
1871
1884
"Sinh" : (direct_op , []),
1872
1885
"Where" : (where_op , []),
1873
1886
}
0 commit comments