Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference/expr.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
- filter
- clip
- is_between
- is_close
- is_duplicated
- is_finite
- is_first_distinct
Expand Down
1 change: 1 addition & 0 deletions docs/api-reference/series.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
- hist
- implementation
- is_between
- is_close
- is_duplicated
- is_empty
- is_finite
Expand Down
19 changes: 18 additions & 1 deletion narwhals/_compliant/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
LazyExprT,
NativeExprT,
)
from narwhals._compliant.utils import IsClose
from narwhals._utils import _StoresCompliant
from narwhals.dependencies import get_numpy, is_numpy_array

Expand Down Expand Up @@ -75,7 +76,7 @@ def __eq__(self, value: Any, /) -> Self: ... # type: ignore[override]
def __ne__(self, value: Any, /) -> Self: ... # type: ignore[override]


class CompliantExpr(Protocol[CompliantFrameT, CompliantSeriesOrNativeExprT_co]):
class CompliantExpr(IsClose, Protocol[CompliantFrameT, CompliantSeriesOrNativeExprT_co]):
_implementation: Implementation
_version: Version
_evaluate_output_names: EvalNames[CompliantFrameT]
Expand Down Expand Up @@ -902,6 +903,22 @@ def exp(self) -> Self:
def sqrt(self) -> Self:
return self._reuse_series("sqrt")

def is_close(
self,
other: Self | NumericLiteral,
*,
abs_tol: float,
rel_tol: float,
nans_equal: bool,
) -> Self:
return self._reuse_series(
"is_close",
other=other,
abs_tol=abs_tol,
rel_tol=rel_tol,
nans_equal=nans_equal,
)
Comment on lines +906 to +920
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might not be needed πŸ€”


