Skip to content

Commit cf3d913

Browse files
committed
implement sign for opset9
1 parent 52bf5e4 commit cf3d913

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

tf2onnx/tfonnx.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,18 @@ def sign_op(ctx, node, name, args):
690690
return nodes
691691

692692

693+
def sign_op9(ctx, node, name, args):
694+
# Currently supported: `float32`
695+
# Ignored: `bfloat16`
696+
# TODO: add support for `int32`, `int64`
697+
node_dtype = ctx.get_dtype(node.output[0])
698+
utils.make_sure(node_dtype, "Dtype of {} is None".format(node.name))
699+
if node_dtype in [onnx_pb.TensorProto.BOOL, onnx_pb.TensorProto.FLOAT16,
700+
onnx_pb.TensorProto.COMPLEX64, onnx_pb.TensorProto.COMPLEX128]:
701+
raise ValueError("dtype " + str(node_dtype) + " is not supported in onnx for now")
702+
return node
703+
704+
693705
def biasadd_op(ctx, node, name, args):
694706
# T output = BiasAdd(T value, T bias, @string data_format)
695707
# T output = BiasAddV1(T value, T bias)
@@ -1868,6 +1880,7 @@ def where_op(ctx, node, name, args):
18681880
"Less": (logical_compare_op, []),
18691881
"ResizeBilinear": (upsample_op9, ["Upsample", "linear"]),
18701882
"ResizeNearestNeighbor": (upsample_op9, ["Upsample", "nearest"]),
1883+
"Sign": (sign_op9, []),
18711884
"Sinh": (direct_op, []),
18721885
"Where": (where_op, []),
18731886
}

0 commit comments

Comments
 (0)