@@ -670,7 +670,7 @@ def cast_op(ctx, node, name, args):
670
670
return node
671
671
672
672
673
- def sign_op (ctx , node , name , args ):
673
+ def sign_op4 (ctx , node , name , args ):
674
674
"""Sign op."""
675
675
# T sign = Sign(T Input)
676
676
nodes = []
@@ -690,6 +690,14 @@ def sign_op(ctx, node, name, args):
690
690
return nodes
691
691
692
692
693
+ def sign_op9 (ctx , node , name , args ):
694
+ node_dtype = ctx .get_dtype (node .output [0 ])
695
+ utils .make_sure (node_dtype , "Dtype of {} is None" .format (node .name ))
696
+ if node_dtype in [onnx_pb .TensorProto .BOOL , onnx_pb .TensorProto .COMPLEX64 , onnx_pb .TensorProto .COMPLEX128 ]:
697
+ raise ValueError ("dtype " + str (node_dtype ) + " is not supported in onnx for now" )
698
+ return node
699
+
700
+
693
701
def biasadd_op (ctx , node , name , args ):
694
702
# T output = BiasAdd(T value, T bias, @string data_format)
695
703
# T output = BiasAddV1(T value, T bias)
@@ -1793,7 +1801,7 @@ def where_op(ctx, node, name, args):
1793
1801
"Pack" : (pack_op , []),
1794
1802
"Unpack" : (unpack_op , []),
1795
1803
"Erf" : (erf_op , []),
1796
- "Sign" : (sign_op , []),
1804
+ "Sign" : (sign_op4 , []),
1797
1805
"ZerosLike" : (zeroslike_op , []),
1798
1806
}
1799
1807
@@ -1868,6 +1876,7 @@ def where_op(ctx, node, name, args):
1868
1876
"Less" : (logical_compare_op , []),
1869
1877
"ResizeBilinear" : (upsample_op9 , ["Upsample" , "linear" ]),
1870
1878
"ResizeNearestNeighbor" : (upsample_op9 , ["Upsample" , "nearest" ]),
1879
+ "Sign" : (sign_op9 , []),
1871
1880
"Sinh" : (direct_op , []),
1872
1881
"Where" : (where_op , []),
1873
1882
}
0 commit comments