Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion narwhals/_compliant/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from narwhals.dependencies import get_numpy, is_numpy_array

if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from collections.abc import Iterable, Mapping, Sequence

from typing_extensions import Self, TypeIs

Expand Down Expand Up @@ -894,6 +894,11 @@ def fn(names: Sequence[str]) -> Sequence[str]:
@classmethod
def _alias_native(cls, expr: NativeExprT, name: str, /) -> NativeExprT: ...

@classmethod
def _from_elementwise_horizontal_op(
cls, func: Callable[[Iterable[NativeExprT]], NativeExprT], *exprs: Self
) -> Self: ...

@property
def name(self) -> LazyExprNameNamespace[Self]:
return LazyExprNameNamespace(self)
Expand Down
1 change: 1 addition & 0 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,3 +683,4 @@ def dt(self) -> DaskExprDateTimeNamespace:
rank = not_implemented() # pyright: ignore[reportAssignmentType]
_alias_native = not_implemented()
window_function = not_implemented() # pyright: ignore[reportAssignmentType]
_from_elementwise_horizontal_op = not_implemented()
34 changes: 32 additions & 2 deletions narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@
when,
window_expression,
)
from narwhals._expression_parsing import ExprKind
from narwhals._expression_parsing import (
ExprKind,
combine_alias_output_names,
combine_evaluate_output_names,
)
from narwhals._utils import Implementation, not_implemented, requires

if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Iterable, Sequence

from duckdb import Expression
from typing_extensions import Self
Expand Down Expand Up @@ -217,6 +221,32 @@ def func(df: DuckDBLazyFrame) -> list[Expression]:
version=context._version,
)

@classmethod
def _from_elementwise_horizontal_op(
cls, func: Callable[[Iterable[Expression]], Expression], *exprs: Self
) -> Self:
def call(df: DuckDBLazyFrame) -> list[Expression]:
cols = (col for _expr in exprs for col in _expr(df))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these could just be:

return [func(chain.from_iterable(expr(df) for expr in exprs))]

return [func(cols)]

def window_function(
df: DuckDBLazyFrame, window_inputs: DuckDBWindowInputs
) -> list[Expression]:
cols = (
col for _expr in exprs for col in _expr.window_function(df, window_inputs)
)
return [func(cols)]

context = exprs[0]
return cls(
call=call,
window_function=window_function,
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
backend_version=context._backend_version,
version=context._version,
)

def _callable_to_eval_series(
self, call: Callable[..., Expression], /, **expressifiable_args: Self | Any
) -> EvalSeries[DuckDBLazyFrame, Expression]:
Expand Down
38 changes: 7 additions & 31 deletions narwhals/_duckdb/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import operator
from functools import reduce
from itertools import chain
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING

import duckdb
from duckdb import CoalesceOperator, Expression
Expand Down Expand Up @@ -49,30 +49,6 @@ def _expr(self) -> type[DuckDBExpr]:
def _lazyframe(self) -> type[DuckDBLazyFrame]:
return DuckDBLazyFrame

def _expr_from_elementwise(
self, func: Callable[[Iterable[Expression]], Expression], *exprs: DuckDBExpr
) -> DuckDBExpr:
def call(df: DuckDBLazyFrame) -> list[Expression]:
cols = (col for _expr in exprs for col in _expr(df))
return [func(cols)]

def window_function(
df: DuckDBLazyFrame, window_inputs: DuckDBWindowInputs
) -> list[Expression]:
cols = (
col for _expr in exprs for col in _expr.window_function(df, window_inputs)
)
return [func(cols)]

return self._expr(
call=call,
window_function=window_function,
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
backend_version=self._backend_version,
version=self._version,
)

def concat(
self, items: Iterable[DuckDBLazyFrame], *, how: ConcatMethod
) -> DuckDBLazyFrame:
Expand Down Expand Up @@ -124,7 +100,7 @@ def func(cols: Iterable[Expression]) -> Expression:
)
return reduce(operator.and_, it)

return self._expr_from_elementwise(func, *exprs)
return self._expr._from_elementwise_horizontal_op(func, *exprs)

