Skip to content

Commit 411c9f6

Browse files
authored
Merge pull request #477 from mindest/implement_isinf
implement IsInf for opset 10
2 parents d1d0802 + 119e3da commit 411c9f6

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

tests/test_backend.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2103,6 +2103,19 @@ def test_space_to_batchnd(self):
21032103
_ = tf.space_to_batch_nd(input_x, block_size, pad, name=_TFOUTPUT)
21042104
self._run_test_case([_OUTPUT], {_INPUT: input_val})
21052105

2106+
@check_opset_min_version(10, "is_inf")
2107+
def test_isinf(self):
2108+
x_types = [np.float32, np.float64]
2109+
for x_type in x_types:
2110+
x_val1 = np.array([1.0, -2.0, 3.0, -4.0], dtype=x_type)
2111+
x_val2 = np.array([np.inf, np.inf, np.inf, np.inf], dtype=x_type).reshape((2, 2))
2112+
x_val3 = np.array([1.0, np.inf, -3.0, np.inf, 5.0, np.inf, -7.0, np.inf], dtype=x_type).reshape((2, 2, 2))
2113+
for x_val in [x_val1, x_val2, x_val3]:
2114+
x = tf.placeholder(x_type, x_val.shape, name=_TFINPUT)
2115+
x_ = tf.is_inf(x)
2116+
_ = tf.identity(x_, name=_TFOUTPUT)
2117+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
2118+
tf.reset_default_graph()
21062119

21072120
if __name__ == '__main__':
21082121
unittest_main()

tf2onnx/onnx_opset/tensor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,3 +922,13 @@ def version_4(cls, ctx, node, **kwargs):
922922
reorganize_node = ctx.make_node(node.type, trans1.output, attr={"blocksize": blocksize[0]})
923923
ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 2, 3, 0]}, name=node.name, outputs=node.output,
924924
shapes=shapes, dtypes=dtypes)
925+
926+
927+
@tf_op("IsInf", onnx_op="IsInf")
928+
class IsInf:
929+
@classmethod
930+
def version_10(cls, ctx, node, **kwargs):
931+
node_dtype = ctx.get_dtype(node.input[0])
932+
utils.make_sure(node_dtype, "Dtype of {} is None".format(node.name))
933+
if node_dtype not in [onnx_pb.TensorProto.FLOAT, onnx_pb.TensorProto.DOUBLE]:
934+
raise ValueError("dtype " + str(node_dtype) + " is not supported in onnx for now")

0 commit comments

Comments
 (0)