Skip to content

Commit 1642744

Browse files
authored
perf: Only use first element to sniff types in sequences in __getitem__ and filter (#2384)
* perf: Only use first element to sniff types in sequences in `__getitem__` and `filter` * use is_list_of
1 parent 41871f7 commit 1642744

File tree

6 files changed

+18
-25
lines changed

6 files changed

+18
-25
lines changed

narwhals/_arrow/dataframe.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -355,11 +355,7 @@ def __getitem__(
355355
return self._with_native(self.native.slice(start, stop - start))
356356

357357
elif isinstance(item, Sequence) or is_numpy_array_1d(item):
358-
if (
359-
isinstance(item, Sequence)
360-
and all(isinstance(x, str) for x in item)
361-
and len(item) > 0
362-
):
358+
if isinstance(item, Sequence) and len(item) > 0 and isinstance(item[0], str):
363359
return self._with_native(self.native.select(cast("Indices", item)))
364360
if isinstance(item, Sequence) and len(item) == 0:
365361
return self._with_native(self.native.slice(0, 0))

narwhals/_arrow/series.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from narwhals.utils import Implementation
3434
from narwhals.utils import generate_temporary_column_name
3535
from narwhals.utils import import_dtypes_module
36+
from narwhals.utils import is_list_of
3637
from narwhals.utils import not_implemented
3738
from narwhals.utils import validate_backend_version
3839

@@ -306,9 +307,7 @@ def len(self: Self, *, _return_py_scalar: bool = True) -> int:
306307
return maybe_extract_py_scalar(len(self.native), _return_py_scalar)
307308

308309
def filter(self: Self, predicate: ArrowSeries | list[bool | None]) -> Self:
309-
if not (
310-
isinstance(predicate, list) and all(isinstance(x, bool) for x in predicate)
311-
):
310+
if not is_list_of(predicate, bool):
312311
_, other_native = extract_native(self, predicate)
313312
else:
314313
other_native = predicate

narwhals/_pandas_like/dataframe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -331,11 +331,11 @@ def __getitem__(
331331
return self._with_native(
332332
self.native.__class__(), validate_column_names=False
333333
)
334-
if all(isinstance(x, int) for x in item[1]): # type: ignore[var-annotated]
334+
if isinstance(item[1][0], int):
335335
return self._with_native(
336336
self.native.iloc[item], validate_column_names=False
337337
)
338-
if all(isinstance(x, str) for x in item[1]): # type: ignore[var-annotated]
338+
if isinstance(item[1][0], str):
339339
indexer = (
340340
item[0],
341341
self.native.columns.get_indexer(item[1]),
@@ -383,7 +383,7 @@ def __getitem__(
383383
return PandasLikeSeries.from_native(native_series, context=self)
384384

385385
elif is_sequence_but_not_str(item) or is_numpy_array_1d(item):
386-
if all(isinstance(x, str) for x in item) and len(item) > 0:
386+
if len(item) > 0 and isinstance(item[0], str):
387387
return self._with_native(
388388
select_columns_by_name(
389389
self.native,

narwhals/_pandas_like/series.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from narwhals.exceptions import InvalidOperationError
3232
from narwhals.utils import Implementation
3333
from narwhals.utils import import_dtypes_module
34+
from narwhals.utils import is_list_of
3435
from narwhals.utils import parse_version
3536
from narwhals.utils import validate_backend_version
3637

@@ -376,9 +377,7 @@ def arg_max(self: Self) -> int:
376377
# Binary comparisons
377378

378379
def filter(self: Self, predicate: Any) -> PandasLikeSeries:
379-
if not (
380-
isinstance(predicate, list) and all(isinstance(x, bool) for x in predicate)
381-
):
380+
if not is_list_of(predicate, bool):
382381
_, other_native = align_and_extract_native(self, predicate)
383382
else:
384383
other_native = predicate

narwhals/dataframe.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from narwhals.utils import flatten
3535
from narwhals.utils import generate_repr
3636
from narwhals.utils import is_compliant_lazyframe
37+
from narwhals.utils import is_list_of
3738
from narwhals.utils import is_sequence_but_not_str
3839
from narwhals.utils import issue_deprecation_warning
3940
from narwhals.utils import parse_version
@@ -192,11 +193,9 @@ def filter(
192193
*predicates: IntoExpr | Iterable[IntoExpr] | list[bool],
193194
**constraints: Any,
194195
) -> Self:
195-
if not (
196-
len(predicates) == 1
197-
and isinstance(predicates[0], list)
198-
and all(isinstance(x, bool) for x in predicates[0])
199-
):
196+
if len(predicates) == 1 and is_list_of(predicates[0], bool):
197+
predicate = predicates[0]
198+
else:
200199
from narwhals.functions import col
201200

202201
flat_predicates = flatten(predicates)
@@ -210,8 +209,6 @@ def filter(
210209
predicate = plx.all_horizontal(
211210
*chain(compliant_predicates, compliant_constraints)
212211
)
213-
else:
214-
predicate = predicates[0]
215212
return self._with_compliant(self._compliant_frame.filter(predicate))
216213

217214
def sort(
@@ -2786,10 +2783,7 @@ def filter(
27862783
<BLANKLINE>
27872784
"""
27882785
if (
2789-
len(predicates) == 1
2790-
and isinstance(predicates[0], list)
2791-
and all(isinstance(x, bool) for x in predicates[0])
2792-
and not constraints
2786+
len(predicates) == 1 and is_list_of(predicates[0], bool) and not constraints
27932787
): # pragma: no cover
27942788
msg = "`LazyFrame.filter` is not supported with Python boolean masks - use expressions instead."
27952789
raise TypeError(msg)

narwhals/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,6 +1271,11 @@ def is_sequence_but_not_str(sequence: Any | Sequence[_T]) -> TypeIs[Sequence[_T]
12711271
return isinstance(sequence, Sequence) and not isinstance(sequence, str)
12721272

12731273

1274+
def is_list_of(obj: Any, tp: type[_T]) -> TypeIs[list[type[_T]]]:
1275+
# Check if an object is a list of `tp`, only sniffing the first element.
1276+
return bool(isinstance(obj, list) and obj and isinstance(obj[0], tp))
1277+
1278+
12741279
def find_stacklevel() -> int:
12751280
"""Find the first place in the stack that is not inside narwhals.
12761281

0 commit comments

Comments
 (0)