def any_horizontal(self, *exprs: DuckDBExpr, ignore_nulls: bool) -> DuckDBExpr:
def func(cols: Iterable[Expression]) -> Expression:
Expand All @@ -135,25 +111,25 @@ def func(cols: Iterable[Expression]) -> Expression:
)
return reduce(operator.or_, it)

return self._expr_from_elementwise(func, *exprs)
return self._expr._from_elementwise_horizontal_op(func, *exprs)

def max_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr:
def func(cols: Iterable[Expression]) -> Expression:
return F("greatest", *cols)

return self._expr_from_elementwise(func, *exprs)
return self._expr._from_elementwise_horizontal_op(func, *exprs)

def min_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr:
def func(cols: Iterable[Expression]) -> Expression:
return F("least", *cols)

return self._expr_from_elementwise(func, *exprs)
return self._expr._from_elementwise_horizontal_op(func, *exprs)

def sum_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr:
def func(cols: Iterable[Expression]) -> Expression:
return reduce(operator.add, (CoalesceOperator(col, lit(0)) for col in cols))

return self._expr_from_elementwise(func, *exprs)
return self._expr._from_elementwise_horizontal_op(func, *exprs)

def mean_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr:
def func(cols: Iterable[Expression]) -> Expression:
Expand All @@ -162,7 +138,7 @@ def func(cols: Iterable[Expression]) -> Expression:
operator.add, (CoalesceOperator(col, lit(0)) for col in cols)
) / reduce(operator.add, (col.isnotnull().cast(BIGINT) for col in cols))

return self._expr_from_elementwise(func, *exprs)
return self._expr._from_elementwise_horizontal_op(func, *exprs)

def when(self, predicate: DuckDBExpr) -> DuckDBWhen:
return DuckDBWhen.from_expr(predicate, context=self)
Expand Down
23 changes: 22 additions & 1 deletion narwhals/_ibis/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

from narwhals._compliant import LazyExpr
from narwhals._compliant.window import WindowInputs
from narwhals._expression_parsing import (
combine_alias_output_names,
combine_evaluate_output_names,
)
from narwhals._ibis.expr_dt import IbisExprDateTimeNamespace
from narwhals._ibis.expr_list import IbisExprListNamespace
from narwhals._ibis.expr_str import IbisExprStringNamespace
Expand All @@ -16,7 +20,7 @@
from narwhals._utils import Implementation, not_implemented

if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from collections.abc import Iterable, Iterator, Sequence

import ibis.expr.types as ir
from typing_extensions import Self
Expand Down Expand Up @@ -203,6 +207,23 @@ def func(df: IbisLazyFrame) -> list[ir.Column]:
version=context._version,
)

@classmethod
def _from_elementwise_horizontal_op(
cls, func: Callable[[Iterable[ir.Value]], ir.Value], *exprs: Self
) -> Self:
def call(df: IbisLazyFrame) -> list[ir.Value]:
cols = (col for _expr in exprs for col in _expr(df))
return [func(cols)]

context = exprs[0]
return cls(
call=call,
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
backend_version=context._backend_version,
version=context._version,
)

def _with_callable(
self, call: Callable[..., ir.Value], /, **expressifiable_args: Self | Any
) -> Self:
Expand Down
29 changes: 7 additions & 22 deletions narwhals/_ibis/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import operator
from functools import reduce
from itertools import chain
from typing import TYPE_CHECKING, Any, Callable, cast
from typing import TYPE_CHECKING, Any, cast

import ibis
import ibis.expr.types as ir
Expand Down Expand Up @@ -33,21 +33,6 @@ def __init__(self, *, backend_version: tuple[int, ...], version: Version) -> Non
self._backend_version = backend_version
self._version = version

def _expr_from_callable(
self, func: Callable[[Iterable[ir.Value]], ir.Value], *exprs: IbisExpr
) -> IbisExpr:
def call(df: IbisLazyFrame) -> list[ir.Value]:
cols = (col for _expr in exprs for col in _expr(df))
return [func(cols)]

return self._expr(
call=call,
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
backend_version=self._backend_version,
version=self._version,
)

