88from duckdb .typing import DuckDBPyType
99
1010from narwhals ._compliant import LazyExpr
11+ from narwhals ._compliant .window import UnorderableWindowInputs , WindowInputs
1112from narwhals ._duckdb .expr_dt import DuckDBExprDateTimeNamespace
1213from narwhals ._duckdb .expr_list import DuckDBExprListNamespace
1314from narwhals ._duckdb .expr_str import DuckDBExprStringNamespace
1415from narwhals ._duckdb .expr_struct import DuckDBExprStructNamespace
1516from narwhals ._duckdb .utils import (
16- UnorderableWindowInputs ,
17- WindowInputs ,
1817 col ,
1918 ensure_type ,
2019 generate_order_by_sql ,
3029 from duckdb import Expression
3130 from typing_extensions import Self
3231
33- from narwhals ._compliant .typing import AliasNames , EvalNames , EvalSeries
32+ from narwhals ._compliant .typing import (
33+ AliasNames ,
34+ EvalNames ,
35+ EvalSeries ,
36+ UnorderableWindowFunction ,
37+ WindowFunction ,
38+ )
3439 from narwhals ._duckdb .dataframe import DuckDBLazyFrame
3540 from narwhals ._duckdb .namespace import DuckDBNamespace
36- from narwhals ._duckdb .typing import UnorderableWindowFunction , WindowFunction
3741 from narwhals ._expression_parsing import ExprMetadata
3842 from narwhals .dtypes import DType
3943 from narwhals .typing import (
4650 )
4751 from narwhals .utils import Version , _FullContext
4852
53+ DuckDBWindowInputs = WindowInputs [Expression ]
54+ DuckDBUnorderableWindowInputs = UnorderableWindowInputs [Expression ]
55+ DuckDBWindowFunction = WindowFunction [Expression ]
56+ DuckDBUnorderableWindowFunction = UnorderableWindowFunction [Expression ]
57+
58+
4959with contextlib .suppress (ImportError ): # requires duckdb>=1.3.0
5060 from duckdb import SQLExpression
5161
@@ -70,10 +80,10 @@ def __init__(
7080 self ._metadata : ExprMetadata | None = None
7181
7282 # This can only be set by `_with_window_function`.
73- self ._window_function : WindowFunction | None = None
83+ self ._window_function : DuckDBWindowFunction | None = None
7484
7585 # These can only be set by `_with_unorderable_window_function`
76- self ._unorderable_window_function : UnorderableWindowFunction | None = None
86+ self ._unorderable_window_function : DuckDBUnorderableWindowFunction | None = None
7787 self ._previous_call : EvalSeries [DuckDBLazyFrame , Expression ] | None = None
7888
7989 def __call__ (self , df : DuckDBLazyFrame ) -> Sequence [Expression ]:
@@ -94,14 +104,12 @@ def _cum_window_func(
94104 * ,
95105 reverse : bool ,
96106 func_name : Literal ["sum" , "max" , "min" , "count" , "product" ],
97- ) -> WindowFunction :
98- def func (window_inputs : WindowInputs ) -> Expression :
99- order_by_sql = generate_order_by_sql (
100- * window_inputs .order_by , ascending = not reverse
101- )
102- partition_by_sql = generate_partition_by_sql (* window_inputs .partition_by )
107+ ) -> DuckDBWindowFunction :
108+ def func (inputs : DuckDBWindowInputs ) -> Expression :
109+ order_by_sql = generate_order_by_sql (* inputs .order_by , ascending = not reverse )
110+ partition_by_sql = generate_partition_by_sql (* inputs .partition_by )
103111 sql = (
104- f"{ func_name } ({ window_inputs .expr } ) over ({ partition_by_sql } { order_by_sql } "
112+ f"{ func_name } ({ inputs .expr } ) over ({ partition_by_sql } { order_by_sql } "
105113 "rows between unbounded preceding and current row)"
106114 )
107115 return SQLExpression (sql ) # type: ignore[no-any-return, unused-ignore]
@@ -116,7 +124,7 @@ def _rolling_window_func(
116124 window_size : int ,
117125 min_samples : int ,
118126 ddof : int | None = None ,
119- ) -> WindowFunction :
127+ ) -> DuckDBWindowFunction :
120128 ensure_type (window_size , int , type (None ))
121129 ensure_type (min_samples , int )
122130 supported_funcs = ["sum" , "mean" , "std" , "var" ]
@@ -129,9 +137,9 @@ def _rolling_window_func(
129137 start = f"{ window_size - 1 } preceding"
130138 end = "current row"
131139
132- def func (window_inputs : WindowInputs ) -> Expression :
133- order_by_sql = generate_order_by_sql (* window_inputs .order_by , ascending = True )
134- partition_by_sql = generate_partition_by_sql (* window_inputs .partition_by )
140+ def func (inputs : DuckDBWindowInputs ) -> Expression :
141+ order_by_sql = generate_order_by_sql (* inputs .order_by , ascending = True )
142+ partition_by_sql = generate_partition_by_sql (* inputs .partition_by )
135143 window = f"({ partition_by_sql } { order_by_sql } rows between { start } and { end } )"
136144 if func_name in {"sum" , "mean" }:
137145 func_ : str = func_name
@@ -149,9 +157,9 @@ def func(window_inputs: WindowInputs) -> Expression:
149157 else : # pragma: no cover
150158 msg = f"Only the following functions are supported: { supported_funcs } .\n Got: { func_name } ."
151159 raise ValueError (msg )
152- condition_sql = f"count({ window_inputs .expr } ) over { window } >= { min_samples } "
160+ condition_sql = f"count({ inputs .expr } ) over { window } >= { min_samples } "
153161 condition = SQLExpression (condition_sql )
154- value = SQLExpression (f"{ func_ } ({ window_inputs .expr } ) over { window } " )
162+ value = SQLExpression (f"{ func_ } ({ inputs .expr } ) over { window } " )
155163 return when (condition , value )
156164
157165 return func
@@ -238,7 +246,7 @@ def _with_alias_output_names(self, func: AliasNames | None, /) -> Self:
238246 version = self ._version ,
239247 )
240248
241- def _with_window_function (self , window_function : WindowFunction ) -> Self :
249+ def _with_window_function (self , window_function : DuckDBWindowFunction ) -> Self :
242250 result = self .__class__ (
243251 self ._call ,
244252 evaluate_output_names = self ._evaluate_output_names ,
@@ -251,7 +259,7 @@ def _with_window_function(self, window_function: WindowFunction) -> Self:
251259
252260 def _with_unorderable_window_function (
253261 self ,
254- unorderable_window_function : UnorderableWindowFunction ,
262+ unorderable_window_function : DuckDBUnorderableWindowFunction ,
255263 previous_call : EvalSeries [DuckDBLazyFrame , Expression ],
256264 ) -> Self :
257265 result = self .__class__ (
@@ -542,55 +550,51 @@ def round(self, decimals: int) -> Self:
542550 def shift (self , n : int ) -> Self :
543551 ensure_type (n , int )
544552
545- def func (window_inputs : WindowInputs ) -> Expression :
546- order_by_sql = generate_order_by_sql (* window_inputs .order_by , ascending = True )
547- partition_by_sql = generate_partition_by_sql (* window_inputs .partition_by )
548- sql = (
549- f"lag({ window_inputs .expr } , { n } ) over ({ partition_by_sql } { order_by_sql } )"
550- )
553+ def func (inputs : DuckDBWindowInputs ) -> Expression :
554+ order_by_sql = generate_order_by_sql (* inputs .order_by , ascending = True )
555+ partition_by_sql = generate_partition_by_sql (* inputs .partition_by )
556+ sql = f"lag({ inputs .expr } , { n } ) over ({ partition_by_sql } { order_by_sql } )"
551557 return SQLExpression (sql ) # type: ignore[no-any-return, unused-ignore]
552558
553559 return self ._with_window_function (func )
554560
555561 @requires .backend_version ((1 , 3 ))
556562 def is_first_distinct (self ) -> Self :
557- def func (window_inputs : WindowInputs ) -> Expression :
558- order_by_sql = generate_order_by_sql (* window_inputs .order_by , ascending = True )
559- if window_inputs .partition_by :
563+ def func (inputs : DuckDBWindowInputs ) -> Expression :
564+ order_by_sql = generate_order_by_sql (* inputs .order_by , ascending = True )
565+ if inputs .partition_by :
560566 partition_by_sql = (
561- generate_partition_by_sql (* window_inputs .partition_by )
562- + f", { window_inputs .expr } "
567+ generate_partition_by_sql (* inputs .partition_by ) + f", { inputs .expr } "
563568 )
564569 else :
565- partition_by_sql = f"partition by { window_inputs .expr } "
570+ partition_by_sql = f"partition by { inputs .expr } "
566571 sql = f"{ FunctionExpression ('row_number' )} over({ partition_by_sql } { order_by_sql } )"
567572 return SQLExpression (sql ) == lit (1 ) # type: ignore[no-any-return, unused-ignore]
568573
569574 return self ._with_window_function (func )
570575
571576 @requires .backend_version ((1 , 3 ))
572577 def is_last_distinct (self ) -> Self :
573- def func (window_inputs : WindowInputs ) -> Expression :
574- order_by_sql = generate_order_by_sql (* window_inputs .order_by , ascending = False )
575- if window_inputs .partition_by :
578+ def func (inputs : DuckDBWindowInputs ) -> Expression :
579+ order_by_sql = generate_order_by_sql (* inputs .order_by , ascending = False )
580+ if inputs .partition_by :
576581 partition_by_sql = (
577- generate_partition_by_sql (* window_inputs .partition_by )
578- + f", { window_inputs .expr } "
582+ generate_partition_by_sql (* inputs .partition_by ) + f", { inputs .expr } "
579583 )
580584 else :
581- partition_by_sql = f"partition by { window_inputs .expr } "
585+ partition_by_sql = f"partition by { inputs .expr } "
582586 sql = f"{ FunctionExpression ('row_number' )} over({ partition_by_sql } { order_by_sql } )"
583587 return SQLExpression (sql ) == lit (1 ) # type: ignore[no-any-return, unused-ignore]
584588
585589 return self ._with_window_function (func )
586590
587591 @requires .backend_version ((1 , 3 ))
588592 def diff (self ) -> Self :
589- def func (window_inputs : WindowInputs ) -> Expression :
590- order_by_sql = generate_order_by_sql (* window_inputs .order_by , ascending = True )
591- partition_by_sql = generate_partition_by_sql (* window_inputs .partition_by )
592- sql = f"lag({ window_inputs .expr } ) over ({ partition_by_sql } { order_by_sql } )"
593- return window_inputs .expr - SQLExpression (sql ) # type: ignore[no-any-return, unused-ignore]
593+ def func (inputs : DuckDBWindowInputs ) -> Expression :
594+ order_by_sql = generate_order_by_sql (* inputs .order_by , ascending = True )
595+ partition_by_sql = generate_partition_by_sql (* inputs .partition_by )
596+ sql = f"lag({ inputs .expr } ) over ({ partition_by_sql } { order_by_sql } )"
597+ return inputs .expr - SQLExpression (sql ) # type: ignore[no-any-return, unused-ignore]
594598
595599 return self ._with_window_function (func )
596600
@@ -685,11 +689,9 @@ def fill_null(
685689 msg = f"`fill_null` with `strategy={ strategy } ` is only available in 'duckdb>=1.3.0'."
686690 raise NotImplementedError (msg )
687691
688- def _fill_with_strategy (window_inputs : WindowInputs ) -> Expression :
689- order_by_sql = generate_order_by_sql (
690- * window_inputs .order_by , ascending = True
691- )
692- partition_by_sql = generate_partition_by_sql (* window_inputs .partition_by )
692+ def _fill_with_strategy (inputs : DuckDBWindowInputs ) -> Expression :
693+ order_by_sql = generate_order_by_sql (* inputs .order_by , ascending = True )
694+ partition_by_sql = generate_partition_by_sql (* inputs .partition_by )
693695
694696 fill_func = "last_value" if strategy == "forward" else "first_value"
695697 _limit = "unbounded" if limit is None else limit
@@ -699,7 +701,7 @@ def _fill_with_strategy(window_inputs: WindowInputs) -> Expression:
699701 else f"current row and { _limit } following"
700702 )
701703 sql = (
702- f"{ fill_func } ({ window_inputs .expr } ignore nulls) over "
704+ f"{ fill_func } ({ inputs .expr } ignore nulls) over "
703705 f"({ partition_by_sql } { order_by_sql } rows between { rows_between } )"
704706 )
705707 return SQLExpression (sql ) # type: ignore[no-any-return, unused-ignore]
@@ -770,11 +772,9 @@ def _rank(
770772 def _unpartitioned_rank (expr : Expression ) -> Expression :
771773 return _rank (expr , descending = descending )
772774
773- def _partitioned_rank (window_inputs : UnorderableWindowInputs ) -> Expression :
775+ def _partitioned_rank (inputs : DuckDBUnorderableWindowInputs ) -> Expression :
774776 return _rank (
775- window_inputs .expr ,
776- descending = descending ,
777- partition_by = window_inputs .partition_by ,
777+ inputs .expr , descending = descending , partition_by = inputs .partition_by
778778 )
779779
780780 return self ._with_callable (_unpartitioned_rank )._with_unorderable_window_function (
0 commit comments