Skip to content

Commit 1fcbc53

Browse files
authored
Merge pull request #364 from mindest/mod_sign_op4
modifiy sign_op4 and test_sign
2 parents 25e0f29 + 919ab80 commit 1fcbc53

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

tests/test_backend.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,20 +1059,16 @@ def test_cast(self):
10591059
_ = tf.identity(x_, name=_TFOUTPUT)
10601060
self._run_test_case([_OUTPUT], {_INPUT: x_val})
10611061

1062-
@check_opset_min_version(9)
1063-
def test_sign_int(self):
1064-
x_val = np.array([1, 2, 0, -1, 0, -2], dtype=np.int).reshape((2, 3))
1065-
x = tf.placeholder(tf.int32, [2, 3], name=_TFINPUT)
1066-
x_ = tf.sign(x)
1067-
_ = tf.identity(x_, name=_TFOUTPUT)
1068-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1069-
10701062
def test_sign(self):
1071-
x_val = np.array([1.0, 2.0, 0.0, -1.0, 0.0, -2.0], dtype=np.float32).reshape((2, 3))
1072-
x = tf.placeholder(tf.float32, [2, 3], name=_TFINPUT)
1073-
x_ = tf.sign(x)
1074-
_ = tf.identity(x_, name=_TFOUTPUT)
1075-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1063+
x_val1 = np.array([1, 2, 0, -1, 0, -2], dtype=np.int32).reshape((2, 3))
1064+
x_val2 = np.array([1, 2, 0, -1, 0, -2], dtype=np.int64).reshape((2, 3))
1065+
x_val3 = np.array([1.0, 2.0, 0.0, -1.0, 0.0, -2.0], dtype=np.float32).reshape((2, 3))
1066+
for x_val in [x_val1, x_val2, x_val3]:
1067+
x = tf.placeholder(x_val.dtype, x_val.shape, name=_TFINPUT)
1068+
x_ = tf.sign(x)
1069+
_ = tf.identity(x_, name=_TFOUTPUT)
1070+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1071+
tf.reset_default_graph()
10761072

10771073
def test_onehot0(self):
10781074
x_val = np.array([0, 1, 2], dtype=np.int32)

tf2onnx/tfonnx.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -648,12 +648,16 @@ def sign_op4(ctx, node, name, args):
648648
node_dtype = ctx.get_dtype(node.output[0])
649649
utils.make_sure(node_dtype, "Dtype of {} is None".format(node.name))
650650
if node_dtype in [onnx_pb.TensorProto.COMPLEX64, onnx_pb.TensorProto.COMPLEX128]:
651-
raise ValueError("dtype " + node_dtype + " is not supported in onnx for now")
652-
input_tensor_type = utils.map_onnx_to_numpy_type(node_dtype)
651+
raise ValueError("dtype " + str(node_dtype) + " is not supported in onnx for now")
653652
zero_name = utils.make_name("{}_zero".format(node.name))
654-
ctx.make_const(zero_name, np.array(0, dtype=input_tensor_type))
655-
greater_node = ctx.make_node("Greater", [node.input[0], zero_name])
656-
less_node = ctx.make_node("Less", [node.input[0], zero_name])
653+
ctx.make_const(zero_name, np.array(0, dtype=np.float32))
654+
if node_dtype not in [onnx_pb.TensorProto.FLOAT16, onnx_pb.TensorProto.FLOAT, onnx_pb.TensorProto.DOUBLE]:
655+
cast_node_0 = ctx.make_node("Cast", [node.input[0]], {"to": onnx_pb.TensorProto.FLOAT})
656+
greater_node = ctx.make_node("Greater", [cast_node_0.output[0], zero_name])
657+
less_node = ctx.make_node("Less", [cast_node_0.output[0], zero_name])
658+
else:
659+
greater_node = ctx.make_node("Greater", [node.input[0], zero_name])
660+
less_node = ctx.make_node("Less", [node.input[0], zero_name])
657661
cast_node_1 = ctx.make_node("Cast", [greater_node.output[0]], {"to": node_dtype})
658662
cast_node_2 = ctx.make_node("Cast", [less_node.output[0]], {"to": node_dtype})
659663

0 commit comments

Comments
 (0)