Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference/expr.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
- filter
- clip
- is_between
- is_close
- is_duplicated
- is_finite
- is_first_distinct
Expand Down
1 change: 1 addition & 0 deletions docs/api-reference/series.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
- hist
- implementation
- is_between
- is_close
- is_duplicated
- is_empty
- is_finite
Expand Down
19 changes: 18 additions & 1 deletion narwhals/_compliant/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
LazyExprT,
NativeExprT,
)
from narwhals._compliant.utils import IsClose
from narwhals._utils import _StoresCompliant
from narwhals.dependencies import get_numpy, is_numpy_array

Expand Down Expand Up @@ -75,7 +76,7 @@ def __eq__(self, value: Any, /) -> Self: ... # type: ignore[override]
def __ne__(self, value: Any, /) -> Self: ... # type: ignore[override]


class CompliantExpr(Protocol[CompliantFrameT, CompliantSeriesOrNativeExprT_co]):
class CompliantExpr(IsClose, Protocol[CompliantFrameT, CompliantSeriesOrNativeExprT_co]):
_implementation: Implementation
_version: Version
_evaluate_output_names: EvalNames[CompliantFrameT]
Expand Down Expand Up @@ -902,6 +903,22 @@ def exp(self) -> Self:
def sqrt(self) -> Self:
return self._reuse_series("sqrt")

def is_close(
self,
other: Self | NumericLiteral,
*,
abs_tol: float,
rel_tol: float,
nans_equal: bool,
) -> Self:
return self._reuse_series(
"is_close",
other=other,
abs_tol=abs_tol,
rel_tol=rel_tol,
nans_equal=nans_equal,
)
Comment on lines +906 to +920
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might not be needed πŸ€”


