diff --git a/tests/test_backend.py b/tests/test_backend.py index 428587ca8..a3a27cdc5 100755 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -3673,6 +3673,22 @@ def func(x): return tf.identity(picks, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @check_opset_min_version(10, "IsInf") + def test_where_with_isinf_condition(self): + def func(x, y, z): + # Use is_inf as condition to trigger the IsInf code path + condition = tf.math.is_inf(x) + result = tf.where(condition, y, z) + return tf.identity(result, name=_TFOUTPUT) + + # Create test data with some infinite values + x_val = np.array([1.0, np.inf, 3.0, -np.inf, 5.0], dtype=np.float32) + y_val = np.array([0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float32) + z_val = np.array([100.0, 200.0, 300.0, 400.0, 500.0], dtype=np.float32) + + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val}) + + @check_opset_min_version(9, "IsNaN") def test_where_isnan(self): x_val = np.array([1, 2, -3, float('nan'), -5, -6, float('nan'), 8, 9, 0], dtype=np.float32) diff --git a/tf2onnx/onnx_opset/controlflow.py b/tf2onnx/onnx_opset/controlflow.py index 4f577facc..75f99afea 100644 --- a/tf2onnx/onnx_opset/controlflow.py +++ b/tf2onnx/onnx_opset/controlflow.py @@ -195,7 +195,8 @@ def version_9(cls, ctx, node, **kwargs): handles_nan = node.get_attr_value("handles_nan", False) if ctx.get_dtype(node.output[0]) in [TensorProto.FLOAT, TensorProto.DOUBLE]: cond_node = node.inputs[0] - if cond_node.type == "IsNaN": + if cond_node.type in {"IsNaN", "IsInf"}: + # We can't use the mul trick if Inf is involved since Inf * 0 = NaN as per IEEE 754. handles_nan = True if cond_node.type == "NotEqual" and cond_node.input[0] == cond_node.input[1]: handles_nan = True