Skip to content

Commit c7bc710

Browse files
authored
chore: use Expression instead of duckdb.Expression for typing (#2579)
1 parent 808eb1c commit c7bc710

File tree

7 files changed

+75
-76
lines changed

7 files changed

+75
-76
lines changed

narwhals/_duckdb/dataframe.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
import pandas as pd
3737
import pyarrow as pa
38+
from duckdb import Expression
3839
from typing_extensions import Self
3940
from typing_extensions import TypeIs
4041

@@ -125,7 +126,7 @@ def get_column(self, name: str) -> DuckDBInterchangeSeries:
125126

126127
return DuckDBInterchangeSeries(self.native.select(name), version=self._version)
127128

128-
def _iter_columns(self) -> Iterator[duckdb.Expression]:
129+
def _iter_columns(self) -> Iterator[Expression]:
129130
for name in self.columns:
130131
yield col(name)
131132

@@ -303,7 +304,7 @@ def join(
303304
col(f'lhs."{left}"') == col(f'rhs."{right}"')
304305
for left, right in zip(left_on, right_on)
305306
)
306-
condition: duckdb.Expression = reduce(and_, it)
307+
condition: Expression = reduce(and_, it)
307308
rel = self.native.set_alias("lhs").join(
308309
other.native.set_alias("rhs"),
309310
# NOTE: Fixed in `--pre` https://github.com/duckdb/duckdb/pull/16933
@@ -342,7 +343,7 @@ def join_asof(
342343
) -> Self:
343344
lhs = self.native
344345
rhs = other.native
345-
conditions: list[duckdb.Expression] = []
346+
conditions: list[Expression] = []
346347
if by_left is not None and by_right is not None:
347348
conditions.extend(
348349
col(f'lhs."{left}"') == col(f'rhs."{right}"')
@@ -357,7 +358,7 @@ def join_asof(
357358
else:
358359
msg = "Only 'backward' and 'forward' strategies are currently supported for DuckDB"
359360
raise NotImplementedError(msg)
360-
condition: duckdb.Expression = reduce(and_, conditions)
361+
condition: Expression = reduce(and_, conditions)
361362
select = ["lhs.*"]
362363
for name in rhs.columns:
363364
if name in lhs.columns and (

narwhals/_duckdb/expr.py

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from narwhals.utils import requires
3535

3636
if TYPE_CHECKING:
37-
import duckdb
37+
from duckdb import Expression
3838
from typing_extensions import Self
3939

4040
from narwhals._compliant.typing import AliasNames
@@ -59,12 +59,12 @@
5959
from duckdb import SQLExpression # type: ignore[attr-defined, unused-ignore]
6060

6161

62-
class DuckDBExpr(LazyExpr["DuckDBLazyFrame", "duckdb.Expression"]):
62+
class DuckDBExpr(LazyExpr["DuckDBLazyFrame", "Expression"]):
6363
_implementation = Implementation.DUCKDB
6464

6565
def __init__(
6666
self,
67-
call: EvalSeries[DuckDBLazyFrame, duckdb.Expression],
67+
call: EvalSeries[DuckDBLazyFrame, Expression],
6868
*,
6969
evaluate_output_names: EvalNames[DuckDBLazyFrame],
7070
alias_output_names: AliasNames | None,
@@ -83,9 +83,9 @@ def __init__(
8383

8484
# These can only be set by `_with_unorderable_window_function`
8585
self._unorderable_window_function: UnorderableWindowFunction | None = None
86-
self._previous_call: EvalSeries[DuckDBLazyFrame, duckdb.Expression] | None = None
86+
self._previous_call: EvalSeries[DuckDBLazyFrame, Expression] | None = None
8787

88-
def __call__(self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]:
88+
def __call__(self, df: DuckDBLazyFrame) -> Sequence[Expression]:
8989
return self._call(df)
9090

9191
def __narwhals_expr__(self) -> None: ...
@@ -104,7 +104,7 @@ def _cum_window_func(
104104
reverse: bool,
105105
func_name: Literal["sum", "max", "min", "count", "product"],
106106
) -> WindowFunction:
107-
def func(window_inputs: WindowInputs) -> duckdb.Expression:
107+
def func(window_inputs: WindowInputs) -> Expression:
108108
order_by_sql = generate_order_by_sql(
109109
*window_inputs.order_by, ascending=not reverse
110110
)
@@ -138,7 +138,7 @@ def _rolling_window_func(
138138
start = f"{window_size - 1} preceding"
139139
end = "current row"
140140

141-
def func(window_inputs: WindowInputs) -> duckdb.Expression:
141+
def func(window_inputs: WindowInputs) -> Expression:
142142
order_by_sql = generate_order_by_sql(*window_inputs.order_by, ascending=True)
143143
partition_by_sql = generate_partition_by_sql(*window_inputs.partition_by)
144144
window = f"({partition_by_sql} {order_by_sql} rows between {start} and {end})"
@@ -174,7 +174,7 @@ def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Se
174174

175175
template = "{expr} over ()"
176176

177-
def func(df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]:
177+
def func(df: DuckDBLazyFrame) -> Sequence[Expression]:
178178
return [SQLExpression(template.format(expr=expr)) for expr in self(df)]
179179

180180
return self.__class__(
@@ -193,7 +193,7 @@ def from_column_names(
193193
*,
194194
context: _FullContext,
195195
) -> Self:
196-
def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
196+
def func(df: DuckDBLazyFrame) -> list[Expression]:
197197
return [col(name) for name in evaluate_column_names(df)]
198198

199199
return cls(
@@ -206,7 +206,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
206206

207207
@classmethod
208208
def from_column_indices(cls, *column_indices: int, context: _FullContext) -> Self:
209-
def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
209+
def func(df: DuckDBLazyFrame) -> list[Expression]:
210210
columns = df.columns
211211
return [col(columns[i]) for i in column_indices]
212212

@@ -219,7 +219,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
219219
)
220220

221221
def _with_callable(
222-
self, call: Callable[..., duckdb.Expression], /, **expressifiable_args: Self | Any
222+
self, call: Callable[..., Expression], /, **expressifiable_args: Self | Any
223223
) -> Self:
224224
"""Create expression from callable.
225225
@@ -230,7 +230,7 @@ def _with_callable(
230230
as expressions (e.g. in `nw.col('a').is_between('b', 'c')`)
231231
"""
232232

233-
def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
233+
def func(df: DuckDBLazyFrame) -> list[Expression]:
234234
native_series_list = self(df)
235235
other_native_series = {
236236
key: df._evaluate_expr(value) if self._is_expr(value) else lit(value)
@@ -272,7 +272,7 @@ def _with_window_function(self, window_function: WindowFunction) -> Self:
272272
def _with_unorderable_window_function(
273273
self,
274274
unorderable_window_function: UnorderableWindowFunction,
275-
previous_call: EvalSeries[DuckDBLazyFrame, duckdb.Expression],
275+
previous_call: EvalSeries[DuckDBLazyFrame, Expression],
276276
) -> Self:
277277
result = self.__class__(
278278
self._call,
@@ -286,7 +286,7 @@ def _with_unorderable_window_function(
286286
return result
287287

288288
@classmethod
289-
def _alias_native(cls, expr: duckdb.Expression, name: str) -> duckdb.Expression:
289+
def _alias_native(cls, expr: Expression, name: str) -> Expression:
290290
return expr.alias(name)
291291

292292
def __and__(self, other: DuckDBExpr) -> Self:
@@ -364,7 +364,7 @@ def __ne__(self, other: DuckDBExpr) -> Self: # type: ignore[override]
364364
return self._with_callable(lambda _input, other: _input != other, other=other)
365365

366366
def __invert__(self) -> Self:
367-
invert = cast("Callable[..., duckdb.Expression]", operator.invert)
367+
invert = cast("Callable[..., Expression]", operator.invert)
368368
return self._with_callable(invert)
369369

370370
def abs(self) -> Self:
@@ -374,7 +374,7 @@ def mean(self) -> Self:
374374
return self._with_callable(lambda _input: FunctionExpression("mean", _input))
375375

376376
def skew(self) -> Self:
377-
def func(_input: duckdb.Expression) -> duckdb.Expression:
377+
def func(_input: Expression) -> Expression:
378378
count = FunctionExpression("count", _input)
379379
# Adjust population skewness by correction factor to get sample skewness
380380
sample_skewness = (
@@ -402,7 +402,7 @@ def any(self) -> Self:
402402
def quantile(
403403
self, quantile: float, interpolation: RollingInterpolationMethod
404404
) -> Self:
405-
def func(_input: duckdb.Expression) -> duckdb.Expression:
405+
def func(_input: Expression) -> Expression:
406406
if interpolation == "linear":
407407
return FunctionExpression("quantile_cont", _input, lit(quantile))
408408
msg = "Only linear interpolation methods are supported for DuckDB quantile."
@@ -415,15 +415,15 @@ def clip(
415415
lower_bound: Self | NumericLiteral | TemporalLiteral | None,
416416
upper_bound: Self | NumericLiteral | TemporalLiteral | None,
417417
) -> Self:
418-
def _clip_lower(_input: duckdb.Expression, lower_bound: Any) -> duckdb.Expression:
418+
def _clip_lower(_input: Expression, lower_bound: Any) -> Expression:
419419
return FunctionExpression("greatest", _input, lower_bound)
420420

421-
def _clip_upper(_input: duckdb.Expression, upper_bound: Any) -> duckdb.Expression:
421+
def _clip_upper(_input: Expression, upper_bound: Any) -> Expression:
422422
return FunctionExpression("least", _input, upper_bound)
423423

424424
def _clip_both(
425-
_input: duckdb.Expression, lower_bound: Any, upper_bound: Any
426-
) -> duckdb.Expression:
425+
_input: Expression, lower_bound: Any, upper_bound: Any
426+
) -> Expression:
427427
return FunctionExpression(
428428
"greatest", FunctionExpression("least", _input, upper_bound), lower_bound
429429
)
@@ -440,7 +440,7 @@ def sum(self) -> Self:
440440
return self._with_callable(lambda _input: FunctionExpression("sum", _input))
441441

442442
def n_unique(self) -> Self:
443-
def func(_input: duckdb.Expression) -> duckdb.Expression:
443+
def func(_input: Expression) -> Expression:
444444
# https://stackoverflow.com/a/79338887/4451315
445445
return FunctionExpression(
446446
"array_unique", FunctionExpression("array_agg", _input)
@@ -466,7 +466,7 @@ def std(self, ddof: int) -> Self:
466466
lambda _input: FunctionExpression("stddev_samp", _input)
467467
)
468468

469-
def _std(_input: duckdb.Expression) -> duckdb.Expression:
469+
def _std(_input: Expression) -> Expression:
470470
n_samples = FunctionExpression("count", _input)
471471
return (
472472
FunctionExpression("stddev_pop", _input)
@@ -486,7 +486,7 @@ def var(self, ddof: int) -> Self:
486486
lambda _input: FunctionExpression("var_samp", _input)
487487
)
488488

489-
def _var(_input: duckdb.Expression) -> duckdb.Expression:
489+
def _var(_input: Expression) -> Expression:
490490
n_samples = FunctionExpression("count", _input)
491491
return (
492492
FunctionExpression("var_pop", _input)
@@ -512,7 +512,7 @@ def over(self, partition_by: Sequence[str], order_by: Sequence[str] | None) -> S
512512
if (window_function := self._window_function) is not None:
513513
assert order_by is not None # noqa: S101
514514

515-
def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
515+
def func(df: DuckDBLazyFrame) -> list[Expression]:
516516
return [
517517
window_function(WindowInputs(expr, partition_by, order_by))
518518
for expr in self._call(df)
@@ -522,7 +522,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
522522
) is not None:
523523
assert order_by is None # noqa: S101
524524

525-
def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
525+
def func(df: DuckDBLazyFrame) -> list[Expression]:
526526
assert self._previous_call is not None # noqa: S101
527527
return [
528528
unorderable_window_function(
@@ -534,7 +534,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
534534
partition_by_sql = generate_partition_by_sql(*partition_by)
535535
template = f"{{expr}} over ({partition_by_sql})"
536536

537-
def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
537+
def func(df: DuckDBLazyFrame) -> list[Expression]:
538538
return [
539539
SQLExpression(template.format(expr=expr)) for expr in self._call(df)
540540
]
@@ -570,7 +570,7 @@ def round(self, decimals: int) -> Self:
570570
def shift(self, n: int) -> Self:
571571
ensure_type(n, int)
572572

573-
def func(window_inputs: WindowInputs) -> duckdb.Expression:
573+
def func(window_inputs: WindowInputs) -> Expression:
574574
order_by_sql = generate_order_by_sql(*window_inputs.order_by, ascending=True)
575575
partition_by_sql = generate_partition_by_sql(*window_inputs.partition_by)
576576
sql = (
@@ -582,7 +582,7 @@ def func(window_inputs: WindowInputs) -> duckdb.Expression:
582582

583583
@requires.backend_version((1, 3))
584584
def is_first_distinct(self) -> Self:
585-
def func(window_inputs: WindowInputs) -> duckdb.Expression:
585+
def func(window_inputs: WindowInputs) -> Expression:
586586
order_by_sql = generate_order_by_sql(*window_inputs.order_by, ascending=True)
587587
if window_inputs.partition_by:
588588
partition_by_sql = (
@@ -598,7 +598,7 @@ def func(window_inputs: WindowInputs) -> duckdb.Expression:
598598

599599
@requires.backend_version((1, 3))
600600
def is_last_distinct(self) -> Self:
601-
def func(window_inputs: WindowInputs) -> duckdb.Expression:
601+
def func(window_inputs: WindowInputs) -> Expression:
602602
order_by_sql = generate_order_by_sql(*window_inputs.order_by, ascending=False)
603603
if window_inputs.partition_by:
604604
partition_by_sql = (
@@ -614,7 +614,7 @@ def func(window_inputs: WindowInputs) -> duckdb.Expression:
614614

615615
@requires.backend_version((1, 3))
616616
def diff(self) -> Self:
617-
def func(window_inputs: WindowInputs) -> duckdb.Expression:
617+
def func(window_inputs: WindowInputs) -> Expression:
618618
order_by_sql = generate_order_by_sql(*window_inputs.order_by, ascending=True)
619619
partition_by_sql = generate_partition_by_sql(*window_inputs.partition_by)
620620
sql = f"lag({window_inputs.expr}) over ({partition_by_sql} {order_by_sql})"
@@ -713,7 +713,7 @@ def fill_null(
713713
msg = f"`fill_null` with `strategy={strategy}` is only available in 'duckdb>=1.3.0'."
714714
raise NotImplementedError(msg)
715715

716-
def _fill_with_strategy(window_inputs: WindowInputs) -> duckdb.Expression:
716+
def _fill_with_strategy(window_inputs: WindowInputs) -> Expression:
717717
order_by_sql = generate_order_by_sql(
718718
*window_inputs.order_by, ascending=True
719719
)
@@ -734,21 +734,21 @@ def _fill_with_strategy(window_inputs: WindowInputs) -> duckdb.Expression:
734734

735735
return self._with_window_function(_fill_with_strategy)
736736

737-
def _fill_constant(_input: duckdb.Expression, value: Any) -> duckdb.Expression:
737+
def _fill_constant(_input: Expression, value: Any) -> Expression:
738738
return CoalesceOperator(_input, value)
739739

740740
return self._with_callable(_fill_constant, value=value)
741741

742742
def cast(self, dtype: DType | type[DType]) -> Self:
743-
def func(_input: duckdb.Expression) -> duckdb.Expression:
743+
def func(_input: Expression) -> Expression:
744744
native_dtype = narwhals_to_native_dtype(dtype, self._version)
745745
return _input.cast(DuckDBPyType(native_dtype))
746746

747747
return self._with_callable(func)
748748

749749
@requires.backend_version((1, 3))
750750
def is_unique(self) -> Self:
751-
def func(_input: duckdb.Expression) -> duckdb.Expression:
751+
def func(_input: Expression) -> Expression:
752752
sql = f"count(*) over (partition by {_input})"
753753
return SQLExpression(sql) == lit(1) # type: ignore[no-any-return, unused-ignore]
754754

@@ -764,11 +764,11 @@ def rank(self, method: RankMethod, *, descending: bool) -> Self:
764764
func = FunctionExpression("row_number")
765765

766766
def _rank(
767-
_input: duckdb.Expression,
767+
_input: Expression,
768768
*,
769769
descending: bool,
770-
partition_by: Sequence[str | duckdb.Expression] | None = None,
771-
) -> duckdb.Expression:
770+
partition_by: Sequence[str | Expression] | None = None,
771+
) -> Expression:
772772
order_by_sql = (
773773
f"order by {_input} desc nulls last"
774774
if descending
@@ -795,12 +795,12 @@ def _rank(
795795
expr = SQLExpression(f"{func} OVER ({window})")
796796
return when(_input.isnotnull(), expr)
797797

798-
def _unpartitioned_rank(_input: duckdb.Expression) -> duckdb.Expression:
798+
def _unpartitioned_rank(_input: Expression) -> Expression:
799799
return _rank(_input, descending=descending)
800800

801801
def _partitioned_rank(
802802
window_inputs: UnorderableWindowInputs,
803-
) -> duckdb.Expression:
803+
) -> Expression:
804804
return _rank(
805805
window_inputs.expr,
806806
descending=descending,
@@ -813,7 +813,7 @@ def _partitioned_rank(
813813
)
814814

815815
def log(self, base: float) -> Self:
816-
def _log(_input: duckdb.Expression) -> duckdb.Expression:
816+
def _log(_input: Expression) -> Expression:
817817
log = FunctionExpression("log", _input)
818818
return (
819819
when(_input < lit(0), lit(float("nan")))

narwhals/_duckdb/expr_str.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from narwhals.utils import not_implemented
99

1010
if TYPE_CHECKING:
11-
import duckdb
11+
from duckdb import Expression
1212

1313
from narwhals._duckdb.expr import DuckDBExpr
1414

@@ -28,15 +28,15 @@ def ends_with(self, suffix: str) -> DuckDBExpr:
2828
)
2929

3030
def contains(self, pattern: str, *, literal: bool) -> DuckDBExpr:
31-
def func(_input: duckdb.Expression) -> duckdb.Expression:
31+
def func(_input: Expression) -> Expression:
3232
if literal:
3333
return FunctionExpression("contains", _input, lit(pattern))
3434
return FunctionExpression("regexp_matches", _input, lit(pattern))
3535

3636
return self._compliant_expr._with_callable(func)
3737

3838
def slice(self, offset: int, length: int) -> DuckDBExpr:
39-
def func(_input: duckdb.Expression) -> duckdb.Expression:
39+
def func(_input: Expression) -> Expression:
4040
offset_lit = lit(offset)
4141
return FunctionExpression(
4242
"array_slice",

0 commit comments

Comments
 (0)