Skip to content

Commit 24519da

Browse files
refactor: Enforce LazyExpr._from_elementwise_horizontal_op (#2718)
* refactor: Enforce `LazyExpr.from_elementwise` See #2716 (comment) * rename to _from_elementwise_horizontal_op * fix typing * actually fix typing --------- Co-authored-by: Marco Gorelli <[email protected]>
1 parent e4f7d30 commit 24519da

File tree

8 files changed

+115
-91
lines changed

8 files changed

+115
-91
lines changed

narwhals/_compliant/expr.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from narwhals.dependencies import get_numpy, is_numpy_array
3333

3434
if TYPE_CHECKING:
35-
from collections.abc import Mapping, Sequence
35+
from collections.abc import Iterable, Mapping, Sequence
3636

3737
from typing_extensions import Self, TypeIs
3838

@@ -894,6 +894,11 @@ def fn(names: Sequence[str]) -> Sequence[str]:
894894
@classmethod
895895
def _alias_native(cls, expr: NativeExprT, name: str, /) -> NativeExprT: ...
896896

897+
@classmethod
898+
def _from_elementwise_horizontal_op(
899+
cls, func: Callable[[Iterable[NativeExprT]], NativeExprT], *exprs: Self
900+
) -> Self: ...
901+
897902
@property
898903
def name(self) -> LazyExprNameNamespace[Self]:
899904
return LazyExprNameNamespace(self)

narwhals/_dask/expr.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,3 +683,4 @@ def dt(self) -> DaskExprDateTimeNamespace:
683683
rank = not_implemented() # pyright: ignore[reportAssignmentType]
684684
_alias_native = not_implemented()
685685
window_function = not_implemented() # pyright: ignore[reportAssignmentType]
686+
_from_elementwise_horizontal_op = not_implemented()

narwhals/_duckdb/expr.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,15 @@
2020
when,
2121
window_expression,
2222
)
23-
from narwhals._expression_parsing import ExprKind
23+
from narwhals._expression_parsing import (
24+
ExprKind,
25+
combine_alias_output_names,
26+
combine_evaluate_output_names,
27+
)
2428
from narwhals._utils import Implementation, not_implemented, requires
2529

2630
if TYPE_CHECKING:
27-
from collections.abc import Sequence
31+
from collections.abc import Iterable, Sequence
2832

2933
from duckdb import Expression
3034
from typing_extensions import Self
@@ -217,6 +221,32 @@ def func(df: DuckDBLazyFrame) -> list[Expression]:
217221
version=context._version,
218222
)
219223

224+
@classmethod
225+
def _from_elementwise_horizontal_op(
226+
cls, func: Callable[[Iterable[Expression]], Expression], *exprs: Self
227+
) -> Self:
228+
def call(df: DuckDBLazyFrame) -> list[Expression]:
229+
cols = (col for _expr in exprs for col in _expr(df))
230+
return [func(cols)]
231+
232+
def window_function(
233+
df: DuckDBLazyFrame, window_inputs: DuckDBWindowInputs
234+
) -> list[Expression]:
235+
cols = (
236+
col for _expr in exprs for col in _expr.window_function(df, window_inputs)
237+
)
238+
return [func(cols)]
239+
240+
context = exprs[0]
241+
return cls(
242+
call=call,
243+
window_function=window_function,
244+
evaluate_output_names=combine_evaluate_output_names(*exprs),
245+
alias_output_names=combine_alias_output_names(*exprs),
246+
backend_version=context._backend_version,
247+
version=context._version,
248+
)
249+
220250
def _callable_to_eval_series(
221251
self, call: Callable[..., Expression], /, **expressifiable_args: Self | Any
222252
) -> EvalSeries[DuckDBLazyFrame, Expression]:

narwhals/_duckdb/namespace.py

Lines changed: 7 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import operator
44
from functools import reduce
55
from itertools import chain
6-
from typing import TYPE_CHECKING, Callable
6+
from typing import TYPE_CHECKING
77

