Skip to content

Commit 3e5afc0

Browse files
committed
update based on review comments
1 parent cf3d913 commit 3e5afc0

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

tf2onnx/tfonnx.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ def cast_op(ctx, node, name, args):
670670
return node
671671

672672

673-
def sign_op(ctx, node, name, args):
673+
def sign_op4(ctx, node, name, args):
674674
"""Sign op."""
675675
# T sign = Sign(T Input)
676676
nodes = []
@@ -691,13 +691,9 @@ def sign_op(ctx, node, name, args):
691691

692692

693693
def sign_op9(ctx, node, name, args):
694-
# Currently supported: `float32`
695-
# Ignored: `bfloat16`
696-
# TODO: add support for `int32`, `int64`
697694
node_dtype = ctx.get_dtype(node.output[0])
698695
utils.make_sure(node_dtype, "Dtype of {} is None".format(node.name))
699-
if node_dtype in [onnx_pb.TensorProto.BOOL, onnx_pb.TensorProto.FLOAT16,
700-
onnx_pb.TensorProto.COMPLEX64, onnx_pb.TensorProto.COMPLEX128]:
696+
if node_dtype in [onnx_pb.TensorProto.BOOL, onnx_pb.TensorProto.COMPLEX64, onnx_pb.TensorProto.COMPLEX128]:
701697
raise ValueError("dtype " + str(node_dtype) + " is not supported in onnx for now")
702698
return node
703699

@@ -1805,7 +1801,7 @@ def where_op(ctx, node, name, args):
18051801
"Pack": (pack_op, []),
18061802
"Unpack": (unpack_op, []),
18071803
"Erf": (erf_op, []),
1808-
"Sign": (sign_op, []),
1804+
"Sign": (sign_op4, []),
18091805
"ZerosLike": (zeroslike_op, []),
18101806
}
18111807

0 commit comments

Comments
 (0)