Skip to content

Commit 5108487

Browse files
authored
Merge pull request #336 from mindest/dev_isnan
add op IsNaN and test case test_isnan
2 parents d6f1411 + 3a7171f commit 5108487

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

tests/test_backend.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1702,6 +1702,19 @@ def test_zeros_like(self):
17021702
_ = tf.identity(res1, name=_TFOUTPUT1)
17031703
self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: input_val})
17041704

1705+
@check_opset_min_version(9, "is_nan")
1706+
def test_isnan(self):
1707+
# only compatible with dtype `float32`
1708+
x_val1 = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))
1709+
x_val2 = np.array([np.nan, np.nan, np.nan, np.nan], dtype=np.float32).reshape((2, 2))
1710+
x_val3 = np.array([1.0, np.nan, -3.0, np.nan], dtype=np.float32).reshape((2, 2))
1711+
for x_val in [x_val1, x_val2, x_val3]:
1712+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
1713+
x_ = tf.is_nan(x)
1714+
_ = tf.identity(x_, name=_TFOUTPUT)
1715+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1716+
tf.reset_default_graph()
1717+
17051718

17061719
if __name__ == '__main__':
17071720
unittest_main()

tf2onnx/tfonnx.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1894,17 +1894,18 @@ def where_op(ctx, node, name, args):
18941894
}
18951895

18961896
_OPSET_9 = {
1897-
"Erf": (direct_op, []),
1898-
"Fill": (fill_op, []),
1899-
"Sinh": (direct_op, []),
1900-
"Cosh": (direct_op, []),
1901-
"Asinh": (direct_op, []),
19021897
"Acosh": (direct_op, []),
1898+
"Asinh": (direct_op, []),
19031899
"Atanh": (direct_op, []),
1900+
"Cosh": (direct_op, []),
1901+
"Erf": (direct_op, []),
1902+
"Fill": (fill_op, []),
19041903
"Greater": (logical_compare_op, []),
1904+
"IsNan": (direct_op, ["IsNaN"]),
19051905
"Less": (logical_compare_op, []),
19061906
"ResizeBilinear": (upsample_op9, ["Upsample", "linear"]),
19071907
"ResizeNearestNeighbor": (upsample_op9, ["Upsample", "nearest"]),
1908+
"Sinh": (direct_op, []),
19081909
"Where": (where_op, []),
19091910
}
19101911

@@ -2296,7 +2297,7 @@ def rewrite_incomplete_type_support_rs5(g, ops):
22962297

22972298

22982299
def rewrite_incomplete_type_support_rs6(g, ops):
2299-
return rewrite_incomplete_type_support(g, ops, ["Div", "ReduceSum", "Slice", "Split", "Tile", "Transpose"])
2300+
return rewrite_incomplete_type_support(g, ops, ["Div", "IsNaN", "ReduceSum", "Slice", "Split", "Tile", "Transpose"])
23002301

23012302

23022303
def tensorflow_onnx_mapping(g, continue_on_error, custom_op_handlers):

0 commit comments

Comments
 (0)