Skip to content

Commit 1d3478f

Browse files
Improve detection of NaN values in Select op (#1663)
Signed-off-by: Tom Wildenhain <[email protected]> Co-authored-by: Guenther Schmuelling <[email protected]>
1 parent e00594d commit 1d3478f

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

tests/test_backend.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3295,6 +3295,16 @@ def func(x):
32953295
return tf.identity(picks, name=_TFOUTPUT)
32963296
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
32973297

3298+
@check_opset_min_version(9, "IsNaN")
3299+
def test_where_isnan(self):
3300+
x_val = np.array([1, 2, -3, float('nan'), -5, -6, float('nan'), 8, 9, 0], dtype=np.float32)
3301+
true_result = np.array([111, 222, 333, 444, 555, 666, 777, 888, 999, 1000],
3302+
dtype=np.float32)
3303+
def func(x):
3304+
picks = tf.where(is_nan(x), true_result, x)
3305+
return tf.identity(picks, name=_TFOUTPUT)
3306+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
3307+
32983308
@check_opset_min_version(9, "Where for strings needs opset 9")
32993309
@skip_tfjs("Technically tf where doesn't support strings and tfjs doesn't like it")
33003310
def test_where_string(self):

tf2onnx/onnx_opset/controlflow.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,15 @@ def version_9(cls, ctx, node, **kwargs):
184184
# We can't use the mul/add trick if a NaN is involved. handles_nan is added earlier in the converter.
185185
handles_nan = node.get_attr_value("handles_nan", False)
186186
if ctx.get_dtype(node.output[0]) in [TensorProto.FLOAT, TensorProto.DOUBLE]:
187+
cond_node = node.inputs[0]
188+
if cond_node.type == "IsNaN":
189+
handles_nan = True
190+
if cond_node.type == "NotEqual" and cond_node.input[0] == cond_node.input[1]:
191+
handles_nan = True
192+
if cond_node.type == "Not" and cond_node.inputs[0].type == "Equal":
193+
eq_node = cond_node.inputs[0]
194+
if eq_node.input[0] == eq_node.input[1]:
195+
handles_nan = True
187196
for inp in node.inputs[1:]:
188197
if inp.is_const() and np.any(np.isnan(inp.get_tensor_value(as_list=False))):
189198
handles_nan = True

0 commit comments

Comments
 (0)