88
import duckdb
99
from duckdb import CoalesceOperator, Expression
@@ -49,30 +49,6 @@ def _expr(self) -> type[DuckDBExpr]:
4949
def _lazyframe(self) -> type[DuckDBLazyFrame]:
5050
return DuckDBLazyFrame
5151

52-
def _expr_from_elementwise(
53-
self, func: Callable[[Iterable[Expression]], Expression], *exprs: DuckDBExpr
54-
) -> DuckDBExpr:
55-
def call(df: DuckDBLazyFrame) -> list[Expression]:
56-
cols = (col for _expr in exprs for col in _expr(df))
57-
return [func(cols)]
58-
59-
def window_function(
60-
df: DuckDBLazyFrame, window_inputs: DuckDBWindowInputs
61-
) -> list[Expression]:
62-
cols = (
63-
col for _expr in exprs for col in _expr.window_function(df, window_inputs)
64-
)
65-
return [func(cols)]
66-
67-
return self._expr(
68-
call=call,
69-
window_function=window_function,
70-
evaluate_output_names=combine_evaluate_output_names(*exprs),
71-
alias_output_names=combine_alias_output_names(*exprs),
72-
backend_version=self._backend_version,
73-
version=self._version,
74-
)
75-
7652
def concat(
7753
self, items: Iterable[DuckDBLazyFrame], *, how: ConcatMethod
7854
) -> DuckDBLazyFrame:
@@ -124,7 +100,7 @@ def func(cols: Iterable[Expression]) -> Expression:
124100
)
125101
return reduce(operator.and_, it)
126102

127-
return self._expr_from_elementwise(func, *exprs)
103+
return self._expr._from_elementwise_horizontal_op(func, *exprs)
128104

129105
def any_horizontal(self, *exprs: DuckDBExpr, ignore_nulls: bool) -> DuckDBExpr:
130106
def func(cols: Iterable[Expression]) -> Expression:
@@ -135,25 +111,25 @@ def func(cols: Iterable[Expression]) -> Expression:
135111
)
136112
return reduce(operator.or_, it)
137113

138-
return self._expr_from_elementwise(func, *exprs)
114+
return self._expr._from_elementwise_horizontal_op(func, *exprs)
139115

140116
def max_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr:
141117
def func(cols: Iterable[Expression]) -> Expression:
142118
return F("greatest", *cols)
143119

144-
return self._expr_from_elementwise(func, *exprs)
120+
return self._expr._from_elementwise_horizontal_op(func, *exprs)
145121

146122
def min_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr:
147123
def func(cols: Iterable[Expression]) -> Expression:
148124
return F("least", *cols)
149125

150-
return self._expr_from_elementwise(func, *exprs)
126+
return self._expr._from_elementwise_horizontal_op(func, *exprs)
151127

152128
def sum_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr:
153129
def func(cols: Iterable[Expression]) -> Expression:
154130
return reduce(operator.add, (CoalesceOperator(col, lit(0)) for col in cols))
155131

156-
return self._expr_from_elementwise(func, *exprs)
132+
return self._expr._from_elementwise_horizontal_op(func, *exprs)
157133

158134
def mean_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr:
159135
def func(cols: Iterable[Expression]) -> Expression:
@@ -162,7 +138,7 @@ def func(cols: Iterable[Expression]) -> Expression:
162138
operator.add, (CoalesceOperator(col, lit(0)) for col in cols)
163139
) / reduce(operator.add, (col.isnotnull().cast(BIGINT) for col in cols))
164140

165-
return self._expr_from_elementwise(func, *exprs)
141+
return self._expr._from_elementwise_horizontal_op(func, *exprs)
166142

167143
def when(self, predicate: DuckDBExpr) -> DuckDBWhen:
168144
return DuckDBWhen.from_expr(predicate, context=self)

narwhals/_ibis/expr.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88

