Skip to content

Commit 9e0429b

Browse files
committed
minor modification and remove test_sign_int
1 parent 2a3ef1b commit 9e0429b

File tree

2 files changed

+2
-10
lines changed

2 files changed

+2
-10
lines changed

tests/test_backend.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,14 +1059,6 @@ def test_cast(self):
10591059
_ = tf.identity(x_, name=_TFOUTPUT)
10601060
self._run_test_case([_OUTPUT], {_INPUT: x_val})
10611061

1062-
@check_opset_min_version(9)
1063-
def test_sign_int(self):
1064-
x_val = np.array([1, 2, 0, -1, 0, -2], dtype=np.int).reshape((2, 3))
1065-
x = tf.placeholder(tf.int32, [2, 3], name=_TFINPUT)
1066-
x_ = tf.sign(x)
1067-
_ = tf.identity(x_, name=_TFOUTPUT)
1068-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1069-
10701062
def test_sign(self):
10711063
x_val1 = np.array([1, 2, 0, -1, 0, -2], dtype=np.int32).reshape((2, 3))
10721064
x_val2 = np.array([1, 2, 0, -1, 0, -2], dtype=np.int64).reshape((2, 3))

tf2onnx/tfonnx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -650,9 +650,9 @@ def sign_op4(ctx, node, name, args):
650650
if node_dtype in [onnx_pb.TensorProto.COMPLEX64, onnx_pb.TensorProto.COMPLEX128]:
651651
raise ValueError("dtype " + str(node_dtype) + " is not supported in onnx for now")
652652
zero_name = utils.make_name("{}_zero".format(node.name))
653-
ctx.make_const(zero_name, np.array(0, dtype=utils.ONNX_TO_NUMPY_DTYPE[1]))
653+
ctx.make_const(zero_name, np.array(0, dtype=np.float32))
654654
if node_dtype in [onnx_pb.TensorProto.INT32, onnx_pb.TensorProto.INT64]:
655-
cast_node_0 = ctx.make_node("Cast", [node.input[0]], {"to": 1})
655+
cast_node_0 = ctx.make_node("Cast", [node.input[0]], {"to": onnx_pb.TensorProto.FLOAT})
656656
greater_node = ctx.make_node("Greater", [cast_node_0.output[0], zero_name])
657657
less_node = ctx.make_node("Less", [cast_node_0.output[0], zero_name])
658658
else:

0 commit comments

Comments
 (0)