Skip to content

Commit 0cced88

Browse files
committed
add op IsNaN and test case test_isnan
1 parent d6f1411 commit 0cced88

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

tests/test_backend.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1702,6 +1702,19 @@ def test_zeros_like(self):
17021702
_ = tf.identity(res1, name=_TFOUTPUT1)
17031703
self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: input_val})
17041704

1705+
@check_opset_min_version(9, "isnan")
1706+
def test_isnan(self):
1707+
# only compatible with dtype `float32`
1708+
x_val1 = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2,2))
1709+
x_val2 = np.array([np.nan, np.nan, np.nan, np.nan], dtype=np.float32).reshape((2, 2))
1710+
x_val3 = np.array([1.0, np.nan, -3.0, np.nan], dtype=np.float32).reshape((2, 2))
1711+
for x_val in [x_val1, x_val2, x_val3]:
1712+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
1713+
x_ = tf.is_nan(x)
1714+
_ = tf.identity(x_, name=_TFOUTPUT)
1715+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1716+
tf.reset_default_graph()
1717+
17051718

17061719
if __name__ == '__main__':
17071720
unittest_main()

tf2onnx/tfonnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2296,7 +2296,7 @@ def rewrite_incomplete_type_support_rs5(g, ops):
22962296

22972297

22982298
def rewrite_incomplete_type_support_rs6(g, ops):
2299-
return rewrite_incomplete_type_support(g, ops, ["Div", "ReduceSum", "Slice", "Split", "Tile", "Transpose"])
2299+
return rewrite_incomplete_type_support(g, ops, ["Div", "IsNaN", "ReduceSum", "Slice", "Split", "Tile", "Transpose"])
23002300

23012301

23022302
def tensorflow_onnx_mapping(g, continue_on_error, custom_op_handlers):

0 commit comments

Comments
 (0)