Skip to content

Commit e860bcc

Browse files
authored
fix: align nw.nth expansion with nw.col during group_by (#3243)
* fix: align `nw.nth` expansion with `nw.col` during `group_by` * skip modin
1 parent 8be198e commit e860bcc

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

narwhals/_expression_parsing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ def from_nth(cls, node: ExprNode) -> ExprMetadata:
460460
return (
461461
cls(ExpansionKind.SINGLE, current_node=node, prev=None)
462462
if len(node.kwargs["indices"]) == 1
463-
else cls.from_selector_multi_unnamed(node)
463+
else cls.from_selector_multi_named(node)
464464
)
465465

466466
@classmethod

tests/frame/group_by_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,3 +772,31 @@ def test_group_by_agg_last(
772772
df = df.sort(aggs, **pre_sort)
773773
result = df.group_by(keys).agg(nw.col(aggs).last()).sort(keys)
774774
assert_equal_data(result, expected)
775+
776+
777+
def test_multi_column_expansion(constructor: Constructor) -> None:
778+
if "polars" in str(constructor) and POLARS_VERSION < (1, 32):
779+
pytest.skip(reason="https://github.com/pola-rs/polars/issues/21773")
780+
if "modin" in str(constructor):
781+
pytest.skip(reason="Internal error")
782+
df = nw.from_native(constructor({"a": [1, 1, 2], "b": [4, 5, 6]}))
783+
result = (
784+
df.group_by("a")
785+
.agg(nw.all().sum().name.suffix("_aggregated"))
786+
.sort("a", descending=True)
787+
)
788+
expected = {"a": [2, 1], "b_aggregated": [6, 9]}
789+
assert_equal_data(result, expected)
790+
result = (
791+
df.group_by("a")
792+
.agg(nw.col("a", "b").sum().name.suffix("_aggregated"))
793+
.sort("a", descending=True)
794+
)
795+
expected = {"a": [2, 1], "a_aggregated": [2, 2], "b_aggregated": [6, 9]}
796+
assert_equal_data(result, expected)
797+
result = (
798+
df.group_by("a")
799+
.agg(nw.nth(0, 1).sum().name.suffix("_aggregated"))
800+
.sort("a", descending=True)
801+
)
802+
assert_equal_data(result, expected)

0 commit comments

Comments
 (0)