Skip to content

Commit f96130c

Browse files
authored
enh: allow for over(order_by=...) for rank (#2746)
1 parent 4e43f54 commit f96130c

File tree

13 files changed

+189
-57
lines changed

13 files changed

+189
-57
lines changed

narwhals/_duckdb/dataframe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,9 @@ def unpivot(
477477

478478
@requires.backend_version((1, 3))
479479
def with_row_index(self, name: str, order_by: Sequence[str]) -> Self:
480+
if order_by is None:
481+
msg = "Cannot pass `order_by` to `with_row_index` for DuckDB"
482+
raise TypeError(msg)
480483
expr = (window_expression(F("row_number"), order_by=order_by) - lit(1)).alias(
481484
name
482485
)

narwhals/_duckdb/expr.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ def func(df: DuckDBLazyFrame, inputs: DuckDBWindowInputs) -> list[Expression]:
112112
F(func_name, expr),
113113
inputs.partition_by,
114114
inputs.order_by,
115-
descending=reverse,
116-
nulls_last=reverse,
115+
descending=[reverse] * len(inputs.order_by),
116+
nulls_last=[reverse] * len(inputs.order_by),
117117
rows_start="unbounded preceding",
118118
rows_end="current row",
119119
)
@@ -564,8 +564,8 @@ def func(df: DuckDBLazyFrame, inputs: DuckDBWindowInputs) -> Sequence[Expression
564564
F("row_number"),
565565
(*inputs.partition_by, expr),
566566
inputs.order_by,
567-
descending=True,
568-
nulls_last=True,
567+
descending=[True] * len(inputs.order_by),
568+
nulls_last=[True] * len(inputs.order_by),
569569
)
570570
== lit(1)
571571
for expr in self(df)
@@ -731,16 +731,18 @@ def rank(self, method: RankMethod, *, descending: bool) -> Self:
731731

732732
def _rank(
733733
expr: Expression,
734+
partition_by: Sequence[str | Expression] = (),
735+
order_by: Sequence[str | Expression] = (),
734736
*,
735-
descending: bool,
736-
partition_by: Sequence[str | Expression],
737+
descending: Sequence[bool],
738+
nulls_last: Sequence[bool],
737739
) -> Expression:
738740
count_expr = F("count", StarExpression())
739741
window_kwargs: WindowExpressionKwargs = {
740742
"partition_by": partition_by,
741-
"order_by": (expr,),
743+
"order_by": (expr, *order_by),
742744
"descending": descending,
743-
"nulls_last": True,
745+
"nulls_last": nulls_last,
744746
}
745747
count_window_kwargs: WindowExpressionKwargs = {
746748
"partition_by": (*partition_by, expr)
@@ -760,14 +762,21 @@ def _rank(
760762
return when(expr.isnotnull(), rank_expr)
761763

762764
def _unpartitioned_rank(expr: Expression) -> Expression:
763-
return _rank(expr, partition_by=(), descending=descending)
765+
return _rank(expr, descending=[descending], nulls_last=[True])
764766

765767
def _partitioned_rank(
766768
df: DuckDBLazyFrame, inputs: DuckDBWindowInputs
767769
) -> Sequence[Expression]:
768-
assert not inputs.order_by # noqa: S101
770+
# node: when `descending` / `nulls_last` are supported in `.over`, they should be respected here
771+
# https://github.com/narwhals-dev/narwhals/issues/2790
769772
return [
770-
_rank(expr, descending=descending, partition_by=inputs.partition_by)
773+
_rank(
774+
expr,
775+
inputs.partition_by,
776+
inputs.order_by,
777+
descending=[descending] + [False] * len(inputs.order_by),
778+
nulls_last=[True] + [False] * len(inputs.order_by),
779+
)
771780
for expr in self(df)
772781
]
773782

narwhals/_duckdb/typing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@ class WindowExpressionKwargs(TypedDict, total=False):
1313
order_by: Sequence[str | Expression]
1414
rows_start: str
1515
rows_end: str
16-
descending: bool
17-
nulls_last: bool
16+
descending: Sequence[bool]
17+
nulls_last: Sequence[bool]
1818
ignore_nulls: bool

narwhals/_duckdb/utils.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
"us": "TIMESTAMP",
3737
"ns": "TIMESTAMP_NS",
3838
}
39+
DESCENDING_TO_ORDER = {True: "desc", False: "asc"}
40+
NULLS_LAST_TO_NULLS_POS = {True: "nulls last", False: "nulls first"}
3941

4042
col = duckdb.ColumnExpression
4143
"""Alias for `duckdb.ColumnExpression`."""
@@ -304,15 +306,14 @@ def generate_partition_by_sql(*partition_by: str | Expression) -> str:
304306

305307

306308
def generate_order_by_sql(
307-
*order_by: str | Expression, descending: bool, nulls_last: bool
309+
*order_by: str | Expression, descending: Sequence[bool], nulls_last: Sequence[bool]
308310
) -> str:
309311
if not order_by:
310312
return ""
311-
nulls = "nulls last" if nulls_last else "nulls first"
312-
if descending:
313-
by_sql = ", ".join([f"{parse_into_expression(x)} desc {nulls}" for x in order_by])
314-
else:
315-
by_sql = ", ".join([f"{parse_into_expression(x)} asc {nulls}" for x in order_by])
313+
by_sql = ",".join(
314+
f"{parse_into_expression(x)} {DESCENDING_TO_ORDER[_descending]} {NULLS_LAST_TO_NULLS_POS[_nulls_last]}"
315+
for x, _descending, _nulls_last in zip(order_by, descending, nulls_last)
316+
)
316317
return f"order by {by_sql}"
317318

318319

@@ -323,8 +324,8 @@ def window_expression(
323324
rows_start: str = "",
324325
rows_end: str = "",
325326
*,
326-
descending: bool = False,
327-
nulls_last: bool = False,
327+
descending: Sequence[bool] | None = None,
328+
nulls_last: Sequence[bool] | None = None,
328329
ignore_nulls: bool = False,
329330
) -> Expression:
330331
# TODO(unassigned): Replace with `duckdb.WindowExpression` when they release it.
@@ -335,6 +336,8 @@ def window_expression(
335336
msg = f"DuckDB>=1.3.0 is required for this operation. Found: DuckDB {duckdb.__version__}"
336337
raise NotImplementedError(msg) from exc
337338
pb = generate_partition_by_sql(*partition_by)
339+
descending = descending or [False] * len(order_by)
340+
nulls_last = nulls_last or [False] * len(order_by)
338341
ob = generate_order_by_sql(*order_by, descending=descending, nulls_last=nulls_last)
339342

340343
if rows_start and rows_end:

narwhals/_expression_parsing.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ class ExprKind(Enum):
131131
ORDERABLE_WINDOW = auto()
132132
"""Depends on the rows around it and on their order, e.g. `diff`."""
133133

134-
UNORDERABLE_WINDOW = auto()
135-
"""Depends on the rows around it but not on their order, e.g. `rank`."""
134+
WINDOW = auto()
135+
"""Depends on the rows around it and possibly their order, e.g. `rank`."""
136136

137137
FILTRATION = auto()
138138
"""Changes length, not affected by row order, e.g. `drop_nulls`."""
@@ -217,6 +217,22 @@ def __and__(self, other: ExpansionKind) -> Literal[ExpansionKind.MULTI_UNNAMED]:
217217

218218

219219
class ExprMetadata:
220+
"""Expression metadata.
221+
222+
Parameters:
223+
expansion_kind: What kind of expansion the expression performs.
224+
has_windows: Whether it already contains window functions.
225+
is_elementwise: Whether it can operate row-by-row without context
226+
of the other rows around it.
227+
is_literal: Whether it is just a literal wrapped in an expression.
228+
is_scalar_like: Whether it is a literal or an aggregation.
229+
last_node: The ExprKind of the last node.
230+
n_orderable_ops: The number of order-dependent operations. In the
231+
lazy case, this number must be `0` by the time the expression
232+
is evaluated.
233+
preserves_length: Whether the expression preserves the input length.
234+
"""
235+
220236
__slots__ = (
221237
"expansion_kind",
222238
"has_windows",
@@ -317,14 +333,17 @@ def with_elementwise_op(self) -> ExprMetadata:
317333
is_literal=self.is_literal,
318334
)
319335

320-
def with_unorderable_window(self) -> ExprMetadata:
336+
def with_window(self) -> ExprMetadata:
337+
# Window function which may (but doesn't have to) be used with `over(order_by=...)`.
321338
if self.is_scalar_like:
322-
msg = "Can't apply unorderable window (`rank`, `is_unique`) to scalar-like expression."
339+
msg = "Can't apply window (e.g. `rank`) to scalar-like expression."
323340
raise InvalidOperationError(msg)
324341
return ExprMetadata(
325342
self.expansion_kind,
326-
ExprKind.UNORDERABLE_WINDOW,
343+
ExprKind.WINDOW,
327344
has_windows=self.has_windows,
345+
# The function isn't order-dependent (but, users can still use `order_by` if they wish!),
346+
# so we don't increment `n_orderable_ops`.
328347
n_orderable_ops=self.n_orderable_ops,
329348
preserves_length=self.preserves_length,
330349
is_elementwise=False,
@@ -333,6 +352,7 @@ def with_unorderable_window(self) -> ExprMetadata:
333352
)
334353

335354
def with_orderable_window(self) -> ExprMetadata:
355+
# Window function which must be used with `over(order_by=...)`.
336356
if self.is_scalar_like:
337357
msg = "Can't apply orderable window (e.g. `diff`, `shift`) to scalar-like expression."
338358
raise InvalidOperationError(msg)
@@ -358,8 +378,16 @@ def with_ordered_over(self) -> ExprMetadata:
358378
)
359379
raise InvalidOperationError(msg)
360380
n_orderable_ops = self.n_orderable_ops
361-
if not n_orderable_ops:
362-
msg = "Cannot use `order_by` in `over` on expression which isn't orderable."
381+
if not n_orderable_ops and self.last_node is not ExprKind.WINDOW:
382+
msg = (
383+
"Cannot use `order_by` in `over` on expression which isn't orderable.\n"
384+
"If your expression is orderable, then make sure that `over(order_by=...)`\n"
385+
"comes immediately after the order-dependent expression.\n\n"
386+
"Hint: instead of\n"
387+
" - `(nw.col('price').diff() + 1).over(order_by='date')`\n"
388+
"write:\n"
389+
" + `nw.col('price').diff().over(order_by='date') + 1`\n"
390+
)
363391
raise InvalidOperationError(msg)
364392
if self.last_node.is_orderable_window:
365393
n_orderable_ops -= 1

narwhals/_ibis/expr.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,13 @@ def window_function(self) -> IbisWindowFunction:
6565
def default_window_func(
6666
df: IbisLazyFrame, window_inputs: IbisWindowInputs
6767
) -> list[ir.Value]:
68-
assert not window_inputs.order_by # noqa: S101
6968
return [
70-
expr.over(ibis.window(group_by=window_inputs.partition_by))
69+
expr.over(
70+
ibis.window(
71+
group_by=window_inputs.partition_by,
72+
order_by=self._sort(*window_inputs.order_by),
73+
)
74+
)
7175
for expr in self(df)
7276
]
7377

narwhals/_spark_like/dataframe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,9 @@ def unpivot(
509509
return self._with_native(unpivoted_native_frame)
510510

511511
def with_row_index(self, name: str, order_by: Sequence[str]) -> Self:
512+
if order_by is None:
513+
msg = "Cannot pass `order_by` to `with_row_index` for PySpark-like"
514+
raise TypeError(msg)
512515
row_index_expr = (
513516
self._F.row_number().over(
514517
self._Window.partitionBy(self._F.lit(1)).orderBy(*order_by)

narwhals/_spark_like/expr.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -129,17 +129,24 @@ def _Window(self) -> type[Window]: # noqa: N802
129129
return import_window(self._implementation)
130130

131131
def _sort(
132-
self, *cols: Column | str, descending: bool = False, nulls_last: bool = False
132+
self,
133+
*cols: Column | str,
134+
descending: Sequence[bool] | None = None,
135+
nulls_last: Sequence[bool] | None = None,
133136
) -> Iterator[Column]:
134137
F = self._F # noqa: N806
138+
descending = descending or [False] * len(cols)
139+
nulls_last = nulls_last or [False] * len(cols)
135140
mapping = {
136141
(False, False): F.asc_nulls_first,
137142
(False, True): F.asc_nulls_last,
138143
(True, False): F.desc_nulls_first,
139144
(True, True): F.desc_nulls_last,
140145
}
141-
sort = mapping[(descending, nulls_last)]
142-
yield from (sort(col) for col in cols)
146+
yield from (
147+
mapping[(_desc, _nulls_last)](col)
148+
for col, _desc, _nulls_last in zip(cols, descending, nulls_last)
149+
)
143150

144151
def partition_by(self, *cols: Column | str) -> WindowSpec:
145152
"""Wraps `Window().paritionBy`, with default and `WindowInputs` handling."""
@@ -178,7 +185,9 @@ def func(df: SparkLikeLazyFrame, inputs: SparkWindowInputs) -> Sequence[Column]:
178185
window = (
179186
self.partition_by(*inputs.partition_by)
180187
.orderBy(
181-
*self._sort(*inputs.order_by, descending=reverse, nulls_last=reverse)
188+
*self._sort(
189+
*inputs.order_by, descending=[reverse], nulls_last=[reverse]
190+
)
182191
)
183192
.rowsBetween(self._Window.unboundedPreceding, 0)
184193
)
@@ -695,7 +704,9 @@ def func(df: SparkLikeLazyFrame, inputs: SparkWindowInputs) -> Sequence[Column]:
695704
return [
696705
self._F.row_number().over(
697706
self.partition_by(*inputs.partition_by, expr).orderBy(
698-
*self._sort(*inputs.order_by, descending=True, nulls_last=True)
707+
*self._sort(
708+
*inputs.order_by, descending=[True], nulls_last=[True]
709+
)
699710
)
700711
)
701712
== 1
@@ -823,17 +834,17 @@ def rank(self, method: RankMethod, *, descending: bool) -> Self:
823834

824835
def _rank(
825836
expr: Column,
837+
partition_by: Sequence[str | Column] = (),
838+
order_by: Sequence[str | Column] = (),
826839
*,
827-
descending: bool,
828-
partition_by: Sequence[str | Column] | None = None,
840+
descending: Sequence[bool],
841+
nulls_last: Sequence[bool],
829842
) -> Column:
830-
order_by = self._sort(expr, descending=descending, nulls_last=True)
831-
if partition_by is not None:
832-
window = self.partition_by(*partition_by).orderBy(*order_by)
833-
count_window = self.partition_by(*partition_by, expr)
834-
else:
835-
window = self.partition_by().orderBy(*order_by)
836-
count_window = self.partition_by(expr)
843+
_order_by = self._sort(
844+
expr, *order_by, descending=descending, nulls_last=nulls_last
845+
)
846+
window = self.partition_by(*partition_by).orderBy(*_order_by)
847+
count_window = self.partition_by(*partition_by, expr)
837848
if method == "max":
838849
rank_expr = (
839850
getattr(self._F, func_name)().over(window)
@@ -852,14 +863,21 @@ def _rank(
852863
return self._F.when(expr.isNotNull(), rank_expr)
853864

854865
def _unpartitioned_rank(expr: Column) -> Column:
855-
return _rank(expr, descending=descending)
866+
return _rank(expr, descending=[descending], nulls_last=[True])
856867

857868
def _partitioned_rank(
858869
df: SparkLikeLazyFrame, inputs: SparkWindowInputs
859870
) -> Sequence[Column]:
860-
assert not inputs.order_by # noqa: S101
871+
# node: when `descending` / `nulls_last` are supported in `.over`, they should be respected here
872+
# https://github.com/narwhals-dev/narwhals/issues/2790
861873
return [
862-
_rank(expr, descending=descending, partition_by=inputs.partition_by)
874+
_rank(
875+
expr,
876+
inputs.partition_by,
877+
inputs.order_by,
878+
descending=[descending] + [False] * len(inputs.order_by),
879+
nulls_last=[True] + [False] * len(inputs.order_by),
880+
)
863881
for expr in self(df)
864882
]
865883

narwhals/expr.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ def _with_orderable_aggregation(
7979
def _with_orderable_window(self, to_compliant_expr: Callable[[Any], Any]) -> Self:
8080
return self.__class__(to_compliant_expr, self._metadata.with_orderable_window())
8181

82-
def _with_unorderable_window(self, to_compliant_expr: Callable[[Any], Any]) -> Self:
83-
return self.__class__(to_compliant_expr, self._metadata.with_unorderable_window())
82+
def _with_window(self, to_compliant_expr: Callable[[Any], Any]) -> Self:
83+
return self.__class__(to_compliant_expr, self._metadata.with_window())
8484

8585
def _with_filtration(self, to_compliant_expr: Callable[[Any], Any]) -> Self:
8686
return self.__class__(to_compliant_expr, self._metadata.with_filtration())
@@ -1646,9 +1646,7 @@ def is_unique(self) -> Self:
16461646
|3 1 c False True|
16471647
└─────────────────────────────────┘
16481648
"""
1649-
return self._with_unorderable_window(
1650-
lambda plx: self._to_compliant_expr(plx).is_unique()
1651-
)
1649+
return self._with_window(lambda plx: self._to_compliant_expr(plx).is_unique())
16521650

16531651
def null_count(self) -> Self:
16541652
r"""Count null values.
@@ -2472,7 +2470,7 @@ def rank(self, method: RankMethod = "average", *, descending: bool = False) -> S
24722470
)
24732471
raise ValueError(msg)
24742472

2475-
return self._with_unorderable_window(
2473+
return self._with_window(
24762474
lambda plx: self._to_compliant_expr(plx).rank(
24772475
method=method, descending=descending
24782476
)

narwhals/stable/v1/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self:
405405
Returns:
406406
A new expression.
407407
"""
408-
return self._with_unorderable_window(
408+
return self._with_window(
409409
lambda plx: self._to_compliant_expr(plx).sort(
410410
descending=descending, nulls_last=nulls_last
411411
)

0 commit comments

Comments
 (0)