Skip to content

Commit 7ad112a

Browse files
committed
wip operator: added all changes from pair session, 1 problem remaining
1 parent 4aae069 commit 7ad112a

File tree

2 files changed

+95
-100
lines changed

2 files changed

+95
-100
lines changed

narwhals/_sql/expr.py

Lines changed: 64 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
AliasNames,
88
EvalNames,
99
EvalSeries,
10-
NativeExprT,
1110
WindowFunction,
1211
)
1312
from narwhals._compliant.window import WindowInputs
@@ -28,9 +27,6 @@
2827
from narwhals._sql.namespace import SQLNamespace
2928
from narwhals.typing import NumericLiteral, PythonLiteral, RankMethod, TemporalLiteral
3029

31-
# am I right in thinking we need to pass NativeSQLExprT here, not NativeExprT? since NativeSQLExprT
32-
# inherits from it, I can use that througout in the code and it will also have its parents functionality.
33-
# though at the moment there are still problems with the class..
3430
class SQLExpr(LazyExpr[SQLLazyFrameT, NativeSQLExprT], Protocol[SQLLazyFrameT, NativeSQLExprT]):
3531
_call: EvalSeries[SQLLazyFrameT, NativeSQLExprT]
3632
_evaluate_output_names: EvalNames[SQLLazyFrameT]
@@ -42,26 +38,26 @@ class SQLExpr(LazyExpr[SQLLazyFrameT, NativeSQLExprT], Protocol[SQLLazyFrameT, N
4238

4339
def __init__(
4440
self,
45-
call: EvalSeries[SQLLazyFrameT, NativeExprT],
46-
window_function: WindowFunction[SQLLazyFrameT, NativeExprT] | None = None,
41+
call: EvalSeries[SQLLazyFrameT, NativeSQLExprT],
42+
window_function: WindowFunction[SQLLazyFrameT, NativeSQLExprT] | None = None,
4743
*,
4844
evaluate_output_names: EvalNames[SQLLazyFrameT],
4945
alias_output_names: AliasNames | None,
5046
version: Version,
5147
implementation: Implementation = Implementation.DUCKDB,
5248
) -> None: ...
5349

54-
def __call__(self, df: SQLLazyFrameT) -> Sequence[NativeExprT]:
50+
def __call__(self, df: SQLLazyFrameT) -> Sequence[NativeSQLExprT]:
5551
return self._call(df)
5652

5753
def __narwhals_namespace__(
5854
self,
59-
) -> SQLNamespace[SQLLazyFrameT, Self, Any, NativeExprT]: ...
55+
) -> SQLNamespace[SQLLazyFrameT, Self, Any, NativeSQLExprT]: ...
6056

6157
def _callable_to_eval_series(
62-
self, call: Callable[..., NativeExprT], /, **expressifiable_args: Self | Any
63-
) -> EvalSeries[SQLLazyFrameT, NativeExprT]:
64-
def func(df: SQLLazyFrameT) -> list[NativeExprT]:
58+
self, call: Callable[..., NativeSQLExprT], /, **expressifiable_args: Self | Any
59+
) -> EvalSeries[SQLLazyFrameT, NativeSQLExprT]:
60+
def func(df: SQLLazyFrameT) -> list[NativeSQLExprT]:
6561
native_series_list = self(df)
6662
other_native_series = {
6763
key: df._evaluate_expr(value)
@@ -77,11 +73,11 @@ def func(df: SQLLazyFrameT) -> list[NativeExprT]:
7773
return func
7874

7975
def _push_down_window_function(
80-
self, call: Callable[..., NativeExprT], /, **expressifiable_args: Self | Any
81-
) -> WindowFunction[SQLLazyFrameT, NativeExprT]:
76+
self, call: Callable[..., NativeSQLExprT], /, **expressifiable_args: Self | Any
77+
) -> WindowFunction[SQLLazyFrameT, NativeSQLExprT]:
8278
def window_f(
83-
df: SQLLazyFrameT, window_inputs: WindowInputs[NativeExprT]
84-
) -> Sequence[NativeExprT]:
79+
df: SQLLazyFrameT, window_inputs: WindowInputs[NativeSQLExprT]
80+
) -> Sequence[NativeSQLExprT]:
8581
# If a function `f` is elementwise, and `g` is another function, then
8682
# - `f(g) over (window)`
8783
# - `f(g over (window))
@@ -114,7 +110,7 @@ def _with_window_function(
114110
)
115111

