Skip to content

Commit 77373af

Browse files
authored
Optimize GroupBy's aggregation algorithm (#2696)
1 parent df1492c commit 77373af

File tree

3 files changed

+50
-14
lines changed

3 files changed

+50
-14
lines changed

mars/dataframe/groupby/aggregation.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import numpy as np
2222
import pandas as pd
23+
from scipy.stats import variation
2324

2425
from ... import opcodes as OperandDef
2526
from ...config import options
@@ -59,12 +60,12 @@
5960

6061
class SizeRecorder:
6162
def __init__(self):
62-
self._raw_records = 0
63-
self._agg_records = 0
63+
self._raw_records = []
64+
self._agg_records = []
6465

65-
def record(self, raw_records: int, agg_records: int):
66-
self._raw_records += raw_records
67-
self._agg_records += agg_records
66+
def record(self, raw_record: int, agg_record: int):
67+
self._raw_records.append(raw_record)
68+
self._agg_records.append(agg_record)
6869

6970
def get(self):
7071
return self._raw_records, self._agg_records
@@ -659,15 +660,27 @@ def _tile_auto(
659660
# yield to trigger execution
660661
yield chunks
661662

662-
raw_size, agg_size = size_recorder.get()
663+
raw_sizes, agg_sizes = size_recorder.get()
663664
# destroy size recorder
664665
ctx.destroy_remote_object(size_recorder_name)
665666

666667
left_chunks = in_df.chunks[combine_size:]
667668
left_chunks = cls._gen_map_chunks(op, left_chunks, out_df, func_infos)
668-
if raw_size >= agg_size * len(chunks):
669-
# aggregated size is less than 1 chunk
670-
# use tree aggregation
669+
# calculate the coefficient of variation of aggregation sizes,
670+
# if the CV is less than 0.2 and the mean of agg_size/raw_size
671+
# is less than 0.8, we suppose the single chunk's aggregation size
672+
# almost equals to the tileable's, then use tree method
673+
# as combine aggregation results won't lead to a rapid expansion.
674+
ratios = [
675+
agg_size / raw_size for agg_size, raw_size in zip(agg_sizes, raw_sizes)
676+
]
677+
cv = variation(agg_sizes)
678+
mean_ratio = np.mean(ratios)
679+
if mean_ratio <= 1 / len(chunks):
680+
# if mean of ratio is less than 0.25, use tree
681+
return cls._combine_tree(op, chunks + left_chunks, out_df, func_infos)
682+
elif cv <= 0.2 and mean_ratio <= 2 / 3:
683+
# check CV and mean of ratio
671684
return cls._combine_tree(op, chunks + left_chunks, out_df, func_infos)
672685
else:
673686
# otherwise, use shuffle
@@ -685,7 +698,7 @@ def tile(cls, op: "DataFrameGroupByAgg"):
685698
func_infos = cls._compile_funcs(op, in_df)
686699

687700
if op.method == "auto":
688-
if len(in_df.chunks) < op.combine_size:
701+
if len(in_df.chunks) <= op.combine_size:
689702
return cls._tile_with_tree(op, in_df, out_df, func_infos)
690703
else:
691704
return (yield from cls._tile_auto(op, in_df, out_df, func_infos))

mars/dataframe/groupby/tests/test_groupby_execution.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -581,14 +581,31 @@ def _disallow_reduce(ctx, op):
581581
pd.testing.assert_frame_equal(result.sort_index(), raw.groupby("c2").agg("sum"))
582582

583583
def _disallow_combine_and_agg(ctx, op):
584-
assert op.stage not in (OperandStage.combine, OperandStage.agg)
584+
assert op.stage != OperandStage.combine
585585
op.execute(ctx, op)
586586

587-
r = mdf.groupby("c1").agg("sum")
587+
r = mdf.groupby("c3").agg("sum")
588588
operand_executors = {DataFrameGroupByAgg: _disallow_combine_and_agg}
589589
result = r.execute(
590590
extra_config={"operand_executors": operand_executors, "check_all": False}
591591
).fetch()
592+
pd.testing.assert_frame_equal(result.sort_index(), raw.groupby("c3").agg("sum"))
593+
594+
rs = np.random.RandomState(0)
595+
raw = pd.DataFrame(
596+
{
597+
"c1": list(range(4)) * 12,
598+
"c2": rs.choice(["a", "b", "c"], (48,)),
599+
"c3": rs.rand(48),
600+
}
601+
)
602+
603+
mdf = md.DataFrame(raw, chunk_size=8)
604+
r = mdf.groupby("c1").agg("sum")
605+
operand_executors = {DataFrameGroupByAgg: _disallow_reduce}
606+
result = r.execute(
607+
extra_config={"operand_executors": operand_executors, "check_all": False}
608+
).fetch()
592609
pd.testing.assert_frame_equal(result.sort_index(), raw.groupby("c1").agg("sum"))
593610

594611

mars/services/subtask/worker/tests/subtask_processor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ def __init__(self, *args, **kwargs):
2424
super().__init__(*args, **kwargs)
2525

2626
check_options = dict()
27-
kwargs = self.subtask.extra_config or dict()
27+
if self.subtask.extra_config:
28+
kwargs = self.subtask.extra_config.copy()
29+
else:
30+
kwargs = dict()
2831
self._operand_executors = operand_executors = kwargs.pop(
2932
"operand_executors", dict()
3033
)
@@ -50,4 +53,7 @@ def _execute_operand(self, ctx: Dict[str, Any], op: OperandType):
5053
async def done(self):
5154
await super().done()
5255
for op in self._operand_executors:
53-
op.unregister_executor()
56+
try:
57+
op.unregister_executor()
58+
except KeyError:
59+
pass

0 commit comments

Comments
 (0)