@property
def cat(self) -> EagerExprCatNamespace[Self]:
return EagerExprCatNamespace(self)
Expand Down
2 changes: 2 additions & 0 deletions narwhals/_compliant/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
NativeSeriesT,
NativeSeriesT_co,
)
from narwhals._compliant.utils import IsClose
from narwhals._translate import FromIterable, FromNative, NumpyConvertible, ToNarwhals
from narwhals._typing_compat import TypeVar, assert_never
from narwhals._utils import (
Expand Down Expand Up @@ -78,6 +79,7 @@ class HistData(TypedDict, Generic[NativeSeriesT, "_CountsT_co"]):


class CompliantSeries(
IsClose,
NumpyConvertible["_1DArray", "Into1DArray"],
FromIterable,
FromNative[NativeSeriesT],
Expand Down
88 changes: 88 additions & 0 deletions narwhals/_compliant/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Protocol

if TYPE_CHECKING:
from typing_extensions import Self

from narwhals.typing import NumericLiteral, TemporalLiteral


class IsClose(Protocol):
"""Every member defined is a dependency of `is_close` method."""

def __and__(self, other: Any) -> Self: ...
def __or__(self, other: Any) -> Self: ...
def __invert__(self) -> Self: ...
def __sub__(self, other: Any) -> Self: ...
def __mul__(self, other: Any) -> Self: ...
def __eq__(self, other: Self | Any) -> Self: ... # type: ignore[override]
def __gt__(self, other: Any) -> Self: ...
def __le__(self, other: Any) -> Self: ...
def abs(self) -> Self: ...
def is_nan(self) -> Self: ...
def is_finite(self) -> Self: ...
def clip(
self,
lower_bound: Self | NumericLiteral | TemporalLiteral | None,
upper_bound: Self | NumericLiteral | TemporalLiteral | None,
) -> Self: ...
def is_close(
self,
other: Self | NumericLiteral,
*,
abs_tol: float,
rel_tol: float,
nans_equal: bool,
) -> Self:
from decimal import Decimal

other_abs: Self | NumericLiteral
other_is_nan: Self | bool
other_is_inf: Self | bool
other_is_not_inf: Self | bool

if isinstance(other, (float, int, Decimal)):
from math import isinf, isnan

# NOTE: See https://discuss.python.org/t/inferred-type-of-function-that-calls-dunder-abs-abs/101447
other_abs = other.__abs__()
other_is_nan = isnan(other)
other_is_inf = isinf(other)

# Define the other_is_not_inf variable to prevent triggering the following warning:
# > DeprecationWarning: Bitwise inversion '~' on bool is deprecated and will be
# > removed in Python 3.16.
other_is_not_inf = not other_is_inf

else:
other_abs, other_is_nan = other.abs(), other.is_nan()
other_is_not_inf = other.is_finite() | other_is_nan
other_is_inf = ~other_is_not_inf

rel_threshold = self.abs().clip(lower_bound=other_abs, upper_bound=None) * rel_tol
tolerance = rel_threshold.clip(lower_bound=abs_tol, upper_bound=None)

self_is_nan = self.is_nan()
self_is_not_inf = self.is_finite() | self_is_nan

# Values are close if abs_diff <= tolerance, and both finite
is_close = (
((self - other).abs() <= tolerance) & self_is_not_inf & other_is_not_inf
)

# Handle infinity cases: infinities are close/equal if they have the same sign
self_sign, other_sign = self > 0, other > 0
is_same_inf = (~self_is_not_inf) & other_is_inf & (self_sign == other_sign)

# Handle nan cases:
# * If any value is NaN, then False (via `& ~either_nan`)
# * However, if `nans_equals = True` and if _both_ values are NaN, then True
either_nan = self_is_nan | other_is_nan
result = (is_close | is_same_inf) & ~either_nan

if nans_equal:
both_nan = self_is_nan & other_is_nan
result = result | both_nan

return result
42 changes: 41 additions & 1 deletion narwhals/_polars/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from narwhals._polars.dataframe import Method, PolarsDataFrame
from narwhals._polars.namespace import PolarsNamespace
from narwhals._utils import Version, _LimitedContext
from narwhals.typing import IntoDType
from narwhals.typing import IntoDType, NumericLiteral


class PolarsExpr:
Expand Down Expand Up @@ -232,6 +232,46 @@ def __narwhals_namespace__(self) -> PolarsNamespace: # pragma: no cover

return PolarsNamespace(version=self._version)

def is_close(
self,
other: Self | NumericLiteral,
*,
abs_tol: float,
rel_tol: float,
nans_equal: bool,
) -> Self:
left = self.native
right = other.native if isinstance(other, PolarsExpr) else pl.lit(other)

if self._backend_version < (1, 32, 0):
lower_bound = right.abs()
tolerance = (left.abs().clip(lower_bound) * rel_tol).clip(abs_tol)

# Values are close if abs_diff <= tolerance, and both finite
abs_diff = (left - right).abs()
all_ = pl.all_horizontal
is_close = all_((abs_diff <= tolerance), left.is_finite(), right.is_finite())

# Handle infinity cases: infinities are "close" only if they have the same sign
is_same_inf = all_(
left.is_infinite(), right.is_infinite(), (left.sign() == right.sign())
)

# Handle nan cases:
# * nans_equals = True => if both values are NaN, then True
# * nans_equals = False => if any value is NaN, then False
left_is_nan, right_is_nan = left.is_nan(), right.is_nan()
either_nan = left_is_nan | right_is_nan
result = (is_close | is_same_inf) & either_nan.not_()

if nans_equal:
result = result | (left_is_nan & right_is_nan)
else:
result = left.is_close(
right, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal
)
return self._with_native(result)

@property
def dt(self) -> PolarsExprDateTimeNamespace:
return PolarsExprDateTimeNamespace(self)
Expand Down
32 changes: 28 additions & 4 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
IntoDType,
MultiIndexSelector,
NonNestedLiteral,
NumericLiteral,
_1DArray,
)

Expand Down Expand Up @@ -90,6 +91,7 @@
"gather_every",
"head",
"is_between",
"is_close",
"is_finite",
"is_first_distinct",
"is_in",
Expand Down Expand Up @@ -131,6 +133,11 @@
class PolarsSeries:
_implementation = Implementation.POLARS

_HIST_EMPTY_SCHEMA: ClassVar[Mapping[IncludeBreakpoint, Sequence[str]]] = {
True: ["breakpoint", "count"],
False: ["count"],
}

def __init__(self, series: pl.Series, *, version: Version) -> None:
self._native_series: pl.Series = series
self._version = version
Expand Down Expand Up @@ -478,10 +485,27 @@ def __contains__(self, other: Any) -> bool:
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e) from None

