Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion narwhals/_compliant/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from narwhals.dependencies import get_numpy, is_numpy_array

if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from collections.abc import Iterable, Mapping, Sequence

from typing_extensions import Self, TypeIs

Expand Down Expand Up @@ -894,6 +894,11 @@ def fn(names: Sequence[str]) -> Sequence[str]:
@classmethod
def _alias_native(cls, expr: NativeExprT, name: str, /) -> NativeExprT: ...

@classmethod
def from_elementwise(
cls, func: Callable[[Iterable[NativeExprT]], NativeExprT], *exprs: Self
) -> Self: ...
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarcoGorelli would you be open to adding a little doc here? πŸ™

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, another thought.

What about LazyExpr.from_horizontal_op?

That would mirror the same constructor these all use for ExprMetadata

narwhals/narwhals/functions.py

Lines 1369 to 1374 in f17acfc

return Expr(
lambda plx: apply_n_ary_operation(
plx, plx.sum_horizontal, *flat_exprs, str_as_lit=False
),
ExprMetadata.from_horizontal_op(*flat_exprs),
)

@classmethod
def from_horizontal_op(cls, *exprs: IntoExpr) -> ExprMetadata:
return combine_metadata(
*exprs, str_as_lit=False, allow_multi_output=True, to_single_output=True
)

I think the detail that we should use this for *_horizontal (and concat_str) is more relevant than the elementwise part


@property
def name(self) -> LazyExprNameNamespace[Self]:
return LazyExprNameNamespace(self)
Expand Down
1 change: 1 addition & 0 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,3 +683,4 @@ def dt(self) -> DaskExprDateTimeNamespace:
rank = not_implemented() # pyright: ignore[reportAssignmentType]
_alias_native = not_implemented()
window_function = not_implemented() # pyright: ignore[reportAssignmentType]
from_elementwise = not_implemented()
34 changes: 32 additions & 2 deletions narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@
narwhals_to_native_dtype,
when,
)
from narwhals._expression_parsing import ExprKind
from narwhals._expression_parsing import (
ExprKind,
combine_alias_output_names,
combine_evaluate_output_names,
)
from narwhals._utils import Implementation, not_implemented, requires

if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Iterable, Sequence

from duckdb import Expression
from typing_extensions import Self
Expand Down Expand Up @@ -213,6 +217,32 @@ def func(df: DuckDBLazyFrame) -> list[Expression]:
version=context._version,
)

@classmethod
def from_elementwise(
cls, func: Callable[[Iterable[Expression]], Expression], *exprs: Self
) -> Self:
def call(df: DuckDBLazyFrame) -> list[Expression]:
cols = (col for _expr in exprs for col in _expr(df))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these could just be:

return [func(chain.from_iterable(expr(df) for expr in exprs))]

return [func(cols)]

def window_function(
df: DuckDBLazyFrame, window_inputs: DuckDBWindowInputs
) -> list[Expression]:
cols = (
col for _expr in exprs for col in _expr.window_function(df, window_inputs)
)
return [func(cols)]

context = exprs[0]
return cls(
call=call,
window_function=window_function,
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
backend_version=context._backend_version,
version=context._version,
)

def _callable_to_eval_series(
self, call: Callable[..., Expression], /, **expressifiable_args: Self | Any
) -> EvalSeries[DuckDBLazyFrame, Expression]:
Expand Down
38 changes: 7 additions & 31 deletions narwhals/_duckdb/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import operator
from functools import reduce
from itertools import chain
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING

import duckdb
from duckdb import CoalesceOperator, Expression, FunctionExpression
Expand Down Expand Up @@ -49,30 +49,6 @@ def _expr(self) -> type[DuckDBExpr]:
def _lazyframe(self) -> type[DuckDBLazyFrame]:
return DuckDBLazyFrame

def _expr_from_elementwise(
self, func: Callable[[Iterable[Expression]], Expression], *exprs: DuckDBExpr
) -> DuckDBExpr:
def call(df: DuckDBLazyFrame) -> list[Expression]:
cols = (col for _expr in exprs for col in _expr(df))
return [func(cols)]

def window_function(
df: DuckDBLazyFrame, window_inputs: DuckDBWindowInputs
) -> list[Expression]:
cols = (
col for _expr in exprs for col in _expr.window_function(df, window_inputs)
)
return [func(cols)]

