Skip to content

Commit 100ee97

Browse files
MarcoGorellidangotbannedpre-commit-ci[bot]
authored
chore: share more WindowInputs code (#2600)
--------- Co-authored-by: dangotbanned <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 55cdea6 commit 100ee97

File tree

14 files changed

+187
-245
lines changed

14 files changed

+187
-245
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,5 @@ help: ## Display this help screen
2121
.PHONY: typing
2222
typing: ## Run typing checks
2323
$(VENV_BIN)/uv pip install -e . --group typing
24-
$(VENV_BIN)/mypy
2524
$(VENV_BIN)/pyright
25+
$(VENV_BIN)/mypy

narwhals/_compliant/typing.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
)
2020
from narwhals._compliant.namespace import CompliantNamespace, EagerNamespace
2121
from narwhals._compliant.series import CompliantSeries, EagerSeries
22+
from narwhals._compliant.window import UnorderableWindowInputs, WindowInputs
2223
from narwhals.typing import FillNullStrategy, NativeFrame, NativeSeries, RankMethod
2324

2425
class ScalarKwargs(TypedDict, total=False):
@@ -144,3 +145,10 @@ class ScalarKwargs(TypedDict, total=False):
144145

145146
EvalNames: TypeAlias = Callable[[CompliantFrameT], Sequence[str]]
146147
"""A function from a `Frame` to a sequence of columns names *before* any aliasing takes place."""
148+
149+
WindowFunction: TypeAlias = "Callable[[WindowInputs[NativeExprT]], NativeExprT]"
150+
"""A function evaluated with `over(partition_by=..., order_by=...)`."""
151+
UnorderableWindowFunction: TypeAlias = (
152+
"Callable[[UnorderableWindowInputs[NativeExprT]], NativeExprT]"
153+
)
154+
"""A function evaluated with `over(partition_by=...)`, without `order_by` (e.g. `is_unique`, `rank`)."""

narwhals/_compliant/window.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from __future__ import annotations
2+
3+
from typing import Generic, Sequence
4+
5+
from narwhals._compliant.typing import NativeExprT_co
6+
7+
8+
class WindowInputs(Generic[NativeExprT_co]):
9+
__slots__ = ("expr", "order_by", "partition_by")
10+
11+
def __init__(
12+
self, expr: NativeExprT_co, partition_by: Sequence[str], order_by: Sequence[str]
13+
) -> None:
14+
self.expr = expr
15+
self.partition_by = partition_by
16+
self.order_by = order_by
17+
18+
19+
class UnorderableWindowInputs(Generic[NativeExprT_co]):
20+
__slots__ = ("expr", "partition_by")
21+
22+
def __init__(self, expr: NativeExprT_co, partition_by: Sequence[str]) -> None:
23+
self.expr = expr
24+
self.partition_by = partition_by

narwhals/_duckdb/expr.py

Lines changed: 54 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@
88
from duckdb.typing import DuckDBPyType
99

