diff --git a/py-polars/src/polars/dataframe/group_by.py b/py-polars/src/polars/dataframe/group_by.py index 80899c7043e8..4fdcd595b751 100644 --- a/py-polars/src/polars/dataframe/group_by.py +++ b/py-polars/src/polars/dataframe/group_by.py @@ -365,14 +365,13 @@ def map_groups(self, function: Callable[[DataFrame], DataFrame]) -> DataFrame: if self.named_by: msg = "cannot call `map_groups` when grouping by named expressions" raise TypeError(msg) - if not all(isinstance(c, str) for c in self.by): + by = list(_parse_inputs_as_iterable(self.by)) + if not all(isinstance(c, str) for c in by): msg = "cannot call `map_groups` when grouping by an expression" raise TypeError(msg) - by_strs: list[str] = self.by # type: ignore[assignment] - return self.df.__class__._from_pydf( - self.df._df.group_by_map_groups(by_strs, function, self.maintain_order) + self.df._df.group_by_map_groups(by, function, self.maintain_order) ) def head(self, n: int = 5) -> DataFrame: diff --git a/py-polars/tests/unit/operations/map/test_map_groups.py b/py-polars/tests/unit/operations/map/test_map_groups.py index 70c731ba883d..dccee909cb77 100644 --- a/py-polars/tests/unit/operations/map/test_map_groups.py +++ b/py-polars/tests/unit/operations/map/test_map_groups.py @@ -287,6 +287,12 @@ def test_map_groups_with_slice_25805() -> None: assert_frame_equal(df, pl.DataFrame({"a": [1], "b": [1]}, schema=schema)) +def test_map_groups_group_by_list_26672() -> None: + df = pl.DataFrame({"a": [1, 1, 2], "b": [4, 4, 5], "x": [1, 2, 3], "y": [4, 5, 6]}) + result = df.group_by(["a", "b"]).map_groups(lambda df: df) + assert_frame_equal(result, df, check_row_order=False) + + def test_map_groups_udf_error_does_not_panic_26647() -> None: with pytest.raises(ComputeError, match="UDF failed"): pl.select(x=1).group_by("x").map_groups(lambda x, y: x) # type: ignore[arg-type, misc]