Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions narwhals/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
is_numpy_array_1d_int,
is_pandas_like_dataframe,
is_pandas_like_series,
is_polars_series,
)
from narwhals.exceptions import ColumnNotFoundError, DuplicateError, InvalidOperationError

Expand Down Expand Up @@ -124,6 +125,7 @@
CompliantLazyFrame,
CompliantSeries,
DTypes,
EagerAllowed,
FileSource,
IntoSeriesT,
MultiIndexSelector,
Expand Down Expand Up @@ -2120,3 +2122,58 @@ def extend_bool(
Stolen from https://github.com/pola-rs/polars/blob/b8bfb07a4a37a8d449d6d1841e345817431142df/py-polars/polars/_utils/various.py#L580-L594
"""
return (value,) * n_match if isinstance(value, bool) else tuple(value)


class _CanTo_List(Protocol): # noqa: N801
def to_list(self, *args: Any, **kwds: Any) -> list[Any]: ...


class _CanToList(Protocol):
def tolist(self, *args: Any, **kwds: Any) -> list[Any]: ...
Comment on lines +2127 to +2132
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW I hate this as well πŸ˜‚

Copy link
Member Author

@dangotbanned dangotbanned Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a better idea would be to ...

Rename _CanTo_List -> ToList, and move to _translate.py alongside:

class ToDict(Protocol[ToDictDT_co]):
def to_dict(self, *args: Any, **kwds: Any) -> ToDictDT_co: ...

Move these as well, but rename to reflect they their naming originates from numpy and pyarrow (respectively):

narwhals/narwhals/_utils.py

Lines 2131 to 2132 in 0c66432

class _CanToList(Protocol):
def tolist(self, *args: Any, **kwds: Any) -> list[Any]: ...

narwhals/narwhals/_utils.py

Lines 2135 to 2136 in 0c66432

class _CanTo_PyList(Protocol): # noqa: N801
def to_pylist(self, *args: Any, **kwds: Any) -> list[Any]: ...

The names of the guards can still stay the same, since their implementations will (after updating the protocol names) the link between origin, protocol, method name:

narwhals/narwhals/_utils.py

Lines 2139 to 2144 in 0c66432

def can_to_list(obj: Any) -> TypeIs[_CanTo_List]:
return (
is_narwhals_series(obj)
or is_polars_series(obj)
or _hasattr_static(obj, "to_list")
)

narwhals/narwhals/_utils.py

Lines 2147 to 2148 in 0c66432

def can_tolist(obj: Any) -> TypeIs[_CanToList]:
return is_numpy_array_1d(obj) or _hasattr_static(obj, "tolist")

narwhals/narwhals/_utils.py

Lines 2151 to 2154 in 0c66432

def can_to_pylist(obj: Any) -> TypeIs[_CanTo_PyList]:
return (
(pa := get_pyarrow()) and isinstance(obj, (pa.Array, pa.ChunkedArray))
) or _hasattr_static(obj, "to_pylist")

Copy link
Member

@FBruzzesi FBruzzesi Oct 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated my comment (#3207 (comment)) - I would argue that native series are not ok, while numpy 1d arrays are.

My thought process for this is that if someone is doing something along the lines of:

import narwhals as nw
import polars as pl

def agnostic_func(frame: IntoDataFrameT) -> IntoDataFrameT:
    other = pl.Series([1, 2, 3])  # <- notice how this is a native series!!!
    return nw.from_native(frame).filter(nw.col("x").is_in(other)).to_native()

then the function is clearly not agnostic and polars would be required in this case.

A different case would be if a narwhals series with a different backend is provided. This could mean that the function is agnostic but a user is "mixin" backends:

def is_left_in_right(left_series: IntoSeriesT, right_series: IntoSeriesT) -> IntoSeriesT:
    left_nw = nw.from_native(left_series, series_only=True)
    right_nw = nw.from_native(right_series, series_only=True)
    return left_nw.is_in(right_nw).to_native()

# but now it a user to mix it up, not the library itself

is_left_in_right(pl.Series([1,2,3]), pd.Series([0, 1]))

This is the case I suggested to yell at the user with a warning.


From our side, I think it would greatly simplify (read as, get rid of) most of the protocols here, the type guards as well as iterable_to_sequence function.



class _CanTo_PyList(Protocol): # noqa: N801
def to_pylist(self, *args: Any, **kwds: Any) -> list[Any]: ...


def can_to_list(obj: Any) -> TypeIs[_CanTo_List]:
return (
is_narwhals_series(obj)
or is_polars_series(obj)
or _hasattr_static(obj, "to_list")
)


def can_tolist(obj: Any) -> TypeIs[_CanToList]:
return is_numpy_array_1d(obj) or _hasattr_static(obj, "tolist")


def can_to_pylist(obj: Any) -> TypeIs[_CanTo_PyList]:
return (
(pa := get_pyarrow()) and isinstance(obj, (pa.Array, pa.ChunkedArray))
) or _hasattr_static(obj, "to_pylist")


# TODO @dangotbanned: Use in `{Expr,Series}.is_in`
# TODO @dangotbanned: Add (brief) doc
def iterable_to_sequence(
iterable: Iterable[Any], /, *, backend: EagerAllowed | None = None
) -> Sequence[Any]:
result: Sequence[Any]
if backend is not None:
from narwhals.series import Series

result = Series.from_iterable("", iterable, backend=backend).to_list()
elif isinstance(iterable, (tuple, list)):
result = iterable
elif isinstance(iterable, (Iterator, Sequence)):
result = tuple(iterable)
elif can_to_list(iterable):
result = iterable.to_list()
elif can_tolist(iterable):
result = iterable.tolist()
elif can_to_pylist(iterable):
result = iterable.to_pylist()
else:
result = tuple(iterable)
return result
14 changes: 9 additions & 5 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@
ExprMetadata,
apply_n_ary_operation,
combine_metadata,
is_series,
)
from narwhals._utils import (
_validate_rolling_arguments,
ensure_type,
flatten,
iterable_to_sequence,
)
from narwhals._utils import _validate_rolling_arguments, ensure_type, flatten
from narwhals.dtypes import _validate_dtype
from narwhals.exceptions import ComputeError, InvalidOperationError
from narwhals.expr_cat import ExprCatNamespace
Expand All @@ -19,7 +25,6 @@
from narwhals.expr_name import ExprNameNamespace
from narwhals.expr_str import ExprStringNamespace
from narwhals.expr_struct import ExprStructNamespace
from narwhals.translate import to_native

if TYPE_CHECKING:
from typing import NoReturn, TypeVar
Expand Down Expand Up @@ -991,10 +996,9 @@ def is_in(self, other: Any) -> Self:
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
if isinstance(other, Iterable) and not isinstance(other, (str, bytes)):
other = other.to_native() if is_series(other) else iterable_to_sequence(other)
return self._with_elementwise(
lambda plx: self._to_compliant_expr(plx).is_in(
to_native(other, pass_through=True)
)
lambda plx: self._to_compliant_expr(plx).is_in(other)
Copy link
Member Author

@dangotbanned dangotbanned Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated this but it is still wrong.
Didn't notice the issue until I started trying to add typing to the compliant-level

def is_in(self, other: Any) -> Self: ...

I don't understand why we were allowing any kind of Native* to be passed to every backend? πŸ€”

Gonna add another test for at least the more common case of a Series from the wrong backend

Updated

test: Add test_expr_is_in_series_wrong_backend

I'm not sure of a safe way to keep the same behavior (if it is desired).
Since we don't know the backend at this stage, the options I see are:

  1. Unconditionally convert to list | tuple
  2. Raise elsewhere when a Series is passed to a lazy backend
  3. Disallow Expr.is_in(other: Series)
  4. Do nothing

Copy link
Member Author

@dangotbanned dangotbanned Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarcoGorelli @FBruzzesi
do you guys have any preference on a path forward here?

This case is a bit different to some of the other places we'd reject nw.Series for lazy backends - since the length isn't an issue.
But the safer option of converting all nw.Series is gonna be less performant than the currently unsafe version (which only works on a matching eager implementation)

Everything seems like a tradeoff to me πŸ˜”

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved out of draft for visibility (#3207 (comment))

Copy link
Member

@FBruzzesi FBruzzesi Oct 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @dangotbanned I will try to take a look during the weekend, but I am not sure I can manage.

My thoughts from the context in the messages here (I didn't open the code changes just yet):

  • Realistically I don't think it's too common to pass one series from a different backend but we should not exclude such possibility
  • My opinion would be to:
    • Use the native series if isinstance(other, Series) and expr._implementation == other._implementation)
    • Otherwise convert it to a list (we can do that with native methods thankfully), but warn that such conversion is happening with a UserWarning. This case includes a series passed to a lazy backend

Update: Regarding native series, then I would prefer to raise an exception in such case

)
msg = "Narwhals `is_in` doesn't accept expressions as an argument, as opposed to Polars. You should provide an iterable instead."
raise NotImplementedError(msg)
Expand Down
9 changes: 6 additions & 3 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
is_compliant_series,
is_eager_allowed,
is_index_selector,
iterable_to_sequence,
qualified_type_name,
supports_arrow_c_stream,
)
Expand All @@ -25,7 +26,6 @@
from narwhals.series_list import SeriesListNamespace
from narwhals.series_str import SeriesStringNamespace
from narwhals.series_struct import SeriesStructNamespace
from narwhals.translate import to_native
from narwhals.typing import IntoSeriesT

if TYPE_CHECKING:
Expand Down Expand Up @@ -948,9 +948,12 @@ def is_in(self, other: Any) -> Self:
]
]
"""
return self._with_compliant(
self._compliant_series.is_in(to_native(other, pass_through=True))
other = (
other.to_native()
if isinstance(other, Series)
else iterable_to_sequence(other, backend=self.implementation)
)
return self._with_compliant(self._compliant_series.is_in(other))

def arg_true(self) -> Self:
"""Find elements where boolean Series is True.
Expand Down
122 changes: 118 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,30 @@

import os
import uuid
from collections import deque
from copy import deepcopy
from functools import lru_cache
from importlib.util import find_spec
from typing import TYPE_CHECKING, Any, Callable, cast

import pytest

from narwhals._utils import Implementation, generate_temporary_column_name
from narwhals._utils import (
Implementation,
generate_temporary_column_name,
qualified_type_name,
)
from tests.utils import ID_PANDAS_LIKE, PANDAS_VERSION, pyspark_session, sqlframe_session

if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import (
Generator,
Iterable,
Iterator,
KeysView,
Sequence,
ValuesView,
)

import duckdb
import ibis
Expand All @@ -27,7 +39,7 @@
from narwhals._spark_like.dataframe import SQLFrameDataFrame
from narwhals._typing import EagerAllowed
from narwhals.typing import NativeDataFrame, NativeLazyFrame
from tests.utils import Constructor, ConstructorEager, ConstructorLazy
from tests.utils import Constructor, ConstructorEager, ConstructorLazy, IntoIterable

Data: TypeAlias = "dict[str, list[Any]]"

Expand Down Expand Up @@ -317,7 +329,109 @@ def eager_backend(request: pytest.FixtureRequest) -> EagerAllowed:
return request.param # type: ignore[no-any-return]


@pytest.fixture(params=[el for el in TEST_EAGER_BACKENDS if not isinstance(el, str)])
@pytest.fixture(
params=[el for el in TEST_EAGER_BACKENDS if not isinstance(el, str)], scope="session"
)
def eager_implementation(request: pytest.FixtureRequest) -> EagerAllowed:
"""Use if a test is heavily parametric, skips `str` backend."""
return request.param # type: ignore[no-any-return]


class UserDefinedIterable:
def __init__(self, iterable: Iterable[Any]) -> None:
self.iterable: Iterable[Any] = iterable

def __iter__(self) -> Iterator[Any]:
yield from self.iterable


def generator_function(iterable: Iterable[Any]) -> Generator[Any, Any, None]:
yield from iterable


def generator_expression(iterable: Iterable[Any]) -> Generator[Any, None, None]:
return (element for element in iterable)


def dict_keys(iterable: Iterable[Any]) -> KeysView[Any]:
return dict.fromkeys(iterable).keys()


def dict_values(iterable: Iterable[Any]) -> ValuesView[Any]:
return dict(enumerate(iterable)).values()


def chunked_array(iterable: Any) -> Iterable[Any]:
import pyarrow as pa

return pa.chunked_array([iterable])


def _ids_into_iter(obj: Any) -> str:
module: str = ""
if (obj_module := obj.__module__) and obj_module != __name__:
module = obj.__module__
name = qualified_type_name(obj)
if name in {"function", "builtin_function_or_method"} or "_cython" in name:
return f"{module}.{obj.__qualname__}" if module else obj.__qualname__
return name.removeprefix(__name__).strip(".")


def _build_into_iter() -> Iterator[IntoIterable]: # pragma: no cover
yield from (
# 1-4: should cover `Iterable`, `Sequence`, `Iterator`
list,
tuple,
iter,
deque,
# 5-6: cover `Generator`
generator_function,
generator_expression,
# 7-8: `Iterable`, but quite commonly cause issues upstream as they are `Sized` but not `Sequence`
dict_keys,
dict_values,
# 9: duck typing
UserDefinedIterable,
)
# 10: 1D numpy
if find_spec("numpy"):
import numpy as np

yield np.array
# 11-13: 1D pandas
if find_spec("pandas"):
import pandas as pd

yield from (pd.Index, pd.array, pd.Series)
# 14: 1D polars
if find_spec("polars"):
import polars as pl

yield pl.Series
# 15-16: 1D pyarrow
if find_spec("pyarrow"):
import pyarrow as pa

yield from (pa.array, chunked_array)


def _into_iter_selector() -> Callable[[int], Iterator[IntoIterable]]:
callables = tuple(_build_into_iter())

def pick(n: int, /) -> Iterator[IntoIterable]:
yield from callables[:n]

return pick


_into_iter: Callable[[int], Iterator[IntoIterable]] = _into_iter_selector()
"""`into_iter` fixtures use the suffix `_<n>` to denote the maximum number of constructors.

Anything greater than **10** may return less depending on available dependencies.
"""


@pytest.fixture(params=_into_iter(16), scope="session", ids=_ids_into_iter)
def into_iter_16(request: pytest.FixtureRequest) -> IntoIterable:
function: IntoIterable = request.param
return function
Comment on lines +482 to +492
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kinda outdated now since (633a06d)

45 changes: 44 additions & 1 deletion tests/expr_and_series/is_in_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
import pytest

import narwhals as nw
from tests.utils import Constructor, ConstructorEager, assert_equal_data
from tests.utils import (
Constructor,
ConstructorEager,
IntoIterable,
assert_equal_data,
assert_equal_series,
)

data = {"a": [1, 4, 2, 5]}

Expand Down Expand Up @@ -50,3 +56,40 @@ def test_filter_is_in_with_series(constructor_eager: ConstructorEager) -> None:
result = df.filter(nw.col("a").is_in(df["b"]))
expected = {"a": [1, 2], "b": [1, 2]}
assert_equal_data(result, expected)


@pytest.mark.slow
def test_expr_is_in_iterable(
constructor: Constructor, into_iter_16: IntoIterable
) -> None:
df = nw.from_native(constructor(data))
expected = {"a": [False, True, True, False]}
iterable = into_iter_16((4, 2))
expr = nw.col("a").is_in(iterable)
result = df.select(expr)
assert_equal_data(result, expected)
# NOTE: For an `Iterator`, this will fail if we haven't collected it first
repeated = df.select(expr)
assert_equal_data(repeated, expected)


@pytest.mark.slow
def test_ser_is_in_iterable(
constructor_eager: ConstructorEager,
into_iter_16: IntoIterable,
request: pytest.FixtureRequest,
) -> None:
test_name = request.node.name
# NOTE: This *could* be supported by using `ExtensionArray.tolist` (same path as numpy)
request.applymarker(
pytest.mark.xfail(
("polars" in test_name and "pandas" in test_name and "array" in test_name),
raises=TypeError,
reason="Polars doesn't support `pd.array`.\nhttps://github.com/pola-rs/polars/issues/22757",
)
)
iterable = into_iter_16((4, 2))
ser = nw.from_native(constructor_eager(data)).get_column("a")
result = ser.is_in(iterable)
expected = [False, True, True, False]
assert_equal_series(result, expected, "a")
Loading
Loading