diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index e4761cc681..ecb26b4a9a 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -585,7 +585,7 @@ def tail(self, n: int) -> Self: return self._with_native(self.native.slice(max(0, num_rows - n))) return self._with_native(self.native.slice(abs(n))) - def is_in(self, other: Any) -> Self: + def is_in(self, other: Sequence[Any] | ChunkedArrayAny) -> Self: if self._is_native(other): value_set: ArrayOrChunkedArray = other else: diff --git a/narwhals/_compliant/column.py b/narwhals/_compliant/column.py index 20f7f08f5c..ddea8a2b60 100644 --- a/narwhals/_compliant/column.py +++ b/narwhals/_compliant/column.py @@ -170,7 +170,7 @@ def is_duplicated(self) -> Self: def is_finite(self) -> Self: ... def is_first_distinct(self) -> Self: ... - def is_in(self, other: Any) -> Self: ... + def is_in(self, other: Sequence[Any]) -> Self: ... def is_last_distinct(self) -> Self: ... def is_nan(self) -> Self: ... def is_null(self) -> Self: ... diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index 05e371f988..23e94f0915 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -39,7 +39,7 @@ if TYPE_CHECKING: from collections.abc import Mapping, Sequence - from typing_extensions import Self, TypeIs + from typing_extensions import Self, TypeAlias, TypeIs from narwhals._compliant.namespace import CompliantNamespace, EagerNamespace from narwhals._compliant.series import CompliantSeries @@ -61,6 +61,8 @@ __all__ = ["CompliantExpr", "DepthTrackingExpr", "EagerExpr", "LazyExpr", "NativeExpr"] +Incomplete: TypeAlias = "Any" + class NativeExpr(Protocol): """An `Expr`-like object from a package with [Lazy-only support](https://narwhals-dev.github.io/narwhals/extending/#levels-of-support). @@ -613,7 +615,7 @@ def fill_null( "fill_null", value=value, scalar_kwargs={"strategy": strategy, "limit": limit} ) - def is_in(self, other: Any) -> Self: + def is_in(self, other: Sequence[Any] | Incomplete) -> Self: return self._reuse_series("is_in", other=other) def arg_true(self) -> Self: diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index bca1d80bbc..c05fac6f49 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -139,6 +139,7 @@ def head(self, n: int) -> Self: ... def is_empty(self) -> bool: return self.len() == 0 + def is_in(self, other: Sequence[Any] | NativeSeriesT) -> Self: ... def is_sorted(self, *, descending: bool) -> bool: ... def item(self, index: int | None) -> Any: ... def kurtosis(self) -> float | None: ... diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 1eb409ae09..6a3a682b9e 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -611,7 +611,7 @@ def func(expr: dx.Series) -> dx.Series: return self._with_callable(func, "is_unique") - def is_in(self, other: Any) -> Self: + def is_in(self, other: Sequence[Any]) -> Self: return self._with_callable(lambda expr: expr.isin(other), "is_in") def null_count(self) -> Self: diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 40921715f0..b3062a6fb5 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -62,7 +62,7 @@ PandasHistData: TypeAlias = "HistData[pd.Series[Any], list[float]]" - +Incomplete: TypeAlias = "Any" PANDAS_TO_NUMPY_DTYPE_NO_MISSING = { "Int64": "int64", "int64[pyarrow]": "int64", @@ -109,7 +109,7 @@ } -class PandasLikeSeries(EagerSeries[Any]): +class PandasLikeSeries(EagerSeries[Incomplete]): def __init__( self, native_series: Any, *, implementation: Implementation, version: Version ) -> None: @@ -125,7 +125,7 @@ def __init__( self._broadcast = False @property - def native(self) -> Any: + def native(self) -> Incomplete: return self._native_series def __native_namespace__(self) -> ModuleType: @@ -362,7 +362,7 @@ def is_between( assert_never(closed) return self._with_native(res).alias(ser.name) - def is_in(self, other: Any) -> Self: + def is_in(self, other: Sequence[Any] | Incomplete) -> Self: return self._with_native(self.native.isin(other)) def arg_true(self) -> Self: diff --git a/narwhals/_utils.py b/narwhals/_utils.py index 65a2c183f1..234990edf3 100644 --- a/narwhals/_utils.py +++ b/narwhals/_utils.py @@ -47,6 +47,7 @@ is_numpy_array_1d_int, is_pandas_like_dataframe, is_pandas_like_series, + is_polars_series, ) from narwhals.exceptions import ColumnNotFoundError, DuplicateError, InvalidOperationError @@ -124,6 +125,7 @@ CompliantLazyFrame, CompliantSeries, DTypes, + EagerAllowed, FileSource, IntoSeriesT, MultiIndexSelector, @@ -2120,3 +2122,57 @@ def extend_bool( Stolen from https://github.com/pola-rs/polars/blob/b8bfb07a4a37a8d449d6d1841e345817431142df/py-polars/polars/_utils/various.py#L580-L594 """ return (value,) * n_match if isinstance(value, bool) else tuple(value) + + +class _CanTo_List(Protocol): # noqa: N801 + def to_list(self, *args: Any, **kwds: Any) -> list[Any]: ... + + +class _CanToList(Protocol): + def tolist(self, *args: Any, **kwds: Any) -> list[Any]: ... + + +class _CanTo_PyList(Protocol): # noqa: N801 + def to_pylist(self, *args: Any, **kwds: Any) -> list[Any]: ... + + +def can_to_list(obj: Any) -> TypeIs[_CanTo_List]: + return ( + is_narwhals_series(obj) + or is_polars_series(obj) + or _hasattr_static(obj, "to_list") + ) + + +def can_tolist(obj: Any) -> TypeIs[_CanToList]: + return is_numpy_array_1d(obj) or _hasattr_static(obj, "tolist") + + +def can_to_pylist(obj: Any) -> TypeIs[_CanTo_PyList]: + return ( + (pa := get_pyarrow()) and isinstance(obj, (pa.Array, pa.ChunkedArray)) + ) or _hasattr_static(obj, "to_pylist") + + +# TODO @dangotbanned: Add (brief) doc +def iterable_to_sequence( + iterable: Iterable[Any], /, *, backend: EagerAllowed | None = None +) -> Sequence[Any]: + result: Sequence[Any] + if backend is not None: + from narwhals.series import Series + + result = Series.from_iterable("", iterable, backend=backend).to_list() + elif isinstance(iterable, (tuple, list)): + result = iterable + elif isinstance(iterable, (Iterator, Sequence)): + result = tuple(iterable) + elif can_to_list(iterable): + result = iterable.to_list() + elif can_tolist(iterable): + result = iterable.tolist() + elif can_to_pylist(iterable): + result = iterable.to_pylist() + else: + result = tuple(iterable) + return result diff --git a/narwhals/expr.py b/narwhals/expr.py index df72bcbfea..9f59dc476c 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -9,8 +9,14 @@ ExprMetadata, apply_n_ary_operation, combine_metadata, + is_series, +) +from narwhals._utils import ( + _validate_rolling_arguments, + ensure_type, + flatten, + iterable_to_sequence, ) -from narwhals._utils import _validate_rolling_arguments, ensure_type, flatten from narwhals.dtypes import _validate_dtype from narwhals.exceptions import ComputeError, InvalidOperationError from narwhals.expr_cat import ExprCatNamespace @@ -19,7 +25,6 @@ from narwhals.expr_name import ExprNameNamespace from narwhals.expr_str import ExprStringNamespace from narwhals.expr_struct import ExprStructNamespace -from narwhals.translate import to_native if TYPE_CHECKING: from typing import NoReturn, TypeVar @@ -968,7 +973,7 @@ def is_between( upper_bound, ) - def is_in(self, other: Any) -> Self: + def is_in(self, other: Iterable[Any]) -> Self: """Check if elements of this expression are present in the other iterable. Arguments: @@ -991,10 +996,10 @@ def is_in(self, other: Any) -> Self: └──────────────────┘ """ if isinstance(other, Iterable) and not isinstance(other, (str, bytes)): + other = other.to_native() if is_series(other) else iterable_to_sequence(other) return self._with_elementwise( - lambda plx: self._to_compliant_expr(plx).is_in( - to_native(other, pass_through=True) - ) + # TODO @dangotbanned: Fix after getting feedback on https://github.com/narwhals-dev/narwhals/pull/3207#discussion_r2430089632 + lambda plx: self._to_compliant_expr(plx).is_in(other) # pyright: ignore[reportArgumentType] ) msg = "Narwhals `is_in` doesn't accept expressions as an argument, as opposed to Polars. You should provide an iterable instead." raise NotImplementedError(msg) diff --git a/narwhals/series.py b/narwhals/series.py index 0eea411e28..c51307533f 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -14,6 +14,7 @@ is_compliant_series, is_eager_allowed, is_index_selector, + iterable_to_sequence, qualified_type_name, supports_arrow_c_stream, ) @@ -25,7 +26,6 @@ from narwhals.series_list import SeriesListNamespace from narwhals.series_str import SeriesStringNamespace from narwhals.series_struct import SeriesStructNamespace -from narwhals.translate import to_native from narwhals.typing import IntoSeriesT if TYPE_CHECKING: @@ -926,7 +926,7 @@ def last(self) -> PythonLiteral: """ return self._compliant_series.last() - def is_in(self, other: Any) -> Self: + def is_in(self, other: Iterable[Any]) -> Self: """Check if the elements of this Series are in the other sequence. Arguments: @@ -948,9 +948,12 @@ def is_in(self, other: Any) -> Self: ] ] """ - return self._with_compliant( - self._compliant_series.is_in(to_native(other, pass_through=True)) + other = ( + other.to_native() + if isinstance(other, Series) + else iterable_to_sequence(other, backend=self.implementation) ) + return self._with_compliant(self._compliant_series.is_in(other)) def arg_true(self) -> Self: """Find elements where boolean Series is True. diff --git a/tests/conftest.py b/tests/conftest.py index 3f50c9717d..028aea4153 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import os import uuid +from collections import deque from copy import deepcopy from functools import lru_cache from importlib.util import find_spec @@ -10,11 +11,22 @@ import pytest import narwhals as nw -from narwhals._utils import Implementation, generate_temporary_column_name +from narwhals._utils import ( + Implementation, + generate_temporary_column_name, + qualified_type_name, +) from tests.utils import ID_PANDAS_LIKE, PANDAS_VERSION, pyspark_session, sqlframe_session if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import ( + Generator, + Iterable, + Iterator, + KeysView, + Sequence, + ValuesView, + ) import duckdb import ibis @@ -32,6 +44,7 @@ Constructor, ConstructorEager, ConstructorLazy, + IntoIterable, NestedOrEnumDType, ) @@ -324,7 +337,9 @@ def eager_backend(request: pytest.FixtureRequest) -> EagerAllowed: return request.param # type: ignore[no-any-return] -@pytest.fixture(params=[el for el in TEST_EAGER_BACKENDS if not isinstance(el, str)]) +@pytest.fixture( + params=[el for el in TEST_EAGER_BACKENDS if not isinstance(el, str)], scope="session" +) def eager_implementation(request: pytest.FixtureRequest) -> EagerAllowed: """Use if a test is heavily parametric, skips `str` backend.""" return request.param # type: ignore[no-any-return] @@ -375,3 +390,103 @@ def non_nested_type(request: pytest.FixtureRequest) -> type[NonNestedDType]: def nested_dtype(request: pytest.FixtureRequest) -> NestedOrEnumDType: dtype: NestedOrEnumDType = request.param return dtype + + +class UserDefinedIterable: + def __init__(self, iterable: Iterable[Any]) -> None: + self.iterable: Iterable[Any] = iterable + + def __iter__(self) -> Iterator[Any]: + yield from self.iterable + + +def generator_function(iterable: Iterable[Any]) -> Generator[Any, Any, None]: + yield from iterable + + +def generator_expression(iterable: Iterable[Any]) -> Generator[Any, None, None]: + return (element for element in iterable) + + +def dict_keys(iterable: Iterable[Any]) -> KeysView[Any]: + return dict.fromkeys(iterable).keys() + + +def dict_values(iterable: Iterable[Any]) -> ValuesView[Any]: + return dict(enumerate(iterable)).values() + + +def chunked_array(iterable: Any) -> Iterable[Any]: + import pyarrow as pa + + return pa.chunked_array([iterable]) + + +def _ids_into_iter(obj: Any) -> str: + module: str = "" + if (obj_module := obj.__module__) and obj_module != __name__: + module = obj.__module__ + name = qualified_type_name(obj) + if name in {"function", "builtin_function_or_method"} or "_cython" in name: + return f"{module}.{obj.__qualname__}" if module else obj.__qualname__ + return name.removeprefix(__name__).strip(".") + + +def _build_into_iter() -> Iterator[IntoIterable]: # pragma: no cover + yield from ( + # 1-4: should cover `Iterable`, `Sequence`, `Iterator` + list, + tuple, + iter, + deque, + # 5-6: cover `Generator` + generator_function, + generator_expression, + # 7-8: `Iterable`, but quite commonly cause issues upstream as they are `Sized` but not `Sequence` + dict_keys, + dict_values, + # 9: duck typing + UserDefinedIterable, + ) + # 10: 1D numpy + if find_spec("numpy"): + import numpy as np + + yield np.array + # 11-13: 1D pandas + if find_spec("pandas"): + import pandas as pd + + yield from (pd.Index, pd.array, pd.Series) + # 14: 1D polars + if find_spec("polars"): + import polars as pl + + yield pl.Series + # 15-16: 1D pyarrow + if find_spec("pyarrow"): + import pyarrow as pa + + yield from (pa.array, chunked_array) + + +def _into_iter_selector() -> Callable[[int], Iterator[IntoIterable]]: + callables = tuple(_build_into_iter()) + + def pick(n: int, /) -> Iterator[IntoIterable]: + yield from callables[:n] + + return pick + + +_into_iter: Callable[[int], Iterator[IntoIterable]] = _into_iter_selector() +"""`into_iter` fixtures use the suffix `_` to denote the maximum number of constructors. + +Anything greater than **10** may return less depending on available dependencies. +""" + + +@pytest.fixture(params=_into_iter(16), scope="session", ids=_ids_into_iter) +def into_iter_16(request: pytest.FixtureRequest) -> IntoIterable: + function: IntoIterable = request.param + return function diff --git a/tests/expr_and_series/is_in_test.py b/tests/expr_and_series/is_in_test.py index 2ae6cabea5..9753436037 100644 --- a/tests/expr_and_series/is_in_test.py +++ b/tests/expr_and_series/is_in_test.py @@ -5,7 +5,14 @@ import pytest import narwhals as nw -from tests.utils import Constructor, ConstructorEager, assert_equal_data +from tests.utils import ( + PANDAS_VERSION, + Constructor, + ConstructorEager, + IntoIterable, + assert_equal_data, + assert_equal_series, +) data = {"a": [1, 4, 2, 5]} @@ -50,3 +57,70 @@ def test_filter_is_in_with_series(constructor_eager: ConstructorEager) -> None: result = df.filter(nw.col("a").is_in(df["b"])) expected = {"a": [1, 2], "b": [1, 2]} assert_equal_data(result, expected) + + +def test_expr_is_in_series_wrong_backend(constructor: Constructor) -> None: + pytest.importorskip("polars") + pytest.importorskip("pyarrow") + + import polars as pl + import pyarrow as pa + + values = [5, 6, 7, 8] + native_pa = pa.chunked_array([values]) + native_pl = pl.Series(values) + df = nw.from_native(constructor(data)) + is_polars = df.implementation.is_polars() + ser = nw.from_native(native_pa if is_polars else native_pl, series_only=True) + result = df.select(nw.col("a").is_in(ser)).sort("a") + expected = {"a": [False, False, False, True]} + assert_equal_data(result, expected) + + +@pytest.mark.slow +def test_expr_is_in_iterable( + constructor: Constructor, into_iter_16: IntoIterable, request: pytest.FixtureRequest +) -> None: + test_name = request.node.name + request.applymarker( + pytest.mark.xfail( + all(part in test_name for part in ("duckdb", "pandas", "array")) + and PANDAS_VERSION < (2,), + reason=( + "Pandas bug produced numpy scalars on `pd.array(...).tolist()`\n" + "Not implemented Error: Unable to transform python value of type '' to DuckDB LogicalType\n" + "https://github.com/pandas-dev/pandas/pull/49890" + ), + ) + ) + df = nw.from_native(constructor(data)) + expected = {"a": [False, True, True, False]} + iterable = into_iter_16((4, 2)) + expr = nw.col("a").is_in(iterable) + result = df.select(expr) + assert_equal_data(result, expected) + # NOTE: For an `Iterator`, this will fail if we haven't collected it first + repeated = df.select(expr) + assert_equal_data(repeated, expected) + + +@pytest.mark.slow +def test_ser_is_in_iterable( + constructor_eager: ConstructorEager, + into_iter_16: IntoIterable, + request: pytest.FixtureRequest, +) -> None: + test_name = request.node.name + # NOTE: This *could* be supported by using `ExtensionArray.tolist` (same path as numpy) + request.applymarker( + pytest.mark.xfail( + all(part in test_name for part in ("polars", "pandas", "array")), + raises=TypeError, + reason="Polars doesn't support `pd.array`.\nhttps://github.com/pola-rs/polars/issues/22757", + ) + ) + iterable = into_iter_16((4, 2)) + ser = nw.from_native(constructor_eager(data)).get_column("a") + result = ser.is_in(iterable) + expected = [False, True, True, False] + assert_equal_series(result, expected, "a") diff --git a/tests/series_only/from_iterable_test.py b/tests/series_only/from_iterable_test.py index 4daac24ea8..f09c3b5ac5 100644 --- a/tests/series_only/from_iterable_test.py +++ b/tests/series_only/from_iterable_test.py @@ -1,113 +1,18 @@ from __future__ import annotations -from collections import deque -from importlib.util import find_spec from typing import TYPE_CHECKING, Any import pytest import narwhals as nw -from narwhals._utils import qualified_type_name from tests.utils import PANDAS_VERSION, assert_equal_series if TYPE_CHECKING: - from collections.abc import ( - Callable, - Generator, - Iterable, - Iterator, - KeysView, - Sequence, - ValuesView, - ) - - from typing_extensions import TypeAlias + from collections.abc import Sequence from narwhals._typing import EagerAllowed from narwhals.typing import IntoDType - - IntoIterable: TypeAlias = Callable[..., Iterable[Any]] - - -class UserDefinedIterable: - def __init__(self, iterable: Iterable[Any]) -> None: - self.iterable: Iterable[Any] = iterable - - def __iter__(self) -> Iterator[Any]: - yield from self.iterable - - -def generator_function(iterable: Iterable[Any]) -> Generator[Any, Any, None]: - yield from iterable - - -def generator_expression(iterable: Iterable[Any]) -> Generator[Any, None, None]: - return (element for element in iterable) - - -def dict_keys(iterable: Iterable[Any]) -> KeysView[Any]: - return dict.fromkeys(iterable).keys() - - -def dict_values(iterable: Iterable[Any]) -> ValuesView[Any]: - return dict(enumerate(iterable)).values() - - -_INTO_ITER_3RD_PARTY: list[IntoIterable] = [] - -if find_spec("numpy"): # pragma: no cover - import numpy as np - - _INTO_ITER_3RD_PARTY.append(np.array) -else: # pragma: no cover - ... -if find_spec("pandas"): # pragma: no cover - import pandas as pd - - _INTO_ITER_3RD_PARTY.extend([pd.Index, pd.array, pd.Series]) -else: # pragma: no cover - ... -if find_spec("polars"): # pragma: no cover - import polars as pl - - _INTO_ITER_3RD_PARTY.append(pl.Series) -else: # pragma: no cover - ... -if find_spec("pyarrow"): # pragma: no cover - import pyarrow as pa - - def chunked_array(iterable: Any) -> Iterable[Any]: - return pa.chunked_array([iterable]) - - _INTO_ITER_3RD_PARTY.extend([pa.array, chunked_array]) -else: # pragma: no cover - ... - -_INTO_ITER_STDLIB: tuple[IntoIterable, ...] = ( - list, - tuple, - iter, - deque, - generator_function, - generator_expression, -) -_INTO_ITER_STDLIB_EXOTIC: tuple[IntoIterable, ...] = dict_keys, dict_values -INTO_ITER: tuple[IntoIterable, ...] = ( - *_INTO_ITER_STDLIB, - *_INTO_ITER_STDLIB_EXOTIC, - UserDefinedIterable, - *_INTO_ITER_3RD_PARTY, -) - - -def _ids_into_iter(obj: Any) -> str: - module: str = "" - if (obj_module := obj.__module__) and obj_module != __name__: - module = obj.__module__ - name = qualified_type_name(obj) - if name in {"function", "builtin_function_or_method"} or "_cython" in name: - return f"{module}.{obj.__qualname__}" if module else obj.__qualname__ - return name.removeprefix(__name__).strip(".") + from tests.utils import IntoIterable @pytest.mark.parametrize( @@ -120,16 +25,15 @@ def _ids_into_iter(obj: Any) -> str: ], ids=["Int32", "no-dtype", "Float64", "String"], ) -@pytest.mark.parametrize("into_iter", INTO_ITER, ids=_ids_into_iter) def test_series_from_iterable( eager_implementation: EagerAllowed, + into_iter_16: IntoIterable, values: Sequence[Any], dtype: IntoDType, - into_iter: IntoIterable, request: pytest.FixtureRequest, ) -> None: name = "b" - iterable = into_iter(values) + iterable = into_iter_16(values) test_name = request.node.name request.applymarker( pytest.mark.xfail( diff --git a/tests/utils.py b/tests/utils.py index 4fc11492a1..0d39d06636 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,6 +4,7 @@ import os import sys import warnings +from collections.abc import Iterable from datetime import date, datetime from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, cast @@ -51,6 +52,8 @@ def get_module_version_as_tuple(module_name: str) -> tuple[int, ...]: NestedOrEnumDType: TypeAlias = "nw.List | nw.Array | nw.Struct | nw.Enum" """`DType`s which **cannot** be used as bare types.""" +IntoIterable: TypeAlias = Callable[..., Iterable[Any]] + ID_PANDAS_LIKE = frozenset( ("pandas", "pandas[nullable]", "pandas[pyarrow]", "modin", "modin[pyarrow]", "cudf") )