Skip to content

Commit f2c9c4c

Browse files
authored
Fix IndexError raise by aggregation of DataFrameGroupBy (#2641)
1 parent 1f6c3d4 commit f2c9c4c

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

mars/dataframe/groupby/aggregation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -825,9 +825,10 @@ def _execute_map(cls, ctx, op: "DataFrameGroupByAgg"):
825825

826826
for input_key, output_key, cols, func in op.pre_funcs:
827827
if input_key == output_key:
828-
ret_map_groupbys[output_key] = (
829-
grouped if cols is None else grouped[cols]
830-
)
828+
if cols is None or grouped._selection is not None:
829+
ret_map_groupbys[output_key] = grouped
830+
else:
831+
ret_map_groupbys[output_key] = grouped[cols]
831832
else:
832833

833834
def _wrapped_func(col):

mars/dataframe/groupby/tests/test_groupby_execution.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,32 @@ def test_groupby_getitem(setup):
273273
raw.groupby(0, as_index=False)[0].agg({"cnt": "count"}),
274274
)
275275

276+
# test groupby getitem then agg(#GH 2640)
277+
rs = np.random.RandomState(0)
278+
raw = pd.DataFrame(
279+
{
280+
"c1": np.arange(100).astype(np.int64),
281+
"c2": rs.choice(["a", "b", "c"], (100,)),
282+
"c3": rs.rand(100),
283+
"c4": rs.rand(100),
284+
}
285+
)
286+
mdf = md.DataFrame(raw, chunk_size=20)
287+
r = mdf.groupby(["c2"])[["c1", "c3"]].agg({"c1": "max", "c3": "min"}, method="tree")
288+
pd.testing.assert_frame_equal(
289+
r.execute().fetch(),
290+
raw.groupby(["c2"])[["c1", "c3"]].agg({"c1": "max", "c3": "min"}),
291+
)
292+
293+
mdf = md.DataFrame(raw.copy(), chunk_size=30)
294+
r = mdf.groupby(["c2"])[["c1", "c4"]].agg(
295+
{"c1": "max", "c4": "mean"}, method="shuffle"
296+
)
297+
pd.testing.assert_frame_equal(
298+
r.execute().fetch(),
299+
raw.groupby(["c2"])[["c1", "c4"]].agg({"c1": "max", "c4": "mean"}),
300+
)
301+
276302

277303
def test_dataframe_groupby_agg(setup):
278304
agg_funs = [

0 commit comments

Comments
 (0)