Skip to content

Commit f17acfc

Browse files
authored
chore: use _expr_from_elementwise in IbisNamespace (#2716)
1 parent a9879b4 commit f17acfc

File tree

3 files changed

+51
-86
lines changed

3 files changed

+51
-86
lines changed

narwhals/_duckdb/namespace.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _expr(self) -> type[DuckDBExpr]:
4949
def _lazyframe(self) -> type[DuckDBLazyFrame]:
5050
return DuckDBLazyFrame
5151

52-
def _with_elementwise(
52+
def _expr_from_elementwise(
5353
self, func: Callable[[Iterable[Expression]], Expression], *exprs: DuckDBExpr
5454
) -> DuckDBExpr:
5555
def call(df: DuckDBLazyFrame) -> list[Expression]:
@@ -119,31 +119,31 @@ def all_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr:
119119
def func(cols: Iterable[Expression]) -> Expression:
120120
return reduce(operator.and_, cols)
121121

122-
return self._with_elementwise(func, *exprs)
122+
return self._expr_from_elementwise(func, *exprs)
123123

124124
def any_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr:
125125
def func(cols: Iterable[Expression]) -> Expression:
126126
return reduce(operator.or_, cols)
127127

128-
return self._with_elementwise(func, *exprs)
128+
return self._expr_from_elementwise(func, *exprs)
129129

130130
def max_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr:
131131
def func(cols: Iterable[Expression]) -> Expression:
132132
return FunctionExpression("greatest", *cols)
133133

134-
return self._with_elementwise(func, *exprs)
134+
return self._expr_from_elementwise(func, *exprs)
135135

136136
def min_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr:
137137
def func(cols: Iterable[Expression]) -> Expression:
138138
return FunctionExpression("least", *cols)
139139

140-
return self._with_elementwise(func, *exprs)
140+
return self._expr_from_elementwise(func, *exprs)
141141

142142
def sum_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr:
143143
def func(cols: Iterable[Expression]) -> Expression:
144144
return reduce(operator.add, (CoalesceOperator(col, lit(0)) for col in cols))
145145

146-
return self._with_elementwise(func, *exprs)
146+
return self._expr_from_elementwise(func, *exprs)
147147

148148
def mean_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr:
149149
def func(cols: Iterable[Expression]) -> Expression:
@@ -152,7 +152,7 @@ def func(cols: Iterable[Expression]) -> Expression:
152152
operator.add, (CoalesceOperator(col, lit(0)) for col in cols)
153153
) / reduce(operator.add, (col.isnotnull().cast(BIGINT) for col in cols))
154154

155-
return self._with_elementwise(func, *exprs)
155+
return self._expr_from_elementwise(func, *exprs)
156156

157157
def when(self, predicate: DuckDBExpr) -> DuckDBWhen:
158158
return DuckDBWhen.from_expr(predicate, context=self)

narwhals/_ibis/namespace.py

Lines changed: 37 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import operator
44
from functools import reduce
55
from itertools import chain
6-
from typing import TYPE_CHECKING, Any, cast
6+
from typing import TYPE_CHECKING, Any, Callable, cast
77

88
import ibis
99
import ibis.expr.types as ir
@@ -33,6 +33,21 @@ def __init__(self, *, backend_version: tuple[int, ...], version: Version) -> Non
3333
self._backend_version = backend_version
3434
self._version = version
3535

36+
def _expr_from_callable(
37+
self, func: Callable[[Iterable[ir.Value]], ir.Value], *exprs: IbisExpr
38+
) -> IbisExpr:
39+
def call(df: IbisLazyFrame) -> list[ir.Value]:
40+
cols = (col for _expr in exprs for col in _expr(df))
41+
return [func(cols)]
42+
43+
return self._expr(
44+
call=call,
45+
evaluate_output_names=combine_evaluate_output_names(*exprs),
46+
alias_output_names=combine_alias_output_names(*exprs),
47+
backend_version=self._backend_version,
48+
version=self._version,
49+
)
50+
3651
@property
3752
def selectors(self) -> IbisSelectorNamespace:
3853
return IbisSelectorNamespace.from_namespace(self)
@@ -86,94 +101,44 @@ def func(df: IbisLazyFrame) -> list[ir.Value]:
86101
)
87102

