|
19 | 19 | from narwhals._duckdb.expr_list import DuckDBExprListNamespace |
20 | 20 | from narwhals._duckdb.expr_str import DuckDBExprStringNamespace |
21 | 21 | from narwhals._duckdb.expr_struct import DuckDBExprStructNamespace |
| 22 | +from narwhals._duckdb.utils import UnorderableWindowInputs |
22 | 23 | from narwhals._duckdb.utils import WindowInputs |
23 | 24 | from narwhals._duckdb.utils import col |
24 | 25 | from narwhals._duckdb.utils import ensure_type |
|
41 | 42 | from narwhals._compliant.typing import EvalSeries |
42 | 43 | from narwhals._duckdb.dataframe import DuckDBLazyFrame |
43 | 44 | from narwhals._duckdb.namespace import DuckDBNamespace |
| 45 | + from narwhals._duckdb.typing import UnorderableWindowFunction |
44 | 46 | from narwhals._duckdb.typing import WindowFunction |
45 | 47 | from narwhals._expression_parsing import ExprMetadata |
46 | 48 | from narwhals.dtypes import DType |
@@ -79,6 +81,10 @@ def __init__( |
79 | 81 | # This can only be set by `_with_window_function`. |
80 | 82 | self._window_function: WindowFunction | None = None |
81 | 83 |
|
| 84 | + # These can only be set by `_with_unorderable_window_function` |
| 85 | + self._unorderable_window_function: UnorderableWindowFunction | None = None |
| 86 | + self._previous_call: EvalSeries[DuckDBLazyFrame, duckdb.Expression] | None = None |
| 87 | + |
82 | 88 | def __call__(self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]: |
83 | 89 | return self._call(df) |
84 | 90 |
|
@@ -263,6 +269,22 @@ def _with_window_function(self, window_function: WindowFunction) -> Self: |
263 | 269 | result._window_function = window_function |
264 | 270 | return result |
265 | 271 |
|
| 272 | + def _with_unorderable_window_function( |
| 273 | + self, |
| 274 | + unorderable_window_function: UnorderableWindowFunction, |
| 275 | + previous_call: EvalSeries[DuckDBLazyFrame, duckdb.Expression], |
| 276 | + ) -> Self: |
| 277 | + result = self.__class__( |
| 278 | + self._call, |
| 279 | + evaluate_output_names=self._evaluate_output_names, |
| 280 | + alias_output_names=self._alias_output_names, |
| 281 | + backend_version=self._backend_version, |
| 282 | + version=self._version, |
| 283 | + ) |
| 284 | + result._unorderable_window_function = unorderable_window_function |
| 285 | + result._previous_call = previous_call |
| 286 | + return result |
| 287 | + |
266 | 288 | @classmethod |
267 | 289 | def _alias_native(cls, expr: duckdb.Expression, name: str) -> duckdb.Expression: |
268 | 290 | return expr.alias(name) |
@@ -495,6 +517,19 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: |
495 | 517 | window_function(WindowInputs(expr, partition_by, order_by)) |
496 | 518 | for expr in self._call(df) |
497 | 519 | ] |
| 520 | + elif ( |
| 521 | + unorderable_window_function := self._unorderable_window_function |
| 522 | + ) is not None: |
| 523 | + assert order_by is None # noqa: S101 |
| 524 | + |
| 525 | + def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: |
| 526 | + assert self._previous_call is not None # noqa: S101 |
| 527 | + return [ |
| 528 | + unorderable_window_function( |
| 529 | + UnorderableWindowInputs(expr, partition_by) |
| 530 | + ) |
| 531 | + for expr in self._previous_call(df) |
| 532 | + ] |
498 | 533 | else: |
499 | 534 | partition_by_sql = generate_partition_by_sql(*partition_by) |
500 | 535 | template = f"{{expr}} over ({partition_by_sql})" |
@@ -728,30 +763,54 @@ def rank(self, method: RankMethod, *, descending: bool) -> Self: |
728 | 763 | else: # method == "ordinal" |
729 | 764 | func = FunctionExpression("row_number") |
730 | 765 |
|
731 | | - def _rank(_input: duckdb.Expression) -> duckdb.Expression: |
732 | | - if descending: |
733 | | - by_sql = f"{_input} desc nulls last" |
734 | | - else: |
735 | | - by_sql = f"{_input} asc nulls last" |
736 | | - order_by_sql = f"order by {by_sql}" |
| 766 | + def _rank( |
| 767 | + _input: duckdb.Expression, |
| 768 | + *, |
| 769 | + descending: bool, |
| 770 | + partition_by: Sequence[str | duckdb.Expression] | None = None, |
| 771 | + ) -> duckdb.Expression: |
| 772 | + order_by_sql = ( |
| 773 | + f"order by {_input} desc nulls last" |
| 774 | + if descending |
| 775 | + else f"order by {_input} asc nulls last" |
| 776 | + ) |
737 | 777 | count_expr = FunctionExpression("count", StarExpression()) |
738 | | - |
| 778 | + if partition_by is not None: |
| 779 | + window = f"{generate_partition_by_sql(*partition_by)} {order_by_sql}" |
| 780 | + count_window = f"{generate_partition_by_sql(*partition_by, _input)}" |
| 781 | + else: |
| 782 | + window = order_by_sql |
| 783 | + count_window = generate_partition_by_sql(_input) |
739 | 784 | if method == "max": |
740 | 785 | expr = ( |
741 | | - SQLExpression(f"{func} OVER ({order_by_sql})") |
742 | | - + SQLExpression(f"{count_expr} OVER (PARTITION BY {_input})") |
| 786 | + SQLExpression(f"{func} OVER ({window})") |
| 787 | + + SQLExpression(f"{count_expr} over ({count_window})") |
743 | 788 | - lit(1) |
744 | 789 | ) |
745 | 790 | elif method == "average": |
746 | | - expr = SQLExpression(f"{func} OVER ({order_by_sql})") + ( |
747 | | - SQLExpression(f"{count_expr} OVER (PARTITION BY {_input})") - lit(1) |
| 791 | + expr = SQLExpression(f"{func} OVER ({window})") + ( |
| 792 | + SQLExpression(f"{count_expr} over ({count_window})") - lit(1) |
748 | 793 | ) / lit(2.0) |
749 | 794 | else: |
750 | | - expr = SQLExpression(f"{func} OVER ({order_by_sql})") |
751 | | - |
| 795 | + expr = SQLExpression(f"{func} OVER ({window})") |
752 | 796 | return when(_input.isnotnull(), expr) |
753 | 797 |
|
754 | | - return self._with_callable(_rank) |
| 798 | + def _unpartitioned_rank(_input: duckdb.Expression) -> duckdb.Expression: |
| 799 | + return _rank(_input, descending=descending) |
| 800 | + |
| 801 | + def _partitioned_rank( |
| 802 | + window_inputs: UnorderableWindowInputs, |
| 803 | + ) -> duckdb.Expression: |
| 804 | + return _rank( |
| 805 | + window_inputs.expr, |
| 806 | + descending=descending, |
| 807 | + partition_by=window_inputs.partition_by, |
| 808 | + ) |
| 809 | + |
| 810 | + return self._with_callable(_unpartitioned_rank)._with_unorderable_window_function( |
| 811 | + _partitioned_rank, |
| 812 | + self._call, |
| 813 | + ) |
755 | 814 |
|
756 | 815 | def log(self, base: float) -> Self: |
757 | 816 | def _log(_input: duckdb.Expression) -> duckdb.Expression: |
|
0 commit comments