return self._expr(
call=call,
window_function=window_function,
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
backend_version=self._backend_version,
version=self._version,
)

def concat(
self, items: Iterable[DuckDBLazyFrame], *, how: ConcatMethod
) -> DuckDBLazyFrame:
Expand Down Expand Up @@ -119,31 +95,31 @@ def all_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr:
def func(cols: Iterable[Expression]) -> Expression:
return reduce(operator.and_, cols)

return self._expr_from_elementwise(func, *exprs)
return self._expr.from_elementwise(func, *exprs)

def any_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr:
def func(cols: Iterable[Expression]) -> Expression:
return reduce(operator.or_, cols)

return self._expr_from_elementwise(func, *exprs)
return self._expr.from_elementwise(func, *exprs)

def max_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr:
def func(cols: Iterable[Expression]) -> Expression:
return FunctionExpression("greatest", *cols)

return self._expr_from_elementwise(func, *exprs)
return self._expr.from_elementwise(func, *exprs)

def min_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr:
def func(cols: Iterable[Expression]) -> Expression:
return FunctionExpression("least", *cols)

return self._expr_from_elementwise(func, *exprs)
return self._expr.from_elementwise(func, *exprs)

def sum_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr:
def func(cols: Iterable[Expression]) -> Expression:
return reduce(operator.add, (CoalesceOperator(col, lit(0)) for col in cols))

return self._expr_from_elementwise(func, *exprs)
return self._expr.from_elementwise(func, *exprs)

def mean_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr:
def func(cols: Iterable[Expression]) -> Expression:
Expand All @@ -152,7 +128,7 @@ def func(cols: Iterable[Expression]) -> Expression:
operator.add, (CoalesceOperator(col, lit(0)) for col in cols)
) / reduce(operator.add, (col.isnotnull().cast(BIGINT) for col in cols))

return self._expr_from_elementwise(func, *exprs)
return self._expr.from_elementwise(func, *exprs)

def when(self, predicate: DuckDBExpr) -> DuckDBWhen:
return DuckDBWhen.from_expr(predicate, context=self)
Expand Down
23 changes: 22 additions & 1 deletion narwhals/_ibis/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

from narwhals._compliant import LazyExpr
from narwhals._compliant.window import WindowInputs
from narwhals._expression_parsing import (
combine_alias_output_names,
combine_evaluate_output_names,
)
from narwhals._ibis.expr_dt import IbisExprDateTimeNamespace
from narwhals._ibis.expr_list import IbisExprListNamespace
from narwhals._ibis.expr_str import IbisExprStringNamespace
Expand All @@ -16,7 +20,7 @@
from narwhals._utils import Implementation, not_implemented

if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from collections.abc import Iterable, Iterator, Sequence

import ibis.expr.types as ir
from typing_extensions import Self
Expand Down Expand Up @@ -203,6 +207,23 @@ def func(df: IbisLazyFrame) -> list[ir.Column]:
version=context._version,
)

@classmethod
def from_elementwise(
cls, func: Callable[[Iterable[ir.Value]], ir.Value], *exprs: Self
) -> Self:
def call(df: IbisLazyFrame) -> list[ir.Value]:
cols = (col for _expr in exprs for col in _expr(df))
return [func(cols)]

context = exprs[0]
return cls(
call=call,
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
backend_version=context._backend_version,
version=context._version,
)

def _with_callable(
self, call: Callable[..., ir.Value], /, **expressifiable_args: Self | Any
) -> Self:
Expand Down
29 changes: 7 additions & 22 deletions narwhals/_ibis/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import operator
from functools import reduce
from itertools import chain
from typing import TYPE_CHECKING, Any, Callable, cast
from typing import TYPE_CHECKING, Any, cast

import ibis
import ibis.expr.types as ir
Expand Down Expand Up @@ -33,21 +33,6 @@ def __init__(self, *, backend_version: tuple[int, ...], version: Version) -> Non
self._backend_version = backend_version
self._version = version

def _expr_from_callable(
self, func: Callable[[Iterable[ir.Value]], ir.Value], *exprs: IbisExpr
) -> IbisExpr:
def call(df: IbisLazyFrame) -> list[ir.Value]:
cols = (col for _expr in exprs for col in _expr(df))
return [func(cols)]