116112
def _with_callable(
117-
self, call: Callable[..., NativeExprT], /, **expressifiable_args: Self | Any
113+
self, call: Callable[..., NativeSQLExprT], /, **expressifiable_args: Self | Any
118114
) -> Self:
119115
return self.__class__(
120116
self._callable_to_eval_series(call, **expressifiable_args),
@@ -125,7 +121,7 @@ def _with_callable(
125121
)
126122

127123
def _with_elementwise(
128-
self, call: Callable[..., NativeExprT], /, **expressifiable_args: Self | Any
124+
self, call: Callable[..., NativeSQLExprT], /, **expressifiable_args: Self | Any
129125
) -> Self:
130126
return self.__class__(
131127
self._callable_to_eval_series(call, **expressifiable_args),
@@ -136,7 +132,7 @@ def _with_elementwise(
136132
implementation=self._implementation,
137133
)
138134

139-
def _with_binary(self, op: Callable[..., NativeExprT], other: Self | Any) -> Self:
135+
def _with_binary(self, op: Callable[..., NativeSQLExprT], other: Self | Any) -> Self:
140136
return self.__class__(
141137
self._callable_to_eval_series(op, other=other),
142138
self._push_down_window_function(op, other=other),
@@ -177,46 +173,46 @@ def default_window_func(
177173

178174
return self._window_function or default_window_func
179175

180-
def _function(self, name: str, *args: NativeExprT | PythonLiteral) -> NativeExprT:
176+
def _function(self, name: str, *args: NativeSQLExprT | PythonLiteral) -> NativeSQLExprT:
181177
return self.__narwhals_namespace__()._function(name, *args)
182178

183-
def _lit(self, value: Any) -> NativeExprT:
179+
def _lit(self, value: Any) -> NativeSQLExprT:
184180
return self.__narwhals_namespace__()._lit(value)
185181

186-
def _coalesce(self, *expr: NativeExprT) -> NativeExprT:
182+
def _coalesce(self, *expr: NativeSQLExprT) -> NativeSQLExprT:
187183
return self.__narwhals_namespace__()._coalesce(*expr)
188184

189-
def _count_star(self) -> NativeExprT: ...
185+
def _count_star(self) -> NativeSQLExprT: ...
190186

191187
def _when(
192188
self,
193-
condition: NativeExprT,
194-
value: NativeExprT,
195-
otherwise: NativeExprT | None = None,
196-
) -> NativeExprT:
189+
condition: NativeSQLExprT,
190+
value: NativeSQLExprT,
191+
otherwise: NativeSQLExprT | None = None,
192+
) -> NativeSQLExprT:
197193
return self.__narwhals_namespace__()._when(condition, value, otherwise)
198194

199195
def _window_expression(
200196
self,
201-
expr: NativeExprT,
202-
partition_by: Sequence[str | NativeExprT] = (),
203-
order_by: Sequence[str | NativeExprT] = (),
197+
expr: NativeSQLExprT,
198+
partition_by: Sequence[str | NativeSQLExprT] = (),
199+
order_by: Sequence[str | NativeSQLExprT] = (),
204200
rows_start: int | None = None,
205201
rows_end: int | None = None,
206202
*,
207203
descending: Sequence[bool] | None = None,
208204
nulls_last: Sequence[bool] | None = None,
209-
) -> NativeExprT: ...
205+
) -> NativeSQLExprT: ...
210206

211207
def _cum_window_func(
212208
self,
213209
func_name: Literal["sum", "max", "min", "count", "product"],
214210
*,
215211
reverse: bool,
216-
) -> WindowFunction[SQLLazyFrameT, NativeExprT]:
212+
) -> WindowFunction[SQLLazyFrameT, NativeSQLExprT]:
217213
def func(
218-
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
219-
) -> Sequence[NativeExprT]:
214+
df: SQLLazyFrameT, inputs: WindowInputs[NativeSQLExprT]
215+
) -> Sequence[NativeSQLExprT]:
220216
return [
221217
self._window_expression(
222218
self._function(func_name, expr),
@@ -239,7 +235,7 @@ def _rolling_window_func(
239235
ddof: int | None = None,
240236
*,
241237
center: bool,
242-
) -> WindowFunction[SQLLazyFrameT, NativeExprT]:
238+
) -> WindowFunction[SQLLazyFrameT, NativeSQLExprT]:
243239
supported_funcs = ["sum", "mean", "std", "var"]
244240
if center:
245241
half = (window_size - 1) // 2
@@ -251,8 +247,8 @@ def _rolling_window_func(
251247
end = 0
252248

253249
def func(
254-
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
255-
) -> Sequence[NativeExprT]:
250+
df: SQLLazyFrameT, inputs: WindowInputs[NativeSQLExprT]
251+
) -> Sequence[NativeSQLExprT]:
256252
if func_name in {"sum", "mean"}:
257253
func_: str = func_name
258254
elif func_name == "var" and ddof == 0:
@@ -297,13 +293,13 @@ def _backend_version(self) -> tuple[int, ...]:
297293
return self._implementation._backend_version()
298294

299295
@classmethod
300-
def _alias_native(cls, expr: NativeExprT, name: str, /) -> NativeExprT: ...
296+
def _alias_native(cls, expr: NativeSQLExprT, name: str, /) -> NativeSQLExprT: ...
301297

302298
@classmethod
303299
def _from_elementwise_horizontal_op(
304-
cls, func: Callable[[Iterable[NativeExprT]], NativeExprT], *exprs: Self
300+
cls, func: Callable[[Iterable[NativeSQLExprT]], NativeSQLExprT], *exprs: Self
305301
) -> Self:
306-
def call(df: SQLLazyFrameT) -> Sequence[NativeExprT]:
302+
def call(df: SQLLazyFrameT) -> Sequence[NativeSQLExprT]:
307303
cols = (col for _expr in exprs for col in _expr(df))
308304
return [func(cols)]
309305

@@ -390,12 +386,12 @@ def __or__(self, other: Self) -> Self:
390386

391387
# Aggregations
392388
def all(self) -> Self:
393-
def f(expr: NativeExprT) -> NativeExprT:
389+
def f(expr: NativeSQLExprT) -> NativeSQLExprT:
394390
return self._coalesce(self._function("bool_and", expr), self._lit(True)) # noqa: FBT003
395391

396392
def window_f(
397-
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
398-
) -> Sequence[NativeExprT]:
393+
df: SQLLazyFrameT, inputs: WindowInputs[NativeSQLExprT]
394+
) -> Sequence[NativeSQLExprT]:
399395
return [
400396
self._coalesce(
401397
self._window_expression(
@@ -409,12 +405,12 @@ def window_f(
409405
return self._with_callable(f)._with_window_function(window_f)
410406

411407
def any(self) -> Self:
412-
def f(expr: NativeExprT) -> NativeExprT:
408+
def f(expr: NativeSQLExprT) -> NativeSQLExprT:
413409
return self._coalesce(self._function("bool_or", expr), self._lit(False)) # noqa: FBT003
414410

415411
def window_f(
416-
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
417-
) -> Sequence[NativeExprT]:
412+
df: SQLLazyFrameT, inputs: WindowInputs[NativeSQLExprT]
413+
) -> Sequence[NativeSQLExprT]:
418414
return [
419415
self._coalesce(
420416
self._window_expression(
@@ -443,12 +439,12 @@ def count(self) -> Self:
443439
return self._with_callable(lambda expr: self._function("count", expr))
444440

445441
def sum(self) -> Self:
446-
def f(expr: NativeExprT) -> NativeExprT:
442+
def f(expr: NativeSQLExprT) -> NativeSQLExprT:
447443
return self._coalesce(self._function("sum", expr), self._lit(0))
448444

449445
def window_f(
450-
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
451-
) -> Sequence[NativeExprT]:
446+
df: SQLLazyFrameT, inputs: WindowInputs[NativeSQLExprT]
447+
) -> Sequence[NativeSQLExprT]:
452448
return [
453449
self._coalesce(
454450
self._window_expression(
@@ -470,15 +466,15 @@ def clip(
470466
lower_bound: Self | NumericLiteral | TemporalLiteral | None,
471467
upper_bound: Self | NumericLiteral | TemporalLiteral | None,
472468
) -> Self:
473-
def _clip_lower(expr: NativeExprT, lower_bound: Any) -> NativeExprT:
469+
def _clip_lower(expr: NativeSQLExprT, lower_bound: Any) -> NativeSQLExprT:
474470
return self._function("greatest", expr, lower_bound)
475471

476-
def _clip_upper(expr: NativeExprT, upper_bound: Any) -> NativeExprT:
472+
def _clip_upper(expr: NativeSQLExprT, upper_bound: Any) -> NativeSQLExprT:
477473
return self._function("least", expr, upper_bound)
478474

479475
def _clip_both(
480-
expr: NativeExprT, lower_bound: Any, upper_bound: Any
481-
) -> NativeExprT:
476+
expr: NativeSQLExprT, lower_bound: Any, upper_bound: Any
477+
) -> NativeSQLExprT:
482478
return self._function(
483479
"greatest", self._function("least", expr, upper_bound), lower_bound
484480
)
@@ -590,8 +586,8 @@ def func(
590586

591587
def shift(self, n: int) -> Self:
592588
def func(
593-
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
594-
) -> Sequence[NativeExprT]:
589+
df: SQLLazyFrameT, inputs: WindowInputs[NativeSQLExprT]
590+
) -> Sequence[NativeSQLExprT]:
595591
return [
596592
self._window_expression(
597593
self._function("lag", expr, n), inputs.partition_by, inputs.order_by
@@ -603,12 +599,12 @@ def func(
603599

604600
def is_first_distinct(self) -> Self:
605601
def func(
606-
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
607-
) -> Sequence[NativeExprT]:
602+
df: SQLLazyFrameT, inputs: WindowInputs[NativeSQLExprT]
603+
) -> Sequence[NativeSQLExprT]:
608604
# pyright checkers think the return type is `list[bool]` because of `==`
609605
return [
610606
cast(
611-
"NativeExprT",
607+
"NativeSQLExprT",
612608
self._window_expression(
613609
self._function("row_number"),
614610
(*inputs.partition_by, expr),
@@ -623,11 +619,11 @@ def func(
623619

624620
def is_last_distinct(self) -> Self:
625621
def func(
626-
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
627-
) -> Sequence[NativeExprT]:
622+
df: SQLLazyFrameT, inputs: WindowInputs[NativeSQLExprT]
623+
) -> Sequence[NativeSQLExprT]:
628624
return [
629625
cast(
630-
"NativeExprT",
626+
"NativeSQLExprT",
631627
self._window_expression(
632628
self._function("row_number"),
633629
(*inputs.partition_by, expr),
@@ -707,20 +703,20 @@ def _partitioned_rank(
707703

708704
def is_unique(self) -> Self:
709705
def _is_unique(
710-
expr: NativeExprT, *partition_by: str | NativeExprT
711-
) -> NativeExprT:
706+
expr: NativeSQLExprT, *partition_by: str | NativeSQLExprT
707+
) -> NativeSQLExprT:
712708
return cast(
713-
"NativeExprT",
709+
"NativeSQLExprT",
714710
self._window_expression(self._count_star(), (expr, *partition_by))
715711
== self._lit(1),
716712
)
717713

718-
def _unpartitioned_is_unique(expr: NativeExprT) -> NativeExprT:
714+
def _unpartitioned_is_unique(expr: NativeSQLExprT) -> NativeSQLExprT:
719715
return _is_unique(expr)
720716

721717
def _partitioned_is_unique(
722-
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
723-
) -> Sequence[NativeExprT]:
718+
df: SQLLazyFrameT, inputs: WindowInputs[NativeSQLExprT]
719+
) -> Sequence[NativeSQLExprT]:
724720
assert not inputs.order_by # noqa: S101
725721
return [_is_unique(expr, *inputs.partition_by) for expr in self(df)]
726722

@@ -730,9 +726,9 @@ def _partitioned_is_unique(
730726

731727
# Other
732728
def over(
733-
self, partition_by: Sequence[str | NativeExprT], order_by: Sequence[str]
729+
self, partition_by: Sequence[str | NativeSQLExprT], order_by: Sequence[str]
734730
) -> Self:
735-
def func(df: SQLLazyFrameT) -> Sequence[NativeExprT]:
731+
def func(df: SQLLazyFrameT) -> Sequence[NativeSQLExprT]:
736732
return self.window_function(df, WindowInputs(partition_by, order_by))
737733

738734
return self.__class__(

0 commit comments

Comments
 (0)