88103
def all_horizontal(self, *exprs: IbisExpr) -> IbisExpr:
89-
def func(df: IbisLazyFrame) -> list[ir.Value]:
90-
cols = chain.from_iterable(expr(df) for expr in exprs)
91-
return [reduce(operator.and_, cols)]
104+
def func(cols: Iterable[ir.Value]) -> ir.Value:
105+
return reduce(operator.and_, cols)
92106

93-
return self._expr(
94-
call=func,
95-
evaluate_output_names=combine_evaluate_output_names(*exprs),
96-
alias_output_names=combine_alias_output_names(*exprs),
97-
backend_version=self._backend_version,
98-
version=self._version,
99-
)
107+
return self._expr_from_callable(func, *exprs)
100108

101109
def any_horizontal(self, *exprs: IbisExpr) -> IbisExpr:
102-
def func(df: IbisLazyFrame) -> list[ir.Value]:
103-
cols = chain.from_iterable(expr(df) for expr in exprs)
104-
return [reduce(operator.or_, cols)]
110+
def func(cols: Iterable[ir.Value]) -> ir.Value:
111+
return reduce(operator.or_, cols)
105112

106-
return self._expr(
107-
call=func,
108-
evaluate_output_names=combine_evaluate_output_names(*exprs),
109-
alias_output_names=combine_alias_output_names(*exprs),
110-
backend_version=self._backend_version,
111-
version=self._version,
112-
)
113+
return self._expr_from_callable(func, *exprs)
113114

114115
def max_horizontal(self, *exprs: IbisExpr) -> IbisExpr:
115-
def func(df: IbisLazyFrame) -> list[ir.Value]:
116-
cols = chain.from_iterable(expr(df) for expr in exprs)
117-
return [ibis.greatest(*cols)]
116+
def func(cols: Iterable[ir.Value]) -> ir.Value:
117+
return ibis.greatest(*cols)
118118

119-
return self._expr(
120-
call=func,
121-
evaluate_output_names=combine_evaluate_output_names(*exprs),
122-
alias_output_names=combine_alias_output_names(*exprs),
123-
backend_version=self._backend_version,
124-
version=self._version,
125-
)
119+
return self._expr_from_callable(func, *exprs)
126120

127121
def min_horizontal(self, *exprs: IbisExpr) -> IbisExpr:
128-
def func(df: IbisLazyFrame) -> list[ir.Value]:
129-
cols = chain.from_iterable(expr(df) for expr in exprs)
130-
return [ibis.least(*cols)]
122+
def func(cols: Iterable[ir.Value]) -> ir.Value:
123+
return ibis.least(*cols)
131124

132-
return self._expr(
133-
call=func,
134-
evaluate_output_names=combine_evaluate_output_names(*exprs),
135-
alias_output_names=combine_alias_output_names(*exprs),
136-
backend_version=self._backend_version,
137-
version=self._version,
138-
)
125+
return self._expr_from_callable(func, *exprs)
139126

140127
def sum_horizontal(self, *exprs: IbisExpr) -> IbisExpr:
141-
def func(df: IbisLazyFrame) -> list[ir.Value]:
142-
cols = [e.fill_null(lit(0)) for _expr in exprs for e in _expr(df)]
143-
return [reduce(operator.add, cols)]
128+
def func(cols: Iterable[ir.Value]) -> ir.Value:
129+
cols = (col.fill_null(lit(0)) for col in cols)
130+
return reduce(operator.add, cols)
144131

145-
return self._expr(
146-
call=func,
147-
evaluate_output_names=combine_evaluate_output_names(*exprs),
148-
alias_output_names=combine_alias_output_names(*exprs),
149-
backend_version=self._backend_version,
150-
version=self._version,
151-
)
132+
return self._expr_from_callable(func, *exprs)
152133

