Skip to content

Commit 6975b75

Browse files
committed
add double type in test, add type check in def
1 parent b86e4a3 commit 6975b75

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

tests/test_backend.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2080,15 +2080,17 @@ def test_space_to_batchnd(self):
20802080

20812081
@check_opset_min_version(10, "is_inf")
20822082
def test_isinf(self):
2083-
x_val1 = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))
2084-
x_val2 = np.array([np.inf, np.inf, np.inf, np.inf], dtype=np.float32).reshape((2, 2))
2085-
x_val3 = np.array([1.0, np.inf, -3.0, np.inf], dtype=np.float32).reshape((2, 2))
2086-
for x_val in [x_val1, x_val2, x_val3]:
2087-
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
2088-
x_ = tf.is_inf(x)
2089-
_ = tf.identity(x_, name=_TFOUTPUT)
2090-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
2091-
tf.reset_default_graph()
2083+
x_types = [np.float32, np.float64]
2084+
for x_type in x_types:
2085+
x_val1 = np.array([1.0, -2.0, 3.0, -4.0], dtype=x_type).reshape((2, 2))
2086+
x_val2 = np.array([np.inf, np.inf, np.inf, np.inf], dtype=x_type).reshape((1, 4))
2087+
x_val3 = np.array([1.0, np.inf, -3.0, np.inf, 5.0, np.inf, -7.0, np.inf, 9.0], dtype=x_type).reshape((3, 3))
2088+
for x_val in [x_val1, x_val2, x_val3]:
2089+
x = tf.placeholder(x_type, x_val.shape, name=_TFINPUT)
2090+
x_ = tf.is_inf(x)
2091+
_ = tf.identity(x_, name=_TFOUTPUT)
2092+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
2093+
tf.reset_default_graph()
20922094

20932095
if __name__ == '__main__':
20942096
unittest_main()

tf2onnx/onnx_opset/tensor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -887,4 +887,7 @@ def version_4(cls, ctx, node, **kwargs):
887887
class IsInf:
888888
@classmethod
889889
def version_10(cls, ctx, node, **kwargs):
890-
pass
890+
node_dtype = ctx.get_dtype(node.input[0])
891+
utils.make_sure(node_dtype, "Dtype of {} is None".format(node.name))
892+
if node_dtype not in [onnx_pb.TensorProto.FLOAT, onnx_pb.TensorProto.DOUBLE]:
893+
raise ValueError("dtype " + str(node_dtype) + " is not supported in onnx for now")

0 commit comments

Comments
 (0)