Skip to content

Commit 2cba4a0

Browse files
authored
perf: Rewrite with_row_index so it's faster for pandas/pyarrow and doesn't use rank (#3239)
* wip * wip * shiny ci
1 parent b7abaf9 commit 2cba4a0

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

narwhals/_arrow/dataframe.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -502,17 +502,18 @@ def to_dict(
502502
return {ser.name: ser.to_list() for ser in it}
503503

504504
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
505-
plx = self.__narwhals_namespace__()
506-
if order_by is None:
507-
import numpy as np # ignore-banned-import
505+
import numpy as np # ignore-banned-import
508506

509-
data = pa.array(np.arange(len(self), dtype=np.int64))
507+
plx = self.__narwhals_namespace__()
508+
data = pa.array(np.arange(len(self)))
509+
row_index_s = plx._series.from_iterable(data, context=self, name=name)
510+
row_index = plx._expr._from_series(row_index_s)
511+
if order_by:
510512
row_index = plx._expr._from_series(
511-
plx._series.from_iterable(data, context=self, name=name)
513+
self.select(row_index, *(plx.col(x) for x in order_by))
514+
.sort(*order_by, descending=False, nulls_last=False)
515+
.get_column(name)
512516
)
513-
else:
514-
rank = plx.col(order_by[0]).rank("ordinal", descending=False)
515-
row_index = (rank.over(partition_by=[], order_by=order_by) - 1).alias(name)
516517
return self.select(row_index, plx.all())
517518

518519
def filter(self, predicate: ArrowExpr) -> Self:

narwhals/_pandas_like/dataframe.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -444,18 +444,17 @@ def estimated_size(self, unit: SizeUnit) -> int | float:
444444

445445
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
446446
plx = self.__narwhals_namespace__()
447-
if order_by is None:
448-
size = len(self)
449-
data = self._array_funcs.arange(size)
450-
447+
data = self._array_funcs.arange(len(self))
448+
row_index_s = plx._series.from_iterable(
449+
data, context=self, index=self.native.index, name=name
450+
)
451+
row_index = plx._expr._from_series(row_index_s)
452+
if order_by:
451453
row_index = plx._expr._from_series(
452-
plx._series.from_iterable(
453-
data, context=self, index=self.native.index, name=name
454-
)
454+
self.select(row_index, *(plx.col(x) for x in order_by))
455+
.sort(*order_by, descending=False, nulls_last=False)
456+
.get_column(name)
455457
)
456-
else:
457-
rank = plx.col(order_by[0]).rank(method="ordinal", descending=False)
458-
row_index = (rank.over(partition_by=[], order_by=order_by) - 1).alias(name)
459458
return self.select(row_index, plx.all())
460459

461460
def row(self, index: int) -> tuple[Any, ...]:

0 commit comments

Comments
 (0)