Skip to content

Commit 695e159

Browse files
authored
Merge pull request #936 from onnx/gs/add-is_finite
add support for tf.math.is_finite
2 parents f4f2f04 + c891780 commit 695e159

File tree

5 files changed

+45
-11
lines changed

5 files changed

+45
-11
lines changed

tests/test_backend.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3190,12 +3190,13 @@ def func(X, K):
31903190
k_val = np.array(raw_k).astype(np.int32)
31913191
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: k_val})
31923192

3193-
@check_opset_min_version(12)
3194-
def test_inverse(self):
3193+
@test_ms_domain()
3194+
def test_inverse(self, extra_opset):
3195+
# this depends on onnx Inverse which was removed from opset-12 but does exists in the ms-domain
31953196
x_val = np.random.random([5, 5]).astype(np.float32)
31963197
def func(x):
31973198
return tf.linalg.inv(x, name=_TFOUTPUT)
3198-
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
3199+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, process_args={"extra_opset": [extra_opset]})
31993200

32003201
@check_opset_min_version(12)
32013202
def test_squared_distance(self):
@@ -3224,6 +3225,15 @@ def func(x, y):
32243225
tf.math.greater_equal(x, y, name=_TFOUTPUT1)
32253226
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: y_val})
32263227

3228+
@check_tf_min_version("1.14", "required for tf.math.is_finite")
3229+
@check_opset_min_version(10)
3230+
def test_is_finite(self):
3231+
x_val = np.array([5.0, 4.8, 6.8, np.inf, np.nan], dtype=np.float32)
3232+
def func(x):
3233+
y = tf.math.is_finite(x)
3234+
return tf.identity(y, name=_TFOUTPUT)
3235+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
3236+
32273237

32283238
if __name__ == '__main__':
32293239
unittest_main()

tf2onnx/custom_opsets/ms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def version_1(cls, ctx, node, **kwargs):
9696
@tf_op("CropAndResize", domain=constants.MICROSOFT_DOMAIN)
9797
class CropAndResize:
9898
@classmethod
99-
def version_11(cls, ctx, node, **kwargs):
99+
def version_1(cls, ctx, node, **kwargs):
100100
""" utilize contrib cropandresize """
101101
node.attr['method'].name = 'mode'
102102
node.domain = constants.MICROSOFT_DOMAIN
@@ -107,7 +107,7 @@ def version_11(cls, ctx, node, **kwargs):
107107
@tf_op("MatrixInverse", domain=constants.MICROSOFT_DOMAIN, onnx_op="Inverse")
108108
class Inverse:
109109
@classmethod
110-
def version_12(cls, ctx, node, **kwargs):
110+
def version_1(cls, ctx, node, **kwargs):
111111
utils.make_sure(node.get_attr('adjoint').i == 0, "adjoint must be false")
112112
del node.attr["adjoint"]
113113
node.domain = constants.MICROSOFT_DOMAIN

tf2onnx/onnx_opset/math.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -566,10 +566,23 @@ def version_12(cls, ctx, node, **kwargs):
566566
del node.attr["N"]
567567

568568

569-
@tf_op("MatrixInverse", onnx_op="Inverse")
570-
class Inverse:
569+
@tf_op("IsFinite")
570+
class IsFinite:
571571
@classmethod
572-
def version_12(cls, ctx, node, **kwargs):
573-
utils.make_sure(node.get_attr('adjoint').i == 0, "adjoint must be false")
574-
del node.attr["adjoint"]
575-
node.domain = constants.MICROSOFT_DOMAIN
572+
def version_10(cls, ctx, node, **kwargs):
573+
# map to onnx as:
574+
# not (isinf(x) or isnan(x))
575+
576+
shapes = node.output_shapes
577+
dtypes = [onnx_pb.TensorProto.BOOL] * len(node.output_dtypes)
578+
579+
ctx.remove_node(node.name)
580+
581+
inf_node = ctx.make_node("IsInf", inputs=node.input, name=utils.make_name(node.name),
582+
shapes=shapes, dtypes=dtypes)
583+
nan_node = ctx.make_node("IsNaN", inputs=node.input, name=utils.make_name(node.name),
584+
shapes=shapes, dtypes=dtypes)
585+
or_node = ctx.make_node("Or", inputs=[inf_node.output[0], nan_node.output[0]], name=utils.make_name(node.name),
586+
shapes=shapes, dtypes=dtypes)
587+
_ = ctx.make_node("Not", inputs=or_node.output, name=node.name,
588+
shapes=shapes, dtypes=dtypes)

tf2onnx/onnx_opset/tensor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,16 @@ def version_7(cls, ctx, node, **kwargs):
625625
return
626626
raise ValueError("non-const dim is not supported")
627627

628+
@classmethod
629+
def version_11(cls, ctx, node, **kwargs):
630+
dim_node = node.inputs[1]
631+
if dim_node.is_const():
632+
node.type = "Unsqueeze"
633+
dim = dim_node.get_tensor_value()
634+
node.set_attr("axes", [dim])
635+
ctx.remove_input(node, node.input[1])
636+
return
637+
raise ValueError("non-const dim is not supported")
628638

629639
@tf_op("StridedSlice")
630640
class StridedSlice:

tools/dump-onnx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import collections
1515
import re
1616

17+
import onnx
1718
from onnx import ModelProto
1819
from onnx import helper, shape_inference
1920

0 commit comments

Comments
 (0)