Skip to content

Commit ee09669

Browse files
hwangdeyufatcat-z
andauthored
Fix Select op with Mul inf input to NaN error (#2035)
* fix select mul inf to nan error Signed-off-by: Deyu Huang <[email protected]> * fix pylint Signed-off-by: Deyu Huang <[email protected]> * add break condition Signed-off-by: Deyu Huang <[email protected]> Co-authored-by: Jay Zhang <[email protected]>
1 parent 0052e50 commit ee09669

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

tests/test_backend.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3488,6 +3488,17 @@ def func(x):
34883488
return tf.identity(picks, name=_TFOUTPUT)
34893489
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
34903490

3491+
@check_opset_min_version(9, "IsNaN")
3492+
def test_where_ismulinf(self):
3493+
x_val1 = np.array([np.inf], dtype=np.float32)
3494+
x_val2 = np.array([0], dtype=np.float32)
3495+
true_result = np.array([np.inf], dtype=np.float32)
3496+
def func(x1, x2):
3497+
mul = tf.multiply(x1, x2)
3498+
picks = tf.where(x1 < mul, true_result, x2)
3499+
return tf.identity(picks, name=_TFOUTPUT)
3500+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
3501+
34913502
@check_opset_min_version(9, "Where for strings needs opset 9")
34923503
@skip_tfjs("Technically tf where doesn't support strings and tfjs doesn't like it")
34933504
def test_where_string(self):
@@ -5542,7 +5553,7 @@ def func(x):
55425553
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val0}, rtol=1e-6, atol=1e-4)
55435554
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-6, atol=1e-4)
55445555

5545-
x_val = np.random.random(size=[4, 3]).astype(np.float32) * 2048. - 1024
5556+
x_val = np.random.random(size=[4, 3]).astype(np.float32) * 2048. - 1024.
55465557
x_val[0, 0] = -1024
55475558
x_val[0, 1] = -1023
55485559
x_val[0, 2] = 1024
@@ -5579,7 +5590,7 @@ def func(x):
55795590
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val0}, rtol=1e-6, atol=1e-4)
55805591
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-6, atol=1e-4)
55815592

5582-
x_val = np.random.random(size=[4, 3]).astype(np.float32) * 2048. - 1024
5593+
x_val = np.random.random(size=[4, 3]).astype(np.float32) * 2048. - 1024.
55835594
x_val[0, 0] = -1024
55845595
x_val[0, 1] = -1023
55855596
x_val[0, 2] = 1024

tf2onnx/onnx_opset/controlflow.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,15 @@ def version_9(cls, ctx, node, **kwargs):
204204
if eq_node.input[0] == eq_node.input[1]:
205205
handles_nan = True
206206
for inp in node.inputs[1:]:
207-
if inp.is_const() and np.any(np.isnan(inp.get_tensor_value(as_list=False))):
207+
if handles_nan:
208+
break
209+
if inp.is_const() and (np.any(np.isnan(inp.get_tensor_value(as_list=False))) or \
210+
np.any(np.isinf(inp.get_tensor_value(as_list=False)))):
208211
handles_nan = True
212+
if inp.type == "Mul":
213+
inp0 = inp.inputs[0].is_const() and np.any(np.isinf(inp.inputs[0].get_tensor_value(as_list=False)))
214+
inp1 = inp.inputs[1].is_const() and np.any(np.isinf(inp.inputs[1].get_tensor_value(as_list=False)))
215+
handles_nan = inp0 or inp1
209216

210217
if ctx.get_dtype(node.output[0]) != TensorProto.STRING and not handles_nan:
211218
# Due to bad ORT implementation, Mul/Add ops are faster than Where op

0 commit comments

Comments
 (0)