1010
from narwhals._compliant import LazyExpr
11+
from narwhals._compliant.window import UnorderableWindowInputs, WindowInputs
1112
from narwhals._duckdb.expr_dt import DuckDBExprDateTimeNamespace
1213
from narwhals._duckdb.expr_list import DuckDBExprListNamespace
1314
from narwhals._duckdb.expr_str import DuckDBExprStringNamespace
1415
from narwhals._duckdb.expr_struct import DuckDBExprStructNamespace
1516
from narwhals._duckdb.utils import (
16-
UnorderableWindowInputs,
17-
WindowInputs,
1817
col,
1918
ensure_type,
2019
generate_order_by_sql,
@@ -30,10 +29,15 @@
3029
from duckdb import Expression
3130
from typing_extensions import Self
3231

33-
from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries
32+
from narwhals._compliant.typing import (
33+
AliasNames,
34+
EvalNames,
35+
EvalSeries,
36+
UnorderableWindowFunction,
37+
WindowFunction,
38+
)
3439
from narwhals._duckdb.dataframe import DuckDBLazyFrame
3540
from narwhals._duckdb.namespace import DuckDBNamespace
36-
from narwhals._duckdb.typing import UnorderableWindowFunction, WindowFunction
3741
from narwhals._expression_parsing import ExprMetadata
3842
from narwhals.dtypes import DType
3943
from narwhals.typing import (
@@ -46,6 +50,12 @@
4650
)
4751
from narwhals.utils import Version, _FullContext
4852

53+
DuckDBWindowInputs = WindowInputs[Expression]
54+
DuckDBUnorderableWindowInputs = UnorderableWindowInputs[Expression]
55+
DuckDBWindowFunction = WindowFunction[Expression]
56+
DuckDBUnorderableWindowFunction = UnorderableWindowFunction[Expression]
57+
58+
4959
with contextlib.suppress(ImportError): # requires duckdb>=1.3.0
5060
from duckdb import SQLExpression
5161

@@ -70,10 +80,10 @@ def __init__(
7080
self._metadata: ExprMetadata | None = None
7181

7282
# This can only be set by `_with_window_function`.
73-
self._window_function: WindowFunction | None = None
83+
self._window_function: DuckDBWindowFunction | None = None
7484

7585
# These can only be set by `_with_unorderable_window_function`
76-
self._unorderable_window_function: UnorderableWindowFunction | None = None
86+
self._unorderable_window_function: DuckDBUnorderableWindowFunction | None = None
7787
self._previous_call: EvalSeries[DuckDBLazyFrame, Expression] | None = None
7888

7989
def __call__(self, df: DuckDBLazyFrame) -> Sequence[Expression]:
@@ -94,14 +104,12 @@ def _cum_window_func(
94104
*,
95105
reverse: bool,
96106
func_name: Literal["sum", "max", "min", "count", "product"],
97-
) -> WindowFunction:
98-
def func(window_inputs: WindowInputs) -> Expression:
99-
order_by_sql = generate_order_by_sql(
100-
*window_inputs.order_by, ascending=not reverse
101-
)
102-
partition_by_sql = generate_partition_by_sql(*window_inputs.partition_by)
107+
) -> DuckDBWindowFunction:
108+
def func(inputs: DuckDBWindowInputs) -> Expression:
109+
order_by_sql = generate_order_by_sql(*inputs.order_by, ascending=not reverse)
110+
partition_by_sql = generate_partition_by_sql(*inputs.partition_by)
103111
sql = (
104-
f"{func_name} ({window_inputs.expr}) over ({partition_by_sql} {order_by_sql} "
112+
f"{func_name} ({inputs.expr}) over ({partition_by_sql} {order_by_sql} "
105113
"rows between unbounded preceding and current row)"
106114
)
107115
return SQLExpression(sql) # type: ignore[no-any-return, unused-ignore]
@@ -116,7 +124,7 @@ def _rolling_window_func(
116124
window_size: int,
117125
min_samples: int,
118126
ddof: int | None = None,
119-
) -> WindowFunction:
127+
) -> DuckDBWindowFunction:
120128
ensure_type(window_size, int, type(None))
121129
ensure_type(min_samples, int)
122130
supported_funcs = ["sum", "mean", "std", "var"]
@@ -129,9 +137,9 @@ def _rolling_window_func(
129137
start = f"{window_size - 1} preceding"
130138
end = "current row"
131139

132-
def func(window_inputs: WindowInputs) -> Expression:
133-
order_by_sql = generate_order_by_sql(*window_inputs.order_by, ascending=True)
134-
partition_by_sql = generate_partition_by_sql(*window_inputs.partition_by)
140+
def func(inputs: DuckDBWindowInputs) -> Expression:
141+
order_by_sql = generate_order_by_sql(*inputs.order_by, ascending=True)
142+
partition_by_sql = generate_partition_by_sql(*inputs.partition_by)
135143
window = f"({partition_by_sql} {order_by_sql} rows between {start} and {end})"
136144
if func_name in {"sum", "mean"}:
137145
func_: str = func_name
@@ -149,9 +157,9 @@ def func(window_inputs: WindowInputs) -> Expression:
149157
else: # pragma: no cover
150158
msg = f"Only the following functions are supported: {supported_funcs}.\nGot: {func_name}."
151159
raise ValueError(msg)
152-
condition_sql = f"count({window_inputs.expr}) over {window} >= {min_samples}"
160+
condition_sql = f"count({inputs.expr}) over {window} >= {min_samples}"
153161
condition = SQLExpression(condition_sql)
154-
value = SQLExpression(f"{func_}({window_inputs.expr}) over {window}")
162+
value = SQLExpression(f"{func_}({inputs.expr}) over {window}")
155163
return when(condition, value)
156164

157165
return func
@@ -238,7 +246,7 @@ def _with_alias_output_names(self, func: AliasNames | None, /) -> Self:
238246
version=self._version,
239247
)
240248