153134
def mean_horizontal(self, *exprs: IbisExpr) -> IbisExpr:
154-
def func(df: IbisLazyFrame) -> list[ir.Value]:
155-
expr = (
156-
cast("ir.NumericColumn", e.fill_null(lit(0)))
157-
for _expr in exprs
158-
for e in _expr(df)
159-
)
160-
non_null = (
161-
cast("ir.NumericColumn", e.isnull().ifelse(lit(0), lit(1)))
162-
for _expr in exprs
163-
for e in _expr(df)
135+
def func(cols: Iterable[ir.Value]) -> ir.Value:
136+
cols = list(cols)
137+
return reduce(operator.add, (col.fill_null(lit(0)) for col in cols)) / reduce(
138+
operator.add, (col.isnull().ifelse(lit(0), lit(1)) for col in cols)
164139
)
165140

166-
return [
167-
(reduce(lambda x, y: x + y, expr) / reduce(lambda x, y: x + y, non_null))
168-
]
169-
170-
return self._expr(
171-
call=func,
172-
evaluate_output_names=combine_evaluate_output_names(*exprs),
173-
alias_output_names=combine_alias_output_names(*exprs),
174-
backend_version=self._backend_version,
175-
version=self._version,
176-
)
141+
return self._expr_from_callable(func, *exprs)
177142

178143
@requires.backend_version((10, 0))
179144
def when(self, predicate: IbisExpr) -> IbisWhen:

narwhals/_spark_like/namespace.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _native_dtypes(self): # type: ignore[no-untyped-def] # noqa: ANN202
7373
else:
7474
return import_native_dtypes(self._implementation)
7575

76-
def _with_elementwise(
76+
def _expr_from_callable(
7777
self, func: Callable[[Iterable[Column]], Column], *exprs: SparkLikeExpr
7878
) -> SparkLikeExpr:
7979
def call(df: SparkLikeLazyFrame) -> list[Column]:
@@ -135,33 +135,33 @@ def all_horizontal(self, *exprs: SparkLikeExpr) -> SparkLikeExpr:
135135
def func(cols: Iterable[Column]) -> Column:
136136
return reduce(operator.and_, cols)
137137

138-
return self._with_elementwise(func, *exprs)
138+
return self._expr_from_callable(func, *exprs)
139139

140140
def any_horizontal(self, *exprs: SparkLikeExpr) -> SparkLikeExpr:
141141
def func(cols: Iterable[Column]) -> Column:
142142
return reduce(operator.or_, cols)
143143

144-
return self._with_elementwise(func, *exprs)
144+
return self._expr_from_callable(func, *exprs)
145145

146146
def max_horizontal(self, *exprs: SparkLikeExpr) -> SparkLikeExpr:
147147
def func(cols: Iterable[Column]) -> Column:
148148
return self._F.greatest(*cols)
149149

150-
return self._with_elementwise(func, *exprs)
150+
return self._expr_from_callable(func, *exprs)
151151

152152
def min_horizontal(self, *exprs: SparkLikeExpr) -> SparkLikeExpr:
153153
def func(cols: Iterable[Column]) -> Column:
154154
return self._F.least(*cols)
155155

156-
return self._with_elementwise(func, *exprs)
156+
return self._expr_from_callable(func, *exprs)
157157

158158
def sum_horizontal(self, *exprs: SparkLikeExpr) -> SparkLikeExpr:
159159
def func(cols: Iterable[Column]) -> Column:
160160
return reduce(
161161
operator.add, (self._F.coalesce(col, self._F.lit(0)) for col in cols)
162162
)
163163

164-
return self._with_elementwise(func, *exprs)
164+
return self._expr_from_callable(func, *exprs)
165165

166166
def mean_horizontal(self, *exprs: SparkLikeExpr) -> SparkLikeExpr:
167167
def func(cols: Iterable[Column]) -> Column:
@@ -182,7 +182,7 @@ def func(cols: Iterable[Column]) -> Column:
182182
),
183183
)
184184

185-
return self._with_elementwise(func, *exprs)
185+
return self._expr_from_callable(func, *exprs)
186186

187187
def concat(
188188
self, items: Iterable[SparkLikeLazyFrame], *, how: ConcatMethod

0 commit comments

Comments
 (0)