@property
def cat(self) -> EagerExprCatNamespace[Self]:
return EagerExprCatNamespace(self)
Expand Down
2 changes: 2 additions & 0 deletions narwhals/_compliant/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
NativeSeriesT,
NativeSeriesT_co,
)
from narwhals._compliant.utils import IsClose
from narwhals._translate import FromIterable, FromNative, NumpyConvertible, ToNarwhals
from narwhals._typing_compat import TypeVar, assert_never
from narwhals._utils import (
Expand Down Expand Up @@ -78,6 +79,7 @@ class HistData(TypedDict, Generic[NativeSeriesT, "_CountsT_co"]):


class CompliantSeries(
IsClose,
NumpyConvertible["_1DArray", "Into1DArray"],
FromIterable,
FromNative[NativeSeriesT],
Expand Down
1 change: 1 addition & 0 deletions narwhals/_compliant/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class ScalarKwargs(TypedDict, total=False):
CompliantExprAny: TypeAlias = "CompliantExpr[Any, Any]"
CompliantSeriesAny: TypeAlias = "CompliantSeries[Any]"
CompliantSeriesOrNativeExprAny: TypeAlias = "CompliantSeriesAny | NativeExpr"
CompliantSeriesOrExprAny: TypeAlias = "CompliantSeriesAny | CompliantExprAny"
CompliantDataFrameAny: TypeAlias = "CompliantDataFrame[Any, Any, Any, Any]"
CompliantLazyFrameAny: TypeAlias = "CompliantLazyFrame[Any, Any, Any]"
CompliantFrameAny: TypeAlias = "CompliantDataFrameAny | CompliantLazyFrameAny"
Expand Down
88 changes: 88 additions & 0 deletions narwhals/_compliant/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Protocol

if TYPE_CHECKING:
from typing_extensions import Self

from narwhals.typing import NumericLiteral, TemporalLiteral


class IsClose(Protocol):
"""Every member defined is a dependency of `is_close` method."""

def __and__(self, other: Any) -> Self: ...
def __or__(self, other: Any) -> Self: ...
def __invert__(self) -> Self: ...
def __sub__(self, other: Any) -> Self: ...
def __mul__(self, other: Any) -> Self: ...
def __eq__(self, other: Self | Any) -> Self: ... # type: ignore[override]
def __gt__(self, other: Any) -> Self: ...
def __le__(self, other: Any) -> Self: ...
def abs(self) -> Self: ...
def is_nan(self) -> Self: ...
def is_finite(self) -> Self: ...
def clip(
self,
lower_bound: Self | NumericLiteral | TemporalLiteral | None,
upper_bound: Self | NumericLiteral | TemporalLiteral | None,
) -> Self: ...
def is_close(
self,
other: Self | NumericLiteral,
*,
abs_tol: float,
rel_tol: float,
nans_equal: bool,
) -> Self:
from decimal import Decimal

other_abs: Self | NumericLiteral
other_is_nan: Self | bool
other_is_inf: Self | bool
other_is_not_inf: Self | bool

if isinstance(other, (float, int, Decimal)):
from math import isinf, isnan

# NOTE: See https://discuss.python.org/t/inferred-type-of-function-that-calls-dunder-abs-abs/101447
other_abs = other.__abs__()
other_is_nan = isnan(other)
other_is_inf = isinf(other)

# Define the other_is_not_inf variable to prevent triggering the following warning:
# > DeprecationWarning: Bitwise inversion '~' on bool is deprecated and will be
# > removed in Python 3.16.
other_is_not_inf = not other_is_inf

else:
other_abs, other_is_nan = other.abs(), other.is_nan()
other_is_not_inf = other.is_finite() | other_is_nan
other_is_inf = ~other_is_not_inf

rel_threshold = self.abs().clip(lower_bound=other_abs, upper_bound=None) * rel_tol
tolerance = rel_threshold.clip(lower_bound=abs_tol, upper_bound=None)

self_is_nan = self.is_nan()
self_is_not_inf = self.is_finite() | self_is_nan

# Values are close if abs_diff <= tolerance, and both finite
is_close = (
((self - other).abs() <= tolerance) & self_is_not_inf & other_is_not_inf
)

# Handle infinity cases: infinities are close/equal if they have the same sign
self_sign, other_sign = self > 0, other > 0
is_same_inf = (~self_is_not_inf) & other_is_inf & (self_sign == other_sign)

# Handle nan cases:
# * If any value is NaN, then False (via `& ~either_nan`)
# * However, if `nans_equals = True` and if _both_ values are NaN, then True
either_nan = self_is_nan | other_is_nan
result = (is_close | is_same_inf) & ~either_nan

if nans_equal:
both_nan = self_is_nan & other_is_nan
result = result | both_nan

return result
47 changes: 46 additions & 1 deletion narwhals/_polars/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from narwhals._polars.dataframe import Method, PolarsDataFrame
from narwhals._polars.namespace import PolarsNamespace
from narwhals._utils import Version, _LimitedContext
from narwhals.typing import IntoDType
from narwhals.typing import IntoDType, NumericLiteral


class PolarsExpr:
Expand Down Expand Up @@ -232,6 +232,51 @@ def __narwhals_namespace__(self) -> PolarsNamespace: # pragma: no cover

return PolarsNamespace(version=self._version)

def is_close(
self,
other: Self | NumericLiteral,
*,
abs_tol: float,
rel_tol: float,
nans_equal: bool,
) -> Self:
native_expr = self.native
other_expr = (
extract_native(other) if isinstance(other, PolarsExpr) else pl.lit(other)
)

if self._backend_version < (1, 32, 0):
abs_diff = (native_expr - other_expr).abs()
rel_threshold = native_expr.abs().clip(lower_bound=other_expr.abs()) * rel_tol
tolerance = rel_threshold.clip(lower_bound=pl.lit(abs_tol))

self_is_inf, other_is_inf = (
native_expr.is_infinite(),
other_expr.is_infinite(),
)

# Values are close if abs_diff <= tolerance, and both finite
is_close = (abs_diff <= tolerance) & self_is_inf.not_() & other_is_inf.not_()

# Handle infinity cases: infinities are "close" only if they have the same sign
self_sign, other_sign = native_expr.sign(), other_expr.sign()
is_same_inf = self_is_inf & other_is_inf & (self_sign == other_sign)

# Handle nan cases:
# * nans_equals = True => if both values are NaN, then True
# * nans_equals = False => if any value is NaN, then False
either_nan = native_expr.is_nan() | other_expr.is_nan()
result = (is_close | is_same_inf) & either_nan.not_()

if nans_equal:
both_nan = native_expr.is_nan() & other_expr.is_nan()
result = result | both_nan
else:
result = native_expr.is_close(
other=other_expr, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal
)
return self._with_native(result)

@property
def dt(self) -> PolarsExprDateTimeNamespace:
return PolarsExprDateTimeNamespace(self)
Expand Down
45 changes: 41 additions & 4 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
IntoDType,
MultiIndexSelector,
NonNestedLiteral,
NumericLiteral,
_1DArray,
)

Expand Down Expand Up @@ -90,6 +91,7 @@
"gather_every",
"head",
"is_between",
"is_close",
"is_finite",
"is_first_distinct",
"is_in",
Expand Down Expand Up @@ -131,6 +133,11 @@
class PolarsSeries:
_implementation = Implementation.POLARS

_HIST_EMPTY_SCHEMA: ClassVar[Mapping[IncludeBreakpoint, Sequence[str]]] = {
True: ["breakpoint", "count"],
False: ["count"],
}

def __init__(self, series: pl.Series, *, version: Version) -> None:
self._native_series: pl.Series = series
self._version = version
Expand Down Expand Up @@ -478,10 +485,40 @@ def __contains__(self, other: Any) -> bool:
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e) from None

_HIST_EMPTY_SCHEMA: ClassVar[Mapping[IncludeBreakpoint, Sequence[str]]] = {
True: ["breakpoint", "count"],
False: ["count"],
}
def is_close(
self,
other: Self | NumericLiteral,
*,
abs_tol: float,
rel_tol: float,
nans_equal: bool,
) -> Self:
other_native = extract_native(other)

if self._backend_version < (1, 32, 0):
name = self.name
ns = self.__narwhals_namespace__()
result = (
self.to_frame()
.select(
ns.col(name).is_close(
other=other_native, # type: ignore[arg-type]
abs_tol=abs_tol,
rel_tol=rel_tol,
nans_equal=nans_equal,
)
)
.get_column(name)
.native
)
else:
result = self.native.is_close(
other=other_native, # pyright: ignore[reportArgumentType]
abs_tol=abs_tol,
rel_tol=rel_tol,
nans_equal=nans_equal,
)
return self._with_native(result)

def hist_from_bins(
self, bins: list[float], *, include_breakpoint: bool
Expand Down
Loading
Loading