Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 10 additions & 14 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def mean_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:

def func(df: ArrowDataFrame) -> list[ArrowSeries]:
expr_results = tuple(chain.from_iterable(expr(df) for expr in exprs))
series = [s.fill_null(0, strategy=None, limit=None) for s in expr_results]
non_na = [1 - s.is_null().cast(int_64) for s in expr_results]
series = (s.fill_null(0, strategy=None, limit=None) for s in expr_results)
non_na = (1 - s.is_null().cast(int_64) for s in expr_results)
return [reduce(operator.add, series) / reduce(operator.add, non_na)]

return self._expr._from_callable(
Expand All @@ -139,13 +139,11 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:

def min_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
init_series, *series = tuple(chain.from_iterable(expr(df) for expr in exprs))
native_series = reduce(
pc.min_element_wise, [s.native for s in series], init_series.native
series = tuple(chain.from_iterable(expr(df) for expr in exprs))
result = reduce(
lambda s1, s2: s1._with_binary(pc.min_element_wise, s2), series
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let _with_binary take care of the broadcasting for scalars

Copy link
Member

@dangotbanned dangotbanned Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pc.min_element_wise doesn't need broadcasting btw

Copy link
Member Author

@FBruzzesi FBruzzesi Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given a scalar input, by the time we get here, it's a length one array, not a scalar anymore, and we get a shape mismatch

Copy link
Member

@dangotbanned dangotbanned Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see 🤦

I was thinking of this behavior:

(Scalar, Scalar) - > Scalar
(Scalar, Array) -> Array
(Array, Array) -> Array

Which is what broadcasting is, but I forgot the Scalar preservation is only in #2572

Copy link
Member Author

@FBruzzesi FBruzzesi Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0cad6c1 should reduce the overhead of both "align"-ing first, and then reduce-ing

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@FBruzzesi 😍😍😍

)
return [
ArrowSeries(native_series, name=init_series.name, version=self._version)
]
return [result]

return self._expr._from_callable(
func=func,
Expand All @@ -156,13 +154,11 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:

def max_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
init_series, *series = tuple(chain.from_iterable(expr(df) for expr in exprs))
native_series = reduce(
pc.max_element_wise, [s.native for s in series], init_series.native
series = tuple(chain.from_iterable(expr(df) for expr in exprs))
result = reduce(
lambda s1, s2: s1._with_binary(pc.max_element_wise, s2), series
)
return [
ArrowSeries(native_series, name=init_series.name, version=self._version)
]
return [result]

return self._expr._from_callable(
func=func,
Expand Down
8 changes: 4 additions & 4 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,11 @@ def concat(

def mean_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
expr_results = [s for _expr in exprs for s in _expr(df)]
series = align_series_full_broadcast(df, *(s.fillna(0) for s in expr_results))
non_na = align_series_full_broadcast(
df, *(1 - s.isna() for s in expr_results)
expr_results = align_series_full_broadcast(
df, *[s for _expr in exprs for s in _expr(df)]
)
series = (s.fillna(0) for s in expr_results)
non_na = (1 - s.isna() for s in expr_results)
Comment on lines -175 to +179
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First broadcast (once) so that we ensure to have series (i.e. no more dask_expr.Scalar's), then it's possible to perform fillna() and isna() safely

num = reduce(lambda x, y: x + y, series) # pyright: ignore[reportOperatorIssue]
den = reduce(lambda x, y: x + y, non_na) # pyright: ignore[reportOperatorIssue]
return [cast("dx.Series", num / den)] # pyright: ignore[reportOperatorIssue]
Expand Down
8 changes: 6 additions & 2 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,9 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:

def min_horizontal(self, *exprs: PandasLikeExpr) -> PandasLikeExpr:
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
series = list(chain.from_iterable(expr(df) for expr in exprs))
series = self._series._align_full_broadcast(
*chain.from_iterable(expr(df) for expr in exprs)
)
Comment on lines -238 to +240
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to broadcast scalars first

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dangotbanned unlike the arrow case, here we need to align the series for how we implement the operation after

Comment on lines -238 to +240
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'd like to check if we need to do this align-full-broadcast, eye-balling coalesce again it looks like we didn't need to do it there, can't remember why right now

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's because we don't allow scalars in coalesce?

(Disclaimer: from mobile, I only checked signature)

return [
PandasLikeSeries(
self.concat(
Expand All @@ -255,7 +257,9 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:

def max_horizontal(self, *exprs: PandasLikeExpr) -> PandasLikeExpr:
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
series = list(chain.from_iterable(expr(df) for expr in exprs))
series = self._series._align_full_broadcast(
*chain.from_iterable(expr(df) for expr in exprs)
)
return [
PandasLikeSeries(
self.concat(
Expand Down
26 changes: 20 additions & 6 deletions narwhals/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,7 +1217,9 @@ def _expr_with_horizontal_op(name: str, *exprs: IntoExpr, **kwargs: Any) -> Expr
)


def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
def sum_horizontal(
*exprs: PythonLiteral | IntoExpr | Iterable[PythonLiteral | IntoExpr],
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will comment here to exemplify the behavior. Polars definition of IntoExpr is already including python literal. They distinguish between:

# Inputs that can convert into a `col` expression
IntoExprColumn: TypeAlias = Union["Expr", "Series", str]
# Inputs that can convert into an expression
IntoExpr: TypeAlias = PythonLiteral | IntoExprColumn | None

I am not sure if it's reasonable for us to eventually align with those definition and distinction

code

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this in polars too.

I think an issue in narwhals could be the special-casing for pandas non-str column "names".
But I'm not sure how far that support extends, e.g. do we actually support it everywhere?

) -> Expr:
"""Sum all values horizontally across columns.

Warning:
Expand Down Expand Up @@ -1251,7 +1253,9 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
return _expr_with_horizontal_op("sum_horizontal", *flatten(exprs))


def min_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
def min_horizontal(
*exprs: PythonLiteral | IntoExpr | Iterable[PythonLiteral | IntoExpr],
) -> Expr:
"""Get the minimum value horizontally across columns.

Notes:
Expand Down Expand Up @@ -1283,7 +1287,9 @@ def min_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
return _expr_with_horizontal_op("min_horizontal", *flatten(exprs))


def max_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
def max_horizontal(
*exprs: PythonLiteral | IntoExpr | Iterable[PythonLiteral | IntoExpr],
) -> Expr:
"""Get the maximum value horizontally across columns.

Notes:
Expand Down Expand Up @@ -1381,7 +1387,10 @@ def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When:
return When(*predicates)


def all_horizontal(*exprs: IntoExpr | Iterable[IntoExpr], ignore_nulls: bool) -> Expr:
def all_horizontal(
*exprs: PythonLiteral | IntoExpr | Iterable[PythonLiteral | IntoExpr],
ignore_nulls: bool,
) -> Expr:
r"""Compute the bitwise AND horizontally across columns.

Arguments:
Expand Down Expand Up @@ -1510,7 +1519,10 @@ def lit(value: PythonLiteral, dtype: IntoDType | None = None) -> Expr:
return Expr(ExprNode(ExprKind.LITERAL, "lit", value=value, dtype=dtype))


def any_horizontal(*exprs: IntoExpr | Iterable[IntoExpr], ignore_nulls: bool) -> Expr:
def any_horizontal(
*exprs: PythonLiteral | IntoExpr | Iterable[PythonLiteral | IntoExpr],
ignore_nulls: bool,
) -> Expr:
r"""Compute the bitwise OR horizontally across columns.

Arguments:
Expand Down Expand Up @@ -1558,7 +1570,9 @@ def any_horizontal(*exprs: IntoExpr | Iterable[IntoExpr], ignore_nulls: bool) ->
)


def mean_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
def mean_horizontal(
*exprs: PythonLiteral | IntoExpr | Iterable[PythonLiteral | IntoExpr],
) -> Expr:
"""Compute the mean of all values horizontally across columns.

Arguments:
Expand Down
23 changes: 22 additions & 1 deletion tests/expr_and_series/all_horizontal_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from __future__ import annotations

from contextlib import nullcontext as does_not_raise
from typing import Any
from typing import TYPE_CHECKING, Any

import pytest

import narwhals as nw
from tests.utils import POLARS_VERSION, Constructor, ConstructorEager, assert_equal_data

if TYPE_CHECKING:
from narwhals.typing import PythonLiteral


def test_allh(constructor: Constructor) -> None:
data = {"a": [False, False, True], "b": [False, True, True]}
Expand Down Expand Up @@ -157,3 +160,21 @@ def test_horizontal_expressions_empty(constructor: Constructor) -> None:
ValueError, match=r"At least one expression must be passed.*min_horizontal"
):
df.select(nw.min_horizontal())


@pytest.mark.parametrize(
("exprs", "name"),
[
((nw.col("a"), True), "a"),
((nw.col("a"), nw.lit(True)), "a"),
((True, nw.col("a")), "literal"),
((nw.lit(True), nw.col("a")), "literal"),
],
)
def test_allh_with_scalars(
constructor: Constructor, exprs: tuple[PythonLiteral | nw.Expr, ...], name: str
) -> None:
data = {"a": [False, True]}
df = nw.from_native(constructor(data))
result = df.select(nw.all_horizontal(*exprs, ignore_nulls=True))
assert_equal_data(result, {name: [False, True]})
22 changes: 22 additions & 0 deletions tests/expr_and_series/any_horizontal_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import annotations

from contextlib import nullcontext as does_not_raise
from typing import TYPE_CHECKING

import pytest

import narwhals as nw
from tests.utils import Constructor, assert_equal_data

if TYPE_CHECKING:
from narwhals.typing import PythonLiteral


def test_anyh(constructor: Constructor) -> None:
data = {"a": [False, False, True], "b": [False, True, True]}
Expand Down Expand Up @@ -85,3 +89,21 @@ def test_anyh_all(constructor: Constructor) -> None:
result = df.select(nw.any_horizontal(nw.all(), ignore_nulls=False))
expected = {"a": [False, True, True]}
assert_equal_data(result, expected)


@pytest.mark.parametrize(
("exprs", "name"),
[
((nw.col("a"), False), "a"),
((nw.col("a"), nw.lit(False)), "a"),
((False, nw.col("a")), "literal"),
((nw.lit(False), nw.col("a")), "literal"),
],
)
def test_anyh_with_scalars(
constructor: Constructor, exprs: tuple[PythonLiteral | nw.Expr, ...], name: str
) -> None:
data = {"a": [False, True]}
df = nw.from_native(constructor(data))
result = df.select(nw.any_horizontal(*exprs, ignore_nulls=True))
assert_equal_data(result, {name: [False, True]})
22 changes: 22 additions & 0 deletions tests/expr_and_series/max_horizontal_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

import narwhals as nw
from tests.utils import Constructor, assert_equal_data

if TYPE_CHECKING:
from narwhals.typing import PythonLiteral

data = {"a": [1, 3, None, None], "b": [4, None, 6, None], "z": [3, 1, None, None]}
expected_values = [4, 3, 6, None]

Expand All @@ -23,3 +28,20 @@ def test_maxh_all(constructor: Constructor) -> None:
result = df.select(nw.max_horizontal(nw.all()), c=nw.max_horizontal(nw.all()))
expected = {"a": expected_values, "c": expected_values}
assert_equal_data(result, expected)


@pytest.mark.parametrize(
("exprs", "name"),
[
((nw.col("a"), 2), "a"),
((nw.col("a"), nw.lit(2)), "a"),
((2, nw.col("a")), "literal"),
((nw.lit(2), nw.col("a")), "literal"),
],
)
def test_maxh_with_scalars(
constructor: Constructor, exprs: tuple[PythonLiteral | nw.Expr, ...], name: str
) -> None:
df = nw.from_native(constructor({"a": [1, 2, 3]}))
result = df.select(nw.max_horizontal(*exprs))
assert_equal_data(result, {name: [2, 2, 3]})
29 changes: 24 additions & 5 deletions tests/expr_and_series/mean_horizontal_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

import narwhals as nw
from tests.utils import Constructor, assert_equal_data

if TYPE_CHECKING:
from narwhals.typing import PythonLiteral


def test_meanh(constructor: Constructor) -> None:
data = {"a": [1, 3, None, None], "b": [4, None, 6, None]}
Expand All @@ -14,11 +19,7 @@ def test_meanh(constructor: Constructor) -> None:
assert_equal_data(result, expected)


def test_meanh_with_literal(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_meanh_with_literal(constructor: Constructor) -> None:
data = {"a": [1, 3, None, None], "b": [4, None, 6, None]}
df = nw.from_native(constructor(data))
result = df.select(horizontal_mean=nw.mean_horizontal(nw.lit(1), "a", nw.col("b")))
Expand All @@ -35,3 +36,21 @@ def test_meanh_all(constructor: Constructor) -> None:
result = df.select(c=nw.mean_horizontal(nw.all()))
expected = {"c": [6, 12, 18]}
assert_equal_data(result, expected)


@pytest.mark.parametrize(
("exprs", "name"),
[
((nw.col("a"), 1), "a"),
((nw.col("a"), nw.lit(1)), "a"),
((1, nw.col("a")), "literal"),
((nw.lit(1), nw.col("a")), "literal"),
],
)
def test_meanh_with_scalars(
constructor: Constructor, exprs: tuple[PythonLiteral | nw.Expr, ...], name: str
) -> None:
data = {"a": [1, 2, 3]}
df = nw.from_native(constructor(data))
result = df.select(nw.mean_horizontal(*exprs))
assert_equal_data(result, {name: [1.0, 1.5, 2.0]})
22 changes: 22 additions & 0 deletions tests/expr_and_series/min_horizontal_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

import narwhals as nw
from tests.utils import Constructor, assert_equal_data

if TYPE_CHECKING:
from narwhals.typing import PythonLiteral

data = {"a": [1, 3, None, None], "b": [4, None, 6, None], "z": [3, 1, None, None]}
expected_values = [1, 1, 6, None]

Expand All @@ -23,3 +28,20 @@ def test_minh_all(constructor: Constructor) -> None:
result = df.select(nw.min_horizontal(nw.all()), c=nw.min_horizontal(nw.all()))
expected = {"a": expected_values, "c": expected_values}
assert_equal_data(result, expected)


@pytest.mark.parametrize(
("exprs", "name"),
[
((nw.col("a"), 2), "a"),
((nw.col("a"), nw.lit(2)), "a"),
((2, nw.col("a")), "literal"),
((nw.lit(2), nw.col("a")), "literal"),
],
)
def test_minh_with_scalars(
constructor: Constructor, exprs: tuple[PythonLiteral | nw.Expr, ...], name: str
) -> None:
df = nw.from_native(constructor({"a": [1, 2, 3]}))
result = df.select(nw.min_horizontal(*exprs))
assert_equal_data(result, {name: [1, 2, 2]})
23 changes: 22 additions & 1 deletion tests/expr_and_series/sum_horizontal_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

from typing import Any
from typing import TYPE_CHECKING, Any

import pytest

import narwhals as nw
from tests.utils import DUCKDB_VERSION, Constructor, assert_equal_data

if TYPE_CHECKING:
from narwhals.typing import PythonLiteral


def test_sumh(constructor: Constructor) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]}
Expand Down Expand Up @@ -60,3 +63,21 @@ def test_sumh_transformations(constructor: Constructor) -> None:
result = df.select(d=nw.sum_horizontal("a", nw.lit(None, dtype=nw.Float64), "c"))
expected = {"d": [8.0, 10.0, 12.0]}
assert_equal_data(result, expected)


@pytest.mark.parametrize(
("exprs", "name"),
[
((nw.col("a"), 1), "a"),
((nw.col("a"), nw.lit(1)), "a"),
((1, nw.col("a")), "literal"),
((nw.lit(1), nw.col("a")), "literal"),
],
)
def test_sumh_with_scalars(
constructor: Constructor, exprs: tuple[PythonLiteral | nw.Expr, ...], name: str
) -> None:
data = {"a": [1, 2, 3]}
df = nw.from_native(constructor(data))
result = df.select(nw.sum_horizontal(*exprs))
assert_equal_data(result, {name: [2, 3, 4]})
Loading