Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference/expr.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
- filter
- clip
- is_between
- is_close
- is_duplicated
- is_finite
- is_first_distinct
Expand Down
1 change: 1 addition & 0 deletions docs/api-reference/series.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
- hist
- implementation
- is_between
- is_close
- is_duplicated
- is_empty
- is_finite
Expand Down
30 changes: 29 additions & 1 deletion narwhals/_compliant/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
LazyExprT,
NativeExprT,
)
from narwhals._utils import _StoresCompliant
from narwhals._utils import _is_close_impl, _StoresCompliant
from narwhals.dependencies import get_numpy, is_numpy_array

if TYPE_CHECKING:
Expand Down Expand Up @@ -231,6 +231,18 @@ def _evaluate_aliases(
names = self._evaluate_output_names(frame)
return alias(names) if (alias := self._alias_output_names) else names

def is_close(
self,
other: Self | NumericLiteral,
*,
abs_tol: float,
rel_tol: float,
nans_equal: bool,
) -> Self:
return _is_close_impl(
self, other, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal
)

@property
def str(self) -> StringNamespace[Self]: ...
@property
Expand Down Expand Up @@ -892,6 +904,22 @@ def exp(self) -> Self:
def sqrt(self) -> Self:
return self._reuse_series("sqrt")

def is_close(
self,
other: Self | NumericLiteral,
*,
abs_tol: float,
rel_tol: float,
nans_equal: bool,
) -> Self:
return self._reuse_series(
"is_close",
other=other,
abs_tol=abs_tol,
rel_tol=rel_tol,
nans_equal=nans_equal,
)
Comment on lines +906 to +920
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might not be needed πŸ€”


@property
def cat(self) -> EagerExprCatNamespace[Self]:
return EagerExprCatNamespace(self)
Expand Down
13 changes: 13 additions & 0 deletions narwhals/_compliant/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from narwhals._translate import FromIterable, FromNative, NumpyConvertible, ToNarwhals
from narwhals._typing_compat import TypeVar, assert_never
from narwhals._utils import (
_is_close_impl,
_StoresCompliant,
_StoresNative,
is_compliant_series,
Expand Down Expand Up @@ -284,6 +285,18 @@ def hist_from_bin_count(
"""`Series.hist(bins=None, bin_count=...)`."""
...

def is_close(
self,
other: Self | NumericLiteral,
*,
abs_tol: float,
rel_tol: float,
nans_equal: bool,
) -> Self:
return _is_close_impl(
self, other, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal
)

@property
def str(self) -> StringNamespace[Self]: ...
@property
Expand Down
4 changes: 4 additions & 0 deletions narwhals/_compliant/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class ScalarKwargs(TypedDict, total=False):
CompliantExprAny: TypeAlias = "CompliantExpr[Any, Any]"
CompliantSeriesAny: TypeAlias = "CompliantSeries[Any]"
CompliantSeriesOrNativeExprAny: TypeAlias = "CompliantSeriesAny | NativeExpr"
CompliantSeriesOrExprAny: TypeAlias = "CompliantSeriesAny | CompliantExprAny"
CompliantDataFrameAny: TypeAlias = "CompliantDataFrame[Any, Any, Any, Any]"
CompliantLazyFrameAny: TypeAlias = "CompliantLazyFrame[Any, Any, Any]"
CompliantFrameAny: TypeAlias = "CompliantDataFrameAny | CompliantLazyFrameAny"
Expand Down Expand Up @@ -113,6 +114,9 @@ class ScalarKwargs(TypedDict, total=False):
bound=CompliantSeriesOrNativeExprAny,
covariant=True,
)
CompliantSeriesOrExprT = TypeVar(
"CompliantSeriesOrExprT", bound="CompliantSeriesOrExprAny"
)
CompliantFrameT = TypeVar("CompliantFrameT", bound=CompliantFrameAny)
CompliantFrameT_co = TypeVar(
"CompliantFrameT_co", bound=CompliantFrameAny, covariant=True
Expand Down
47 changes: 46 additions & 1 deletion narwhals/_polars/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from narwhals._polars.dataframe import Method, PolarsDataFrame
from narwhals._polars.namespace import PolarsNamespace
from narwhals._utils import Version, _LimitedContext
from narwhals.typing import IntoDType
from narwhals.typing import IntoDType, NumericLiteral


class PolarsExpr:
Expand Down Expand Up @@ -232,6 +232,51 @@ def __narwhals_namespace__(self) -> PolarsNamespace: # pragma: no cover

return PolarsNamespace(version=self._version)

def is_close(
self,
other: Self | NumericLiteral,
*,
abs_tol: float,
rel_tol: float,
nans_equal: bool,
) -> Self:
native_expr = self.native
other_expr = (
extract_native(other) if isinstance(other, PolarsExpr) else pl.lit(other)
)

if self._backend_version < (1, 32, 0):
abs_diff = (native_expr - other_expr).abs()
rel_threshold = native_expr.abs().clip(lower_bound=other_expr.abs()) * rel_tol
tolerance = rel_threshold.clip(lower_bound=pl.lit(abs_tol))

self_is_inf, other_is_inf = (
native_expr.is_infinite(),
other_expr.is_infinite(),
)

# Values are close if abs_diff <= tolerance, and both finite
is_close = (abs_diff <= tolerance) & self_is_inf.not_() & other_is_inf.not_()

# Handle infinity cases: infinities are "close" only if they have the same sign
self_sign, other_sign = native_expr.sign(), other_expr.sign()
is_same_inf = self_is_inf & other_is_inf & (self_sign == other_sign)

# Handle nan cases:
# * nans_equals = True => if both values are NaN, then True
# * nans_equals = False => if any value is NaN, then False
either_nan = native_expr.is_nan() | other_expr.is_nan()
result = (is_close | is_same_inf) & either_nan.not_()

if nans_equal:
both_nan = native_expr.is_nan() & other_expr.is_nan()
result = result | both_nan
else:
result = native_expr.is_close(
other=other_expr, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal
)
return self._with_native(result)

@property
def dt(self) -> PolarsExprDateTimeNamespace:
return PolarsExprDateTimeNamespace(self)
Expand Down
52 changes: 47 additions & 5 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@
from narwhals._utils import Version, _LimitedContext
from narwhals.dtypes import DType
from narwhals.series import Series
from narwhals.typing import Into1DArray, IntoDType, MultiIndexSelector, _1DArray
from narwhals.typing import (
Into1DArray,
IntoDType,
MultiIndexSelector,
NumericLiteral,
_1DArray,
)

T = TypeVar("T")
IncludeBreakpoint: TypeAlias = Literal[False, True]
Expand Down Expand Up @@ -84,6 +90,7 @@
"gather_every",
"head",
"is_between",
"is_close",
"is_finite",
"is_first_distinct",
"is_in",
Expand Down Expand Up @@ -125,6 +132,11 @@
class PolarsSeries:
_implementation = Implementation.POLARS

_HIST_EMPTY_SCHEMA: ClassVar[Mapping[IncludeBreakpoint, Sequence[str]]] = {
True: ["breakpoint", "count"],
False: ["count"],
}

def __init__(self, series: pl.Series, *, version: Version) -> None:
self._native_series: pl.Series = series
self._version = version
Expand Down Expand Up @@ -472,10 +484,40 @@ def __contains__(self, other: Any) -> bool:
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e) from None

_HIST_EMPTY_SCHEMA: ClassVar[Mapping[IncludeBreakpoint, Sequence[str]]] = {
True: ["breakpoint", "count"],
False: ["count"],
}
def is_close(
self,
other: Self | NumericLiteral,
*,
abs_tol: float,
rel_tol: float,
nans_equal: bool,
) -> Self:
other_native = extract_native(other)

if self._backend_version < (1, 32, 0):
name = self.name
ns = self.__narwhals_namespace__()
result = (
self.to_frame()
.select(
ns.col(name).is_close(
other=other_native, # type: ignore[arg-type]
abs_tol=abs_tol,
rel_tol=rel_tol,
nans_equal=nans_equal,
)
)
.get_column(name)
.native
)
else:
result = self.native.is_close(
other=other_native, # pyright: ignore[reportArgumentType]
abs_tol=abs_tol,
rel_tol=rel_tol,
nans_equal=nans_equal,
)
return self._with_native(result)

def hist_from_bins(
self, bins: list[float], *, include_breakpoint: bool
Expand Down
62 changes: 61 additions & 1 deletion narwhals/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
NativeFrameT_co,
NativeSeriesT_co,
)
from narwhals._compliant.typing import EvalNames
from narwhals._compliant.typing import CompliantSeriesOrExprT, EvalNames
from narwhals._namespace import EagerAllowedImplementation, Namespace
from narwhals._translate import ArrowStreamExportable, IntoArrowTable, ToNarwhalsT_co
from narwhals.dataframe import DataFrame, LazyFrame
Expand All @@ -86,6 +86,7 @@
DTypes,
IntoSeriesT,
MultiIndexSelector,
NumericLiteral,
SingleIndexSelector,
SizedMultiIndexSelector,
SizeUnit,
Expand Down Expand Up @@ -2035,3 +2036,62 @@ def deep_attrgetter(attr: str, *nested: str) -> attrgetter[Any]:
def deep_getattr(obj: Any, name_1: str, *nested: str) -> Any:
"""Perform a nested attribute lookup on `obj`."""
return deep_attrgetter(name_1, *nested)(obj)


def _is_close_impl(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle, we could use this implementation also for polars pre 1.32.0, however polars has a couple of native methods which play a bit nicer (e.g. .sign, .is_infinite, .not_). If we were to introduce those, then we can have a single implementation function

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not_ and __invert__ are identical btw

self: CompliantSeriesOrExprT,
other: CompliantSeriesOrExprT | NumericLiteral,
*,
abs_tol: float,
rel_tol: float,
nans_equal: bool,
) -> CompliantSeriesOrExprT:
from decimal import Decimal

other_abs: CompliantSeriesOrExprT | NumericLiteral
other_is_nan: CompliantSeriesOrExprT | bool
other_is_inf: CompliantSeriesOrExprT | bool
other_is_not_inf: CompliantSeriesOrExprT | bool

if isinstance(other, (float, int, Decimal)):
from math import isinf, isnan

other_abs, other_is_nan, other_is_inf = abs(other), isnan(other), isinf(other) # type: ignore[assignment]

# Define the other_is_not_inf variable to prevent triggering the following warning:
# > DeprecationWarning: Bitwise inversion '~' on bool is deprecated and will be
# > removed in Python 3.16. This returns the bitwise inversion of the
# > underlying int object and is usually not what you expect from negating
# > a bool. Use the 'not' operator for boolean negation or ~int(x) if you
# > really want the bitwise inversion of the underlying int.
other_is_not_inf = not other_is_inf

else:
other_abs, other_is_nan = other.abs(), other.is_nan()
other_is_not_inf = other.is_finite() | other_is_nan
other_is_inf = ~other_is_not_inf

rel_threshold = self.abs().clip(lower_bound=other_abs, upper_bound=None) * rel_tol # type: ignore[arg-type]
tolerance = rel_threshold.clip(lower_bound=abs_tol, upper_bound=None)

self_is_nan = self.is_nan()
self_is_not_inf = self.is_finite() | self_is_nan

# Values are close if abs_diff <= tolerance, and both finite
is_close = ((self - other).abs() <= tolerance) & self_is_not_inf & other_is_not_inf

# Handle infinity cases: infinities are "close" only if they have the same sign
self_sign, other_sign = self > 0, other > 0
is_same_inf = (~self_is_not_inf) & other_is_inf & (self_sign == other_sign)

# Handle nan cases:
# * nans_equals = True => if both values are NaN, then True
# * nans_equals = False => if any value is NaN, then False
either_nan = self_is_nan | other_is_nan
result = (is_close | is_same_inf) & ~either_nan

if nans_equal:
both_nan = self_is_nan & other_is_nan
result = result | both_nan

return result
Loading
Loading