Skip to content

Commit b86e4a3

Browse files
committed
implement isinf for opset 10
1 parent 6ae0320 commit b86e4a3

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

tests/test_backend.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2078,5 +2078,17 @@ def test_space_to_batchnd(self):
20782078
_ = tf.space_to_batch_nd(input_x, block_size, pad, name=_TFOUTPUT)
20792079
self._run_test_case([_OUTPUT], {_INPUT: input_val})
20802080

2081+
@check_opset_min_version(10, "is_inf")
2082+
def test_isinf(self):
2083+
x_val1 = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))
2084+
x_val2 = np.array([np.inf, np.inf, np.inf, np.inf], dtype=np.float32).reshape((2, 2))
2085+
x_val3 = np.array([1.0, np.inf, -3.0, np.inf], dtype=np.float32).reshape((2, 2))
2086+
for x_val in [x_val1, x_val2, x_val3]:
2087+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
2088+
x_ = tf.is_inf(x)
2089+
_ = tf.identity(x_, name=_TFOUTPUT)
2090+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
2091+
tf.reset_default_graph()
2092+
20812093
if __name__ == '__main__':
20822094
unittest_main()

tf2onnx/onnx_opset/tensor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -881,3 +881,10 @@ def version_4(cls, ctx, node, **kwargs):
881881
reorganize_node = ctx.make_node(node.type, trans1.output, attr={"blocksize": blocksize[0]})
882882
ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 2, 3, 0]}, name=node.name, outputs=node.output,
883883
shapes=shapes, dtypes=dtypes)
884+
885+
886+
@tf_op("IsInf", onnx_op="IsInf")
887+
class IsInf:
888+
@classmethod
889+
def version_10(cls, ctx, node, **kwargs):
890+
pass

0 commit comments

Comments
 (0)