diff --git a/docs/api-reference/expr.md b/docs/api-reference/expr.md index dbd613d134..7e182ac000 100644 --- a/docs/api-reference/expr.md +++ b/docs/api-reference/expr.md @@ -23,6 +23,7 @@ - filter - clip - is_between + - is_close - is_duplicated - is_finite - is_first_distinct diff --git a/docs/api-reference/series.md b/docs/api-reference/series.md index 57c0ab313b..01b7c1ff8b 100644 --- a/docs/api-reference/series.md +++ b/docs/api-reference/series.md @@ -35,6 +35,7 @@ - hist - implementation - is_between + - is_close - is_duplicated - is_empty - is_finite diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index 5c2d9b8aba..57a790eb9e 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -27,6 +27,7 @@ LazyExprT, NativeExprT, ) +from narwhals._compliant.utils import IsClose from narwhals._utils import _StoresCompliant from narwhals.dependencies import get_numpy, is_numpy_array @@ -75,7 +76,7 @@ def __eq__(self, value: Any, /) -> Self: ... # type: ignore[override] def __ne__(self, value: Any, /) -> Self: ... # type: ignore[override] -class CompliantExpr(Protocol[CompliantFrameT, CompliantSeriesOrNativeExprT_co]): +class CompliantExpr(IsClose, Protocol[CompliantFrameT, CompliantSeriesOrNativeExprT_co]): _implementation: Implementation _version: Version _evaluate_output_names: EvalNames[CompliantFrameT] @@ -902,6 +903,22 @@ def exp(self) -> Self: def sqrt(self) -> Self: return self._reuse_series("sqrt") + def is_close( + self, + other: Self | NumericLiteral, + *, + abs_tol: float, + rel_tol: float, + nans_equal: bool, + ) -> Self: + return self._reuse_series( + "is_close", + other=other, + abs_tol=abs_tol, + rel_tol=rel_tol, + nans_equal=nans_equal, + ) + @property def cat(self) -> EagerExprCatNamespace[Self]: return EagerExprCatNamespace(self) diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index dcb376ed06..e8a88590d3 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -16,6 +16,7 @@ NativeSeriesT, NativeSeriesT_co, ) +from narwhals._compliant.utils import IsClose from narwhals._translate import FromIterable, FromNative, NumpyConvertible, ToNarwhals from narwhals._typing_compat import TypeVar, assert_never from narwhals._utils import ( @@ -78,6 +79,7 @@ class HistData(TypedDict, Generic[NativeSeriesT, "_CountsT_co"]): class CompliantSeries( + IsClose, NumpyConvertible["_1DArray", "Into1DArray"], FromIterable, FromNative[NativeSeriesT], diff --git a/narwhals/_compliant/utils.py b/narwhals/_compliant/utils.py new file mode 100644 index 0000000000..3a88e39561 --- /dev/null +++ b/narwhals/_compliant/utils.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol + +if TYPE_CHECKING: + from typing_extensions import Self + + from narwhals.typing import NumericLiteral, TemporalLiteral + + +class IsClose(Protocol): + """Every member defined is a dependency of `is_close` method.""" + + def __and__(self, other: Any) -> Self: ... + def __or__(self, other: Any) -> Self: ... + def __invert__(self) -> Self: ... + def __sub__(self, other: Any) -> Self: ... + def __mul__(self, other: Any) -> Self: ... + def __eq__(self, other: Self | Any) -> Self: ... # type: ignore[override] + def __gt__(self, other: Any) -> Self: ... + def __le__(self, other: Any) -> Self: ... + def abs(self) -> Self: ... + def is_nan(self) -> Self: ... + def is_finite(self) -> Self: ... + def clip( + self, + lower_bound: Self | NumericLiteral | TemporalLiteral | None, + upper_bound: Self | NumericLiteral | TemporalLiteral | None, + ) -> Self: ... + def is_close( + self, + other: Self | NumericLiteral, + *, + abs_tol: float, + rel_tol: float, + nans_equal: bool, + ) -> Self: + from decimal import Decimal + + other_abs: Self | NumericLiteral + other_is_nan: Self | bool + other_is_inf: Self | bool + other_is_not_inf: Self | bool + + if isinstance(other, (float, int, Decimal)): + from math import isinf, isnan + + # NOTE: See https://discuss.python.org/t/inferred-type-of-function-that-calls-dunder-abs-abs/101447 + other_abs = other.__abs__() + other_is_nan = isnan(other) + other_is_inf = isinf(other) + + # Define the other_is_not_inf variable to prevent triggering the following warning: + # > DeprecationWarning: Bitwise inversion '~' on bool is deprecated and will be + # > removed in Python 3.16. + other_is_not_inf = not other_is_inf + + else: + other_abs, other_is_nan = other.abs(), other.is_nan() + other_is_not_inf = other.is_finite() | other_is_nan + other_is_inf = ~other_is_not_inf + + rel_threshold = self.abs().clip(lower_bound=other_abs, upper_bound=None) * rel_tol + tolerance = rel_threshold.clip(lower_bound=abs_tol, upper_bound=None) + + self_is_nan = self.is_nan() + self_is_not_inf = self.is_finite() | self_is_nan + + # Values are close if abs_diff <= tolerance, and both finite + is_close = ( + ((self - other).abs() <= tolerance) & self_is_not_inf & other_is_not_inf + ) + + # Handle infinity cases: infinities are close/equal if they have the same sign + self_sign, other_sign = self > 0, other > 0 + is_same_inf = (~self_is_not_inf) & other_is_inf & (self_sign == other_sign) + + # Handle nan cases: + # * If any value is NaN, then False (via `& ~either_nan`) + # * However, if `nans_equals = True` and if _both_ values are NaN, then True + either_nan = self_is_nan | other_is_nan + result = (is_close | is_same_inf) & ~either_nan + + if nans_equal: + both_nan = self_is_nan & other_is_nan + result = result | both_nan + + return result diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index a106e13d1e..1210ee08e0 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -27,7 +27,7 @@ from narwhals._polars.dataframe import Method, PolarsDataFrame from narwhals._polars.namespace import PolarsNamespace from narwhals._utils import Version, _LimitedContext - from narwhals.typing import IntoDType + from narwhals.typing import IntoDType, NumericLiteral class PolarsExpr: @@ -232,6 +232,46 @@ def __narwhals_namespace__(self) -> PolarsNamespace: # pragma: no cover return PolarsNamespace(version=self._version) + def is_close( + self, + other: Self | NumericLiteral, + *, + abs_tol: float, + rel_tol: float, + nans_equal: bool, + ) -> Self: + left = self.native + right = other.native if isinstance(other, PolarsExpr) else pl.lit(other) + + if self._backend_version < (1, 32, 0): + lower_bound = right.abs() + tolerance = (left.abs().clip(lower_bound) * rel_tol).clip(abs_tol) + + # Values are close if abs_diff <= tolerance, and both finite + abs_diff = (left - right).abs() + all_ = pl.all_horizontal + is_close = all_((abs_diff <= tolerance), left.is_finite(), right.is_finite()) + + # Handle infinity cases: infinities are "close" only if they have the same sign + is_same_inf = all_( + left.is_infinite(), right.is_infinite(), (left.sign() == right.sign()) + ) + + # Handle nan cases: + # * nans_equals = True => if both values are NaN, then True + # * nans_equals = False => if any value is NaN, then False + left_is_nan, right_is_nan = left.is_nan(), right.is_nan() + either_nan = left_is_nan | right_is_nan + result = (is_close | is_same_inf) & either_nan.not_() + + if nans_equal: + result = result | (left_is_nan & right_is_nan) + else: + result = left.is_close( + right, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal + ) + return self._with_native(result) + @property def dt(self) -> PolarsExprDateTimeNamespace: return PolarsExprDateTimeNamespace(self) diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index d8ac0216ff..e94f712534 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -41,6 +41,7 @@ IntoDType, MultiIndexSelector, NonNestedLiteral, + NumericLiteral, _1DArray, ) @@ -90,6 +91,7 @@ "gather_every", "head", "is_between", + "is_close", "is_finite", "is_first_distinct", "is_in", @@ -131,6 +133,11 @@ class PolarsSeries: _implementation = Implementation.POLARS + _HIST_EMPTY_SCHEMA: ClassVar[Mapping[IncludeBreakpoint, Sequence[str]]] = { + True: ["breakpoint", "count"], + False: ["count"], + } + def __init__(self, series: pl.Series, *, version: Version) -> None: self._native_series: pl.Series = series self._version = version @@ -478,10 +485,27 @@ def __contains__(self, other: Any) -> bool: except Exception as e: # noqa: BLE001 raise catch_polars_exception(e) from None - _HIST_EMPTY_SCHEMA: ClassVar[Mapping[IncludeBreakpoint, Sequence[str]]] = { - True: ["breakpoint", "count"], - False: ["count"], - } + def is_close( + self, + other: Self | NumericLiteral, + *, + abs_tol: float, + rel_tol: float, + nans_equal: bool, + ) -> PolarsSeries: + if self._backend_version < (1, 32, 0): + name = self.name + ns = self.__narwhals_namespace__() + other_expr = other._to_expr() if isinstance(other, PolarsSeries) else other + expr = ns.col(name).is_close( + other_expr, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal + ) + return self.to_frame().select(expr).get_column(name) + other_series = other.native if isinstance(other, PolarsSeries) else other + result = self.native.is_close( + other_series, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal + ) + return self._with_native(result) def hist_from_bins( self, bins: list[float], *, include_breakpoint: bool diff --git a/narwhals/expr.py b/narwhals/expr.py index b1ca36f974..73a5324a62 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -12,7 +12,7 @@ ) from narwhals._utils import _validate_rolling_arguments, ensure_type, flatten from narwhals.dtypes import _validate_dtype -from narwhals.exceptions import InvalidOperationError +from narwhals.exceptions import ComputeError, InvalidOperationError from narwhals.expr_cat import ExprCatNamespace from narwhals.expr_dt import ExprDateTimeNamespace from narwhals.expr_list import ExprListNamespace @@ -2347,6 +2347,97 @@ def sqrt(self) -> Self: """ return self._with_elementwise(lambda plx: self._to_compliant_expr(plx).sqrt()) + def is_close( + self, + other: Self | NumericLiteral, + *, + abs_tol: float = 0.0, + rel_tol: float = 1e-09, + nans_equal: bool = False, + ) -> Self: + r"""Check if this expression is close, i.e. almost equal, to the other expression. + + Two values `a` and `b` are considered close if the following condition holds: + + $$ + |a-b| \le max \{ \text{rel\_tol} \cdot max \{ |a|, |b| \}, \text{abs\_tol} \} + $$ + + Arguments: + other: Values to compare with. + abs_tol: Absolute tolerance. This is the maximum allowed absolute difference + between two values. Must be non-negative. + rel_tol: Relative tolerance. This is the maximum allowed difference between + two values, relative to the larger absolute value. Must be in the range + [0, 1). + nans_equal: Whether NaN values should be considered equal. + + Returns: + Expression of Boolean data type. + + Notes: + The implementation of this method is symmetric and mirrors the behavior of + `math.isclose`. Specifically note that this behavior is different to + `numpy.isclose`. + + Examples: + >>> import duckdb + >>> import pyarrow as pa + >>> import narwhals as nw + >>> + >>> data = { + ... "x": [1.0, float("inf"), 1.41, None, float("nan")], + ... "y": [1.2, float("inf"), 1.40, None, float("nan")], + ... } + >>> _table = pa.table(data) + >>> df_native = duckdb.table("_table") + >>> df = nw.from_native(df_native) + >>> df.with_columns( + ... is_close=nw.col("x").is_close( + ... nw.col("y"), abs_tol=0.1, nans_equal=True + ... ) + ... ) + ┌──────────────────────────────┐ + | Narwhals LazyFrame | + |------------------------------| + |┌────────┬────────┬──────────┐| + |│ x │ y │ is_close │| + |│ double │ double │ boolean │| + |├────────┼────────┼──────────┤| + |│ 1.0 │ 1.2 │ false │| + |│ inf │ inf │ true │| + |│ 1.41 │ 1.4 │ true │| + |│ NULL │ NULL │ NULL │| + |│ nan │ nan │ true │| + |└────────┴────────┴──────────┘| + └──────────────────────────────┘ + """ + if abs_tol < 0: + msg = f"`abs_tol` must be non-negative but got {abs_tol}" + raise ComputeError(msg) + + if not (0 <= rel_tol < 1): + msg = f"`rel_tol` must be in the range [0, 1) but got {rel_tol}" + raise ComputeError(msg) + + kwargs = {"abs_tol": abs_tol, "rel_tol": rel_tol, "nans_equal": nans_equal} + return self.__class__( + lambda plx: apply_n_ary_operation( + plx, + lambda *exprs: exprs[0].is_close(exprs[1], **kwargs), + self, + other, + str_as_lit=False, + ), + combine_metadata( + self, + other, + str_as_lit=False, + allow_multi_output=False, + to_single_output=False, + ), + ) + @property def str(self) -> ExprStringNamespace[Self]: return ExprStringNamespace(self) diff --git a/narwhals/series.py b/narwhals/series.py index 34bcb2152a..cc57ca48c7 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -17,7 +17,7 @@ ) from narwhals.dependencies import is_numpy_array_1d, is_numpy_scalar from narwhals.dtypes import _validate_dtype, _validate_into_dtype -from narwhals.exceptions import ComputeError +from narwhals.exceptions import ComputeError, InvalidOperationError from narwhals.series_cat import SeriesCatNamespace from narwhals.series_dt import SeriesDateTimeNamespace from narwhals.series_list import SeriesListNamespace @@ -2763,6 +2763,82 @@ def sqrt(self) -> Self: """ return self._with_compliant(self._compliant_series.sqrt()) + def is_close( + self, + other: Self | NumericLiteral, + *, + abs_tol: float = 0.0, + rel_tol: float = 1e-09, + nans_equal: bool = False, + ) -> Self: + r"""Get a boolean mask of the values being close to the other values. + + Two values `a` and `b` are considered close if the following condition holds: + + $$ + |a-b| \le max \{ \text{rel\_tol} \cdot max \{ |a|, |b| \}, \text{abs\_tol} \} + $$ + + Arguments: + other: Values to compare with. + abs_tol: Absolute tolerance. This is the maximum allowed absolute difference + between two values. Must be non-negative. + rel_tol: Relative tolerance. This is the maximum allowed difference between + two values, relative to the larger absolute value. Must be in the range + [0, 1). + nans_equal: Whether NaN values should be considered equal. + + Returns: + Series of Boolean data type. + + Notes: + The implementation of this method is symmetric and mirrors the behavior of + `math.isclose`. Specifically note that this behavior is different to + `numpy.isclose`. + + Examples: + >>> import pyarrow as pa + >>> import narwhals as nw + >>> + >>> data = [1.0, float("inf"), 1.41, None, float("nan")] + >>> s_native = pa.chunked_array([data]) + >>> s = nw.from_native(s_native, series_only=True) + >>> s.is_close(1.4, abs_tol=0.1).to_native() # doctest:+ELLIPSIS + + [ + [ + false, + false, + true, + null, + false + ] + ] + """ + if not self.dtype.is_numeric(): + msg = ( + f"is_close operation not supported for dtype `{self.dtype}`\n\n" + "Hint: `is_close` is only supported for numeric types" + ) + raise InvalidOperationError(msg) + + if abs_tol < 0: + msg = f"`abs_tol` must be non-negative but got {abs_tol}" + raise ComputeError(msg) + + if not (0 <= rel_tol < 1): + msg = f"`rel_tol` must be in the range [0, 1) but got {rel_tol}" + raise ComputeError(msg) + + return self._with_compliant( + self._compliant_series.is_close( + self._extract_native(other), + abs_tol=abs_tol, + rel_tol=rel_tol, + nans_equal=nans_equal, + ) + ) + @property def str(self) -> SeriesStringNamespace[Self]: return SeriesStringNamespace(self) diff --git a/pyproject.toml b/pyproject.toml index 72d499e161..8ba95a77eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -245,6 +245,7 @@ filterwarnings = [ "ignore:.*np.find_common_type is deprecated:DeprecationWarning:pandas", # Warning raised when calling PandasLikeNamespace.from_arrow with old pyarrow "ignore:.*is_sparse is deprecated and will be removed in a future version.*:DeprecationWarning:pyarrow", + 'ignore:.*invalid value encountered in cast:RuntimeWarning:pandas', ] xfail_strict = true markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"] diff --git a/tests/expr_and_series/is_close_test.py b/tests/expr_and_series/is_close_test.py new file mode 100644 index 0000000000..e14b623d8b --- /dev/null +++ b/tests/expr_and_series/is_close_test.py @@ -0,0 +1,232 @@ +"""Tricks to generate nan's and null's for pandas with nullable backends. + +* Square rooting a negative number will generate a NaN +* Replacing a value with None once the dtype is nullable will generate 's +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +from narwhals.exceptions import ComputeError, InvalidOperationError +from tests.conftest import ( + dask_lazy_p1_constructor, + dask_lazy_p2_constructor, + modin_constructor, + pandas_constructor, +) +from tests.utils import Constructor, ConstructorEager, assert_equal_data + +if TYPE_CHECKING: + from narwhals.typing import NumericLiteral + +NON_NULLABLE_CONSTRUCTORS = ( + pandas_constructor, + dask_lazy_p1_constructor, + dask_lazy_p2_constructor, + modin_constructor, +) +NULL_PLACEHOLDER, NAN_PLACEHOLDER = 9999.0, -1.0 +INF_POS, INF_NEG = float("inf"), float("-inf") + +data = { + "x": [1.001, NULL_PLACEHOLDER, NAN_PLACEHOLDER, INF_POS, INF_NEG, INF_POS], + "y": [1.005, NULL_PLACEHOLDER, NAN_PLACEHOLDER, INF_POS, 3.0, INF_NEG], + "non_numeric": list("number"), + "idx": list(range(6)), +} + + +# Exceptions +def test_is_close_series_raise_non_numeric(constructor_eager: ConstructorEager) -> None: + df = nw.from_native(constructor_eager(data), eager_only=True) + x, y = df["non_numeric"], df["y"] + + msg = "is_close operation not supported for dtype" + with pytest.raises(InvalidOperationError, match=msg): + x.is_close(y) + + +@pytest.mark.parametrize("rel_tol", [1e-09, 999]) +def test_is_close_raise_negative_abs_tol( + constructor_eager: ConstructorEager, rel_tol: float +) -> None: + df = nw.from_native(constructor_eager(data), eager_only=True) + x, y = df["x"], df["y"] + + abs_tol = -2 + msg = rf"`abs_tol` must be non-negative but got {abs_tol}" + with pytest.raises(ComputeError, match=msg): + x.is_close(y, abs_tol=abs_tol, rel_tol=rel_tol) + + with pytest.raises(ComputeError, match=msg): + df.select(nw.col("x").is_close(nw.col("y"), abs_tol=abs_tol, rel_tol=rel_tol)) + + +@pytest.mark.parametrize("rel_tol", [-0.0001, 1.0, 1.1]) +def test_is_close_raise_invalid_rel_tol( + constructor_eager: ConstructorEager, rel_tol: float +) -> None: + df = nw.from_native(constructor_eager(data), eager_only=True) + x, y = df["x"], df["y"] + + msg = rf"`rel_tol` must be in the range \[0, 1\) but got {rel_tol}" + with pytest.raises(ComputeError, match=msg): + x.is_close(y, rel_tol=rel_tol) + + with pytest.raises(ComputeError, match=msg): + df.select(nw.col("x").is_close(nw.col("y"), rel_tol=rel_tol)) + + +cases_columnar = pytest.mark.parametrize( + ("abs_tol", "rel_tol", "nans_equal", "expected"), + [ + (0.1, 0.0, False, [True, None, False, True, False, False]), + (0.0001, 0.0, True, [False, None, True, True, False, False]), + (0.0, 0.1, False, [True, None, False, True, False, False]), + (0.0, 0.001, True, [False, None, True, True, False, False]), + ], +) +cases_scalar = pytest.mark.parametrize( + ("other", "abs_tol", "rel_tol", "nans_equal", "expected"), + [ + (1.0, 0.1, 0.0, False, [True, None, False, False, False, False]), + (1.0, 0.0001, 0.0, True, [False, None, False, False, False, False]), + (2.9, 0.0, 0.1, False, [False, None, False, False, True, False]), + (2.9, 0.0, 0.001, True, [False, None, False, False, False, False]), + ], +) + + +# Series +@cases_columnar +def test_is_close_series_with_series( + constructor_eager: ConstructorEager, + abs_tol: float, + rel_tol: float, + *, + nans_equal: bool, + expected: list[float], +) -> None: + df = nw.from_native(constructor_eager(data), eager_only=True) + x, y = df["x"], df["y"] + + nulls = nw.new_series( + name="nulls", + values=[None] * len(x), + dtype=nw.Float64(), + backend=df.implementation, + ) + x = x.zip_with(x != NAN_PLACEHOLDER, x**0.5).zip_with(x != NULL_PLACEHOLDER, nulls) + y = y.zip_with(y != NAN_PLACEHOLDER, y**0.5).zip_with(y != NULL_PLACEHOLDER, nulls) + result = x.is_close(y, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal) + + if constructor_eager in NON_NULLABLE_CONSTRUCTORS: + expected = [v if v is not None else nans_equal for v in expected] + assert_equal_data({"result": result}, {"result": expected}) + + +@cases_scalar +def test_is_close_series_with_scalar( + constructor_eager: ConstructorEager, + other: NumericLiteral, + abs_tol: float, + rel_tol: float, + *, + nans_equal: bool, + expected: list[float], +) -> None: + df = nw.from_native(constructor_eager(data), eager_only=True) + y = df["y"] + + nulls = nw.new_series( + name="nulls", + values=[None] * len(y), + dtype=nw.Float64(), + backend=df.implementation, + ) + y = y.zip_with(y != NAN_PLACEHOLDER, y**0.5).zip_with(y != NULL_PLACEHOLDER, nulls) + result = y.is_close(other, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal) + + if constructor_eager in NON_NULLABLE_CONSTRUCTORS: + expected = [v if v is not None else False for v in expected] + assert_equal_data({"result": result}, {"result": expected}) + + +# Expr +@cases_columnar +def test_is_close_expr_with_expr( + request: pytest.FixtureRequest, + constructor: Constructor, + abs_tol: float, + rel_tol: float, + *, + nans_equal: bool, + expected: list[float], +) -> None: + if "sqlframe" in str(constructor): + # TODO(FBruzzesi): Figure out a MRE and report upstream + reason = ( + "duckdb.duckdb.ParserException: Parser Error: syntax error at or near '='" + ) + request.applymarker(pytest.mark.xfail(reason=reason)) + + x, y = nw.col("x"), nw.col("y") + result = ( + nw.from_native(constructor(data)) + .with_columns( + x=nw.when(x != NAN_PLACEHOLDER).then(x).otherwise(x**0.5), + y=nw.when(y != NAN_PLACEHOLDER).then(y).otherwise(y**0.5), + ) + .with_columns( + x=nw.when(x != NULL_PLACEHOLDER).then(x), + y=nw.when(y != NULL_PLACEHOLDER).then(y), + ) + .select( + "idx", + result=x.is_close(y, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal), + ) + .sort("idx") + ) + if constructor in NON_NULLABLE_CONSTRUCTORS: + expected = [v if v is not None else nans_equal for v in expected] + assert_equal_data(result, {"idx": data["idx"], "result": expected}) + + +@cases_scalar +def test_is_close_expr_with_scalar( + request: pytest.FixtureRequest, + constructor: Constructor, + other: NumericLiteral, + abs_tol: float, + rel_tol: float, + *, + nans_equal: bool, + expected: list[float], +) -> None: + if "sqlframe" in str(constructor): + # TODO(FBruzzesi): Figure out a MRE and report upstream + reason = ( + "duckdb.duckdb.ParserException: Parser Error: syntax error at or near '='" + ) + request.applymarker(pytest.mark.xfail(reason=reason)) + + y = nw.col("y") + result = ( + nw.from_native(constructor(data)) + .with_columns(y=nw.when(y != NAN_PLACEHOLDER).then(y).otherwise(y**0.5)) + .with_columns(y=nw.when(y != NULL_PLACEHOLDER).then(y)) + .select( + "idx", + result=y.is_close( + other, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal + ), + ) + .sort("idx") + ) + if constructor in NON_NULLABLE_CONSTRUCTORS: + expected = [v if v is not None else False for v in expected] + assert_equal_data(result, {"idx": data["idx"], "result": expected})