return self._expr(
call=call,
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
backend_version=self._backend_version,
version=self._version,
)

@property
def selectors(self) -> IbisSelectorNamespace:
return IbisSelectorNamespace.from_namespace(self)
Expand Down Expand Up @@ -104,32 +89,32 @@ def all_horizontal(self, *exprs: IbisExpr) -> IbisExpr:
def func(cols: Iterable[ir.Value]) -> ir.Value:
return reduce(operator.and_, cols)

return self._expr_from_callable(func, *exprs)
return self._expr.from_elementwise(func, *exprs)

def any_horizontal(self, *exprs: IbisExpr) -> IbisExpr:
def func(cols: Iterable[ir.Value]) -> ir.Value:
return reduce(operator.or_, cols)

return self._expr_from_callable(func, *exprs)
return self._expr.from_elementwise(func, *exprs)

def max_horizontal(self, *exprs: IbisExpr) -> IbisExpr:
def func(cols: Iterable[ir.Value]) -> ir.Value:
return ibis.greatest(*cols)

return self._expr_from_callable(func, *exprs)
return self._expr.from_elementwise(func, *exprs)

def min_horizontal(self, *exprs: IbisExpr) -> IbisExpr:
def func(cols: Iterable[ir.Value]) -> ir.Value:
return ibis.least(*cols)

return self._expr_from_callable(func, *exprs)
return self._expr.from_elementwise(func, *exprs)

def sum_horizontal(self, *exprs: IbisExpr) -> IbisExpr:
def func(cols: Iterable[ir.Value]) -> ir.Value:
cols = (col.fill_null(lit(0)) for col in cols)
return reduce(operator.add, cols)

return self._expr_from_callable(func, *exprs)
return self._expr.from_elementwise(func, *exprs)

def mean_horizontal(self, *exprs: IbisExpr) -> IbisExpr:
def func(cols: Iterable[ir.Value]) -> ir.Value:
Expand All @@ -138,7 +123,7 @@ def func(cols: Iterable[ir.Value]) -> ir.Value:
operator.add, (col.isnull().ifelse(lit(0), lit(1)) for col in cols)
)

return self._expr_from_callable(func, *exprs)
return self._expr.from_elementwise(func, *exprs)

@requires.backend_version((10, 0))
def when(self, predicate: IbisExpr) -> IbisWhen:
Expand Down
35 changes: 33 additions & 2 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

from narwhals._compliant import LazyExpr
from narwhals._compliant.window import WindowInputs
from narwhals._expression_parsing import ExprKind
from narwhals._expression_parsing import (
ExprKind,
combine_alias_output_names,
combine_evaluate_output_names,
)
from narwhals._spark_like.expr_dt import SparkLikeExprDateTimeNamespace
from narwhals._spark_like.expr_list import SparkLikeExprListNamespace
from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace
Expand All @@ -20,7 +24,7 @@
from narwhals.dependencies import get_pyspark

if TYPE_CHECKING:
from collections.abc import Iterator, Mapping, Sequence
from collections.abc import Iterable, Iterator, Mapping, Sequence

from sqlframe.base.column import Column
from sqlframe.base.window import Window, WindowSpec
Expand Down Expand Up @@ -277,6 +281,33 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
implementation=context._implementation,
)

@classmethod
def from_elementwise(
cls, func: Callable[[Iterable[Column]], Column], *exprs: Self
) -> Self:
def call(df: SparkLikeLazyFrame) -> list[Column]:
cols = (col for _expr in exprs for col in _expr(df))
return [func(cols)]

def window_function(
df: SparkLikeLazyFrame, window_inputs: SparkWindowInputs
) -> list[Column]:
cols = (
col for _expr in exprs for col in _expr.window_function(df, window_inputs)
)
return [func(cols)]

context = exprs[0]
return cls(
call=call,
window_function=window_function,
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
backend_version=context._backend_version,
version=context._version,
implementation=context._implementation,
)

def _callable_to_eval_series(
self, call: Callable[..., Column], /, **expressifiable_args: Self | Any
) -> EvalSeries[SparkLikeLazyFrame, Column]:
Expand Down
Loading
Loading