99
from narwhals._compliant import LazyExpr
1010
from narwhals._compliant.window import WindowInputs
11+
from narwhals._expression_parsing import (
12+
combine_alias_output_names,
13+
combine_evaluate_output_names,
14+
)
1115
from narwhals._ibis.expr_dt import IbisExprDateTimeNamespace
1216
from narwhals._ibis.expr_list import IbisExprListNamespace
1317
from narwhals._ibis.expr_str import IbisExprStringNamespace
@@ -16,7 +20,7 @@
1620
from narwhals._utils import Implementation, not_implemented
1721

1822
if TYPE_CHECKING:
19-
from collections.abc import Iterator, Sequence
23+
from collections.abc import Iterable, Iterator, Sequence
2024

2125
import ibis.expr.types as ir
2226
from typing_extensions import Self
@@ -203,6 +207,23 @@ def func(df: IbisLazyFrame) -> list[ir.Column]:
203207
version=context._version,
204208
)
205209

210+
@classmethod
211+
def _from_elementwise_horizontal_op(
212+
cls, func: Callable[[Iterable[ir.Value]], ir.Value], *exprs: Self
213+
) -> Self:
214+
def call(df: IbisLazyFrame) -> list[ir.Value]:
215+
cols = (col for _expr in exprs for col in _expr(df))
216+
return [func(cols)]
217+
218+
context = exprs[0]
219+
return cls(
220+
call=call,
221+
evaluate_output_names=combine_evaluate_output_names(*exprs),
222+
alias_output_names=combine_alias_output_names(*exprs),
223+
backend_version=context._backend_version,
224+
version=context._version,
225+
)
226+
206227
def _with_callable(
207228
self, call: Callable[..., ir.Value], /, **expressifiable_args: Self | Any
208229
) -> Self:

narwhals/_ibis/namespace.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import operator
44
from functools import reduce
55
from itertools import chain
6-
from typing import TYPE_CHECKING, Any, Callable, cast
6+
from typing import TYPE_CHECKING, Any, cast
77

88
import ibis
99
import ibis.expr.types as ir
@@ -33,21 +33,6 @@ def __init__(self, *, backend_version: tuple[int, ...], version: Version) -> Non
3333
self._backend_version = backend_version
3434
self._version = version
3535

36-
def _expr_from_callable(
37-
self, func: Callable[[Iterable[ir.Value]], ir.Value], *exprs: IbisExpr
38-
) -> IbisExpr:
39-
def call(df: IbisLazyFrame) -> list[ir.Value]:
40-
cols = (col for _expr in exprs for col in _expr(df))
41-
return [func(cols)]
42-
43-
return self._expr(
44-
call=call,
45-
evaluate_output_names=combine_evaluate_output_names(*exprs),
46-
alias_output_names=combine_alias_output_names(*exprs),
47-
backend_version=self._backend_version,
48-
version=self._version,
49-
)
50-
5136
@property
5237
def selectors(self) -> IbisSelectorNamespace:
5338
return IbisSelectorNamespace.from_namespace(self)
@@ -109,7 +94,7 @@ def func(cols: Iterable[ir.Value]) -> ir.Value:
10994
)
11095
return reduce(operator.and_, it)
11196

112-
return self._expr_from_callable(func, *exprs)
97+
return self._expr._from_elementwise_horizontal_op(func, *exprs)
11398

11499
def any_horizontal(self, *exprs: IbisExpr, ignore_nulls: bool) -> IbisExpr:
115100
def func(cols: Iterable[ir.Value]) -> ir.Value:
@@ -120,26 +105,26 @@ def func(cols: Iterable[ir.Value]) -> ir.Value:
120105
)
121106
return reduce(operator.or_, it)
122107

123-
return self._expr_from_callable(func, *exprs)
108+
return self._expr._from_elementwise_horizontal_op(func, *exprs)
124109

125110
def max_horizontal(self, *exprs: IbisExpr) -> IbisExpr:
126111
def func(cols: Iterable[ir.Value]) -> ir.Value:
127112
return ibis.greatest(*cols)
128113

