Skip to content

Commit 919ab80

Browse files
committed
modify for node_dtype
1 parent 9e0429b commit 919ab80

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tf2onnx/tfonnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ def sign_op4(ctx, node, name, args):
651651
raise ValueError("dtype " + str(node_dtype) + " is not supported in onnx for now")
652652
zero_name = utils.make_name("{}_zero".format(node.name))
653653
ctx.make_const(zero_name, np.array(0, dtype=np.float32))
654-
if node_dtype in [onnx_pb.TensorProto.INT32, onnx_pb.TensorProto.INT64]:
654+
if node_dtype not in [onnx_pb.TensorProto.FLOAT16, onnx_pb.TensorProto.FLOAT, onnx_pb.TensorProto.DOUBLE]:
655655
cast_node_0 = ctx.make_node("Cast", [node.input[0]], {"to": onnx_pb.TensorProto.FLOAT})
656656
greater_node = ctx.make_node("Greater", [cast_node_0.output[0], zero_name])
657657
less_node = ctx.make_node("Less", [cast_node_0.output[0], zero_name])

0 commit comments

Comments
 (0)