@property
def selectors(self) -> IbisSelectorNamespace:
return IbisSelectorNamespace.from_namespace(self)
Expand Down Expand Up @@ -109,7 +94,7 @@ def func(cols: Iterable[ir.Value]) -> ir.Value:
)
return reduce(operator.and_, it)

return self._expr_from_callable(func, *exprs)
return self._expr._from_elementwise_horizontal_op(func, *exprs)

def any_horizontal(self, *exprs: IbisExpr, ignore_nulls: bool) -> IbisExpr:
def func(cols: Iterable[ir.Value]) -> ir.Value:
Expand All @@ -120,26 +105,26 @@ def func(cols: Iterable[ir.Value]) -> ir.Value:
)
return reduce(operator.or_, it)

return self._expr_from_callable(func, *exprs)
return self._expr._from_elementwise_horizontal_op(func, *exprs)

def max_horizontal(self, *exprs: IbisExpr) -> IbisExpr:
def func(cols: Iterable[ir.Value]) -> ir.Value:
return ibis.greatest(*cols)

return self._expr_from_callable(func, *exprs)
return self._expr._from_elementwise_horizontal_op(func, *exprs)

def min_horizontal(self, *exprs: IbisExpr) -> IbisExpr:
def func(cols: Iterable[ir.Value]) -> ir.Value:
return ibis.least(*cols)

return self._expr_from_callable(func, *exprs)
return self._expr._from_elementwise_horizontal_op(func, *exprs)

def sum_horizontal(self, *exprs: IbisExpr) -> IbisExpr:
def func(cols: Iterable[ir.Value]) -> ir.Value:
cols = (col.fill_null(lit(0)) for col in cols)
return reduce(operator.add, cols)

return self._expr_from_callable(func, *exprs)
return self._expr._from_elementwise_horizontal_op(func, *exprs)

def mean_horizontal(self, *exprs: IbisExpr) -> IbisExpr:
def func(cols: Iterable[ir.Value]) -> ir.Value:
Expand All @@ -148,7 +133,7 @@ def func(cols: Iterable[ir.Value]) -> ir.Value:
operator.add, (col.isnull().ifelse(lit(0), lit(1)) for col in cols)
)

return self._expr_from_callable(func, *exprs)
return self._expr._from_elementwise_horizontal_op(func, *exprs)

@requires.backend_version((10, 0))
def when(self, predicate: IbisExpr) -> IbisWhen:
Expand Down
35 changes: 33 additions & 2 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

from narwhals._compliant import LazyExpr
from narwhals._compliant.window import WindowInputs
from narwhals._expression_parsing import ExprKind
from narwhals._expression_parsing import (
ExprKind,
combine_alias_output_names,
combine_evaluate_output_names,
)
from narwhals._spark_like.expr_dt import SparkLikeExprDateTimeNamespace
from narwhals._spark_like.expr_list import SparkLikeExprListNamespace
from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace
Expand All @@ -20,7 +24,7 @@
from narwhals.dependencies import get_pyspark

if TYPE_CHECKING:
from collections.abc import Iterator, Mapping, Sequence
from collections.abc import Iterable, Iterator, Mapping, Sequence

from sqlframe.base.column import Column
from sqlframe.base.window import Window, WindowSpec
Expand Down Expand Up @@ -277,6 +281,33 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
implementation=context._implementation,
)

@classmethod
def _from_elementwise_horizontal_op(
cls, func: Callable[[Iterable[Column]], Column], *exprs: Self
) -> Self:
def call(df: SparkLikeLazyFrame) -> list[Column]:
cols = (col for _expr in exprs for col in _expr(df))
return [func(cols)]

def window_function(
df: SparkLikeLazyFrame, window_inputs: SparkWindowInputs
) -> list[Column]:
cols = (
col for _expr in exprs for col in _expr.window_function(df, window_inputs)
)
return [func(cols)]

context = exprs[0]
return cls(
call=call,
window_function=window_function,
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
backend_version=context._backend_version,
version=context._version,
implementation=context._implementation,
)

def _callable_to_eval_series(
self, call: Callable[..., Column], /, **expressifiable_args: Self | Any
) -> EvalSeries[SparkLikeLazyFrame, Column]:
Expand Down
Loading
Loading