Skip to content

Commit 3ea81c4

Browse files
stakeswkyUsernameexhaustion
authored
fix(python): Allow list argument in group_by().map_groups() (#26707)
Co-authored-by: User <user@example.com> Co-authored-by: nameexhaustion <simonlin.rqmmw@slmail.me>
1 parent 45c9157 commit 3ea81c4

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

py-polars/src/polars/dataframe/group_by.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -365,14 +365,13 @@ def map_groups(self, function: Callable[[DataFrame], DataFrame]) -> DataFrame:
365365
if self.named_by:
366366
msg = "cannot call `map_groups` when grouping by named expressions"
367367
raise TypeError(msg)
368-
if not all(isinstance(c, str) for c in self.by):
368+
by = list(_parse_inputs_as_iterable(self.by))
369+
if not all(isinstance(c, str) for c in by):
369370
msg = "cannot call `map_groups` when grouping by an expression"
370371
raise TypeError(msg)
371372

372-
by_strs: list[str] = self.by # type: ignore[assignment]
373-
374373
return self.df.__class__._from_pydf(
375-
self.df._df.group_by_map_groups(by_strs, function, self.maintain_order)
374+
self.df._df.group_by_map_groups(by, function, self.maintain_order)
376375
)
377376

378377
def head(self, n: int = 5) -> DataFrame:

py-polars/tests/unit/operations/map/test_map_groups.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,12 @@ def test_map_groups_with_slice_25805() -> None:
287287
assert_frame_equal(df, pl.DataFrame({"a": [1], "b": [1]}, schema=schema))
288288

289289

290+
def test_map_groups_group_by_list_26672() -> None:
291+
df = pl.DataFrame({"a": [1, 1, 2], "b": [4, 4, 5], "x": [1, 2, 3], "y": [4, 5, 6]})
292+
result = df.group_by(["a", "b"]).map_groups(lambda df: df)
293+
assert_frame_equal(result, df, check_row_order=False)
294+
295+
290296
def test_map_groups_udf_error_does_not_panic_26647() -> None:
291297
with pytest.raises(ComputeError, match="UDF failed"):
292298
pl.select(x=1).group_by("x").map_groups(lambda x, y: x) # type: ignore[arg-type, misc]

0 commit comments

Comments
 (0)