241-
def _with_window_function(self, window_function: WindowFunction) -> Self:
249+
def _with_window_function(self, window_function: DuckDBWindowFunction) -> Self:
242250
result = self.__class__(
243251
self._call,
244252
evaluate_output_names=self._evaluate_output_names,
@@ -251,7 +259,7 @@ def _with_window_function(self, window_function: WindowFunction) -> Self:
251259

252260
def _with_unorderable_window_function(
253261
self,
254-
unorderable_window_function: UnorderableWindowFunction,
262+
unorderable_window_function: DuckDBUnorderableWindowFunction,
255263
previous_call: EvalSeries[DuckDBLazyFrame, Expression],
256264
) -> Self:
257265
result = self.__class__(
@@ -542,55 +550,51 @@ def round(self, decimals: int) -> Self:
542550
def shift(self, n: int) -> Self:
543551
ensure_type(n, int)
544552

545-
def func(window_inputs: WindowInputs) -> Expression:
546-
order_by_sql = generate_order_by_sql(*window_inputs.order_by, ascending=True)
547-
partition_by_sql = generate_partition_by_sql(*window_inputs.partition_by)
548-
sql = (
549-
f"lag({window_inputs.expr}, {n}) over ({partition_by_sql} {order_by_sql})"
550-
)
553+
def func(inputs: DuckDBWindowInputs) -> Expression:
554+
order_by_sql = generate_order_by_sql(*inputs.order_by, ascending=True)
555+
partition_by_sql = generate_partition_by_sql(*inputs.partition_by)
556+
sql = f"lag({inputs.expr}, {n}) over ({partition_by_sql} {order_by_sql})"
551557
return SQLExpression(sql) # type: ignore[no-any-return, unused-ignore]
552558

553559
return self._with_window_function(func)
554560

555561
@requires.backend_version((1, 3))
556562
def is_first_distinct(self) -> Self:
557-
def func(window_inputs: WindowInputs) -> Expression:
558-
order_by_sql = generate_order_by_sql(*window_inputs.order_by, ascending=True)
559-
if window_inputs.partition_by:
563+
def func(inputs: DuckDBWindowInputs) -> Expression:
564+
order_by_sql = generate_order_by_sql(*inputs.order_by, ascending=True)
565+
if inputs.partition_by:
560566
partition_by_sql = (
561-
generate_partition_by_sql(*window_inputs.partition_by)
562-
+ f", {window_inputs.expr}"
567+
generate_partition_by_sql(*inputs.partition_by) + f", {inputs.expr}"
563568
)
564569
else:
565-
partition_by_sql = f"partition by {window_inputs.expr}"
570+
partition_by_sql = f"partition by {inputs.expr}"
566571
sql = f"{FunctionExpression('row_number')} over({partition_by_sql} {order_by_sql})"
567572
return SQLExpression(sql) == lit(1) # type: ignore[no-any-return, unused-ignore]
568573

569574
return self._with_window_function(func)
570575

571576
@requires.backend_version((1, 3))
572577
def is_last_distinct(self) -> Self:
573-
def func(window_inputs: WindowInputs) -> Expression:
574-
order_by_sql = generate_order_by_sql(*window_inputs.order_by, ascending=False)
575-
if window_inputs.partition_by:
578+
def func(inputs: DuckDBWindowInputs) -> Expression:
579+
order_by_sql = generate_order_by_sql(*inputs.order_by, ascending=False)
580+
if inputs.partition_by:
576581
partition_by_sql = (
577-
generate_partition_by_sql(*window_inputs.partition_by)
578-
+ f", {window_inputs.expr}"
582+
generate_partition_by_sql(*inputs.partition_by) + f", {inputs.expr}"
579583
)
580584
else:
581-
partition_by_sql = f"partition by {window_inputs.expr}"
585+
partition_by_sql = f"partition by {inputs.expr}"
582586
sql = f"{FunctionExpression('row_number')} over({partition_by_sql} {order_by_sql})"
583587
return SQLExpression(sql) == lit(1) # type: ignore[no-any-return, unused-ignore]
584588

585589
return self._with_window_function(func)
586590

587591
@requires.backend_version((1, 3))
588592
def diff(self) -> Self:
589-
def func(window_inputs: WindowInputs) -> Expression:
590-
order_by_sql = generate_order_by_sql(*window_inputs.order_by, ascending=True)
591-
partition_by_sql = generate_partition_by_sql(*window_inputs.partition_by)
592-
sql = f"lag({window_inputs.expr}) over ({partition_by_sql} {order_by_sql})"
593-
return window_inputs.expr - SQLExpression(sql) # type: ignore[no-any-return, unused-ignore]
593+
def func(inputs: DuckDBWindowInputs) -> Expression:
594+
order_by_sql = generate_order_by_sql(*inputs.order_by, ascending=True)
595+
partition_by_sql = generate_partition_by_sql(*inputs.partition_by)
596+
sql = f"lag({inputs.expr}) over ({partition_by_sql} {order_by_sql})"
597+
return inputs.expr - SQLExpression(sql) # type: ignore[no-any-return, unused-ignore]
594598

595599
return self._with_window_function(func)
596600

@@ -685,11 +689,9 @@ def fill_null(
685689
msg = f"`fill_null` with `strategy={strategy}` is only available in 'duckdb>=1.3.0'."
686690
raise NotImplementedError(msg)
687691

688-
def _fill_with_strategy(window_inputs: WindowInputs) -> Expression:
689-
order_by_sql = generate_order_by_sql(
690-
*window_inputs.order_by, ascending=True
691-
)
692-
partition_by_sql = generate_partition_by_sql(*window_inputs.partition_by)
692+
def _fill_with_strategy(inputs: DuckDBWindowInputs) -> Expression:
693+
order_by_sql = generate_order_by_sql(*inputs.order_by, ascending=True)
694+
partition_by_sql = generate_partition_by_sql(*inputs.partition_by)
693695

694696
fill_func = "last_value" if strategy == "forward" else "first_value"
695697
_limit = "unbounded" if limit is None else limit
@@ -699,7 +701,7 @@ def _fill_with_strategy(window_inputs: WindowInputs) -> Expression:
699701
else f"current row and {_limit} following"
700702
)
701703
sql = (
702-
f"{fill_func}({window_inputs.expr} ignore nulls) over "
704+
f"{fill_func}({inputs.expr} ignore nulls) over "
703705
f"({partition_by_sql} {order_by_sql} rows between {rows_between})"
704706
)
705707
return SQLExpression(sql) # type: ignore[no-any-return, unused-ignore]
@@ -770,11 +772,9 @@ def _rank(
770772
def _unpartitioned_rank(expr: Expression) -> Expression:
771773
return _rank(expr, descending=descending)
772774

773-
def _partitioned_rank(window_inputs: UnorderableWindowInputs) -> Expression:
775+
def _partitioned_rank(inputs: DuckDBUnorderableWindowInputs) -> Expression:
774776
return _rank(
775-
window_inputs.expr,
776-
descending=descending,
777-
partition_by=window_inputs.partition_by,
777+
inputs.expr, descending=descending, partition_by=inputs.partition_by
778778
)
779779

780780
return self._with_callable(_unpartitioned_rank)._with_unorderable_window_function(

narwhals/_duckdb/typing.py

Lines changed: 0 additions & 14 deletions
This file was deleted.

narwhals/_duckdb/utils.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from functools import lru_cache
4-
from typing import TYPE_CHECKING, Any, Sequence
4+
from typing import TYPE_CHECKING, Any
55

66
import duckdb
77

@@ -38,25 +38,6 @@
3838
"""Alias for `duckdb.CaseExpression`."""
3939

4040

41-
class WindowInputs:
42-
__slots__ = ("expr", "order_by", "partition_by")
43-
44-
def __init__(
45-
self, expr: Expression, partition_by: Sequence[str], order_by: Sequence[str]
46-
) -> None:
47-
self.expr = expr
48-
self.partition_by = partition_by
49-
self.order_by = order_by
50-
51-
52-
class UnorderableWindowInputs:
53-
__slots__ = ("expr", "partition_by")
54-
55-
def __init__(self, expr: Expression, partition_by: Sequence[str]) -> None:
56-
self.expr = expr
57-
self.partition_by = partition_by
58-
59-
6041
def concat_str(*exprs: Expression, separator: str = "") -> Expression:
6142
"""Concatenate many strings, NULL inputs are skipped.
6243

0 commit comments

Comments
 (0)