129-
return self._expr_from_callable(func, *exprs)
114+
return self._expr._from_elementwise_horizontal_op(func, *exprs)
130115

131116
def min_horizontal(self, *exprs: IbisExpr) -> IbisExpr:
132117
def func(cols: Iterable[ir.Value]) -> ir.Value:
133118
return ibis.least(*cols)
134119

135-
return self._expr_from_callable(func, *exprs)
120+
return self._expr._from_elementwise_horizontal_op(func, *exprs)
136121

137122
def sum_horizontal(self, *exprs: IbisExpr) -> IbisExpr:
138123
def func(cols: Iterable[ir.Value]) -> ir.Value:
139124
cols = (col.fill_null(lit(0)) for col in cols)
140125
return reduce(operator.add, cols)
141126

142-
return self._expr_from_callable(func, *exprs)
127+
return self._expr._from_elementwise_horizontal_op(func, *exprs)
143128

144129
def mean_horizontal(self, *exprs: IbisExpr) -> IbisExpr:
145130
def func(cols: Iterable[ir.Value]) -> ir.Value:
@@ -148,7 +133,7 @@ def func(cols: Iterable[ir.Value]) -> ir.Value:
148133
operator.add, (col.isnull().ifelse(lit(0), lit(1)) for col in cols)
149134
)
150135

151-
return self._expr_from_callable(func, *exprs)
136+
return self._expr._from_elementwise_horizontal_op(func, *exprs)
152137

153138
@requires.backend_version((10, 0))
154139
def when(self, predicate: IbisExpr) -> IbisWhen:

narwhals/_spark_like/expr.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55

66
from narwhals._compliant import LazyExpr
77
from narwhals._compliant.window import WindowInputs
8-
from narwhals._expression_parsing import ExprKind
8+
from narwhals._expression_parsing import (
9+
ExprKind,
10+
combine_alias_output_names,
11+
combine_evaluate_output_names,
12+
)
913
from narwhals._spark_like.expr_dt import SparkLikeExprDateTimeNamespace
1014
from narwhals._spark_like.expr_list import SparkLikeExprListNamespace
1115
from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace
@@ -20,7 +24,7 @@
2024
from narwhals.dependencies import get_pyspark
2125

2226
if TYPE_CHECKING:
23-
from collections.abc import Iterator, Mapping, Sequence
27+
from collections.abc import Iterable, Iterator, Mapping, Sequence
2428

2529
from sqlframe.base.column import Column
2630
from sqlframe.base.window import Window, WindowSpec
@@ -277,6 +281,33 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
277281
implementation=context._implementation,
278282
)
279283

284+
@classmethod
285+
def _from_elementwise_horizontal_op(
286+
cls, func: Callable[[Iterable[Column]], Column], *exprs: Self
287+
) -> Self:
288+
def call(df: SparkLikeLazyFrame) -> list[Column]:
289+
cols = (col for _expr in exprs for col in _expr(df))
290+
return [func(cols)]
291+
292+
def window_function(
293+
df: SparkLikeLazyFrame, window_inputs: SparkWindowInputs
294+
) -> list[Column]:
295+
cols = (
296+
col for _expr in exprs for col in _expr.window_function(df, window_inputs)
297+
)
298+
return [func(cols)]
299+
300+
context = exprs[0]
301+
return cls(
302+
call=call,
303+
window_function=window_function,
304+
evaluate_output_names=combine_evaluate_output_names(*exprs),
305+
alias_output_names=combine_alias_output_names(*exprs),
306+
backend_version=context._backend_version,
307+
version=context._version,
308+
implementation=context._implementation,
309+
)
310+
280311
def _callable_to_eval_series(
281312
self, call: Callable[..., Column], /, **expressifiable_args: Self | Any
282313
) -> EvalSeries[SparkLikeLazyFrame, Column]:

0 commit comments

Comments
 (0)