_HIST_EMPTY_SCHEMA: ClassVar[Mapping[IncludeBreakpoint, Sequence[str]]] = {
True: ["breakpoint", "count"],
False: ["count"],
}
def is_close(
self,
other: Self | NumericLiteral,
*,
abs_tol: float,
rel_tol: float,
nans_equal: bool,
) -> PolarsSeries:
if self._backend_version < (1, 32, 0):
name = self.name
ns = self.__narwhals_namespace__()
other_expr = other._to_expr() if isinstance(other, PolarsSeries) else other
expr = ns.col(name).is_close(
other_expr, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal
)
return self.to_frame().select(expr).get_column(name)
other_series = other.native if isinstance(other, PolarsSeries) else other
result = self.native.is_close(
other_series, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal
)
return self._with_native(result)

def hist_from_bins(
self, bins: list[float], *, include_breakpoint: bool
Expand Down
93 changes: 92 additions & 1 deletion narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from narwhals._utils import _validate_rolling_arguments, ensure_type, flatten
from narwhals.dtypes import _validate_dtype
from narwhals.exceptions import InvalidOperationError
from narwhals.exceptions import ComputeError, InvalidOperationError
from narwhals.expr_cat import ExprCatNamespace
from narwhals.expr_dt import ExprDateTimeNamespace
from narwhals.expr_list import ExprListNamespace
Expand Down Expand Up @@ -2347,6 +2347,97 @@ def sqrt(self) -> Self:
"""
return self._with_elementwise(lambda plx: self._to_compliant_expr(plx).sqrt())

def is_close(
self,
other: Self | NumericLiteral,
*,
abs_tol: float = 0.0,
rel_tol: float = 1e-09,
nans_equal: bool = False,
) -> Self:
r"""Check if this expression is close, i.e. almost equal, to the other expression.

Two values `a` and `b` are considered close if the following condition holds:

$$
|a-b| \le max \{ \text{rel\_tol} \cdot max \{ |a|, |b| \}, \text{abs\_tol} \}
$$

Arguments:
other: Values to compare with.
abs_tol: Absolute tolerance. This is the maximum allowed absolute difference
between two values. Must be non-negative.
rel_tol: Relative tolerance. This is the maximum allowed difference between
two values, relative to the larger absolute value. Must be in the range
[0, 1).
nans_equal: Whether NaN values should be considered equal.

Returns:
Expression of Boolean data type.

Notes:
The implementation of this method is symmetric and mirrors the behavior of
`math.isclose`. Specifically note that this behavior is different to
`numpy.isclose`.

Examples:
>>> import duckdb
>>> import pyarrow as pa
>>> import narwhals as nw
>>>
>>> data = {
... "x": [1.0, float("inf"), 1.41, None, float("nan")],
... "y": [1.2, float("inf"), 1.40, None, float("nan")],
... }
>>> _table = pa.table(data)
>>> df_native = duckdb.table("_table")
>>> df = nw.from_native(df_native)
>>> df.with_columns(
... is_close=nw.col("x").is_close(
... nw.col("y"), abs_tol=0.1, nans_equal=True
... )
... )
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
| Narwhals LazyFrame |
|------------------------------|
|β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”|
|β”‚ x β”‚ y β”‚ is_close β”‚|
|β”‚ double β”‚ double β”‚ boolean β”‚|
|β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€|
|β”‚ 1.0 β”‚ 1.2 β”‚ false β”‚|
|β”‚ inf β”‚ inf β”‚ true β”‚|
|β”‚ 1.41 β”‚ 1.4 β”‚ true β”‚|
|β”‚ NULL β”‚ NULL β”‚ NULL β”‚|
|β”‚ nan β”‚ nan β”‚ true β”‚|
|β””β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜|
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
if abs_tol < 0:
msg = f"`abs_tol` must be non-negative but got {abs_tol}"
raise ComputeError(msg)

if not (0 <= rel_tol < 1):
msg = f"`rel_tol` must be in the range [0, 1) but got {rel_tol}"
raise ComputeError(msg)
Comment on lines +2414 to +2421
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For lazy backends, we are not raising for non-numeric dtypes


kwargs = {"abs_tol": abs_tol, "rel_tol": rel_tol, "nans_equal": nans_equal}
return self.__class__(
lambda plx: apply_n_ary_operation(
plx,
lambda *exprs: exprs[0].is_close(exprs[1], **kwargs),
self,
other,
str_as_lit=False,
),
combine_metadata(
self,
other,
str_as_lit=False,
allow_multi_output=False,
to_single_output=False,
),
)

@property
def str(self) -> ExprStringNamespace[Self]:
return ExprStringNamespace(self)
Expand Down
Loading
Loading