Skip to content

Commit df1492c

Browse files
author
Xuye (Chris) Qin
authored
Auto merge small chunks when df.groupby().apply(func) is doing aggregation (#2708)
1 parent c3e8bdc commit df1492c

File tree

4 files changed

+189
-24
lines changed

4 files changed

+189
-24
lines changed

mars/dataframe/groupby/apply.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,18 @@
1717

1818
from ... import opcodes
1919
from ...core import OutputType
20+
from ...core.context import get_context
2021
from ...core.custom_log import redirect_custom_log
21-
from ...serialization.serializables import TupleField, DictField, FunctionField
22+
from ...serialization.serializables import (
23+
BoolField,
24+
TupleField,
25+
DictField,
26+
FunctionField,
27+
)
2228
from ...utils import enter_current_session, quiet_stdio
2329
from ..operands import DataFrameOperandMixin, DataFrameOperand
2430
from ..utils import (
31+
auto_merge_chunks,
2532
build_empty_df,
2633
build_empty_series,
2734
parse_index,
@@ -35,26 +42,13 @@ class GroupByApply(DataFrameOperand, DataFrameOperandMixin):
3542
_op_type_ = opcodes.APPLY
3643
_op_module_ = "dataframe.groupby"
3744

38-
_func = FunctionField("func")
39-
_args = TupleField("args")
40-
_kwds = DictField("kwds")
41-
42-
def __init__(self, func=None, args=None, kwds=None, output_types=None, **kw):
43-
super().__init__(
44-
_func=func, _args=args, _kwds=kwds, _output_types=output_types, **kw
45-
)
46-
47-
@property
48-
def func(self):
49-
return self._func
45+
func = FunctionField("func")
46+
args = TupleField("args", default_factory=tuple)
47+
kwds = DictField("kwds", default_factory=dict)
48+
maybe_agg = BoolField("maybe_agg", default=None)
5049

51-
@property
52-
def args(self):
53-
return getattr(self, "_args", None) or ()
54-
55-
@property
56-
def kwds(self):
57-
return getattr(self, "_kwds", None) or dict()
50+
def __init__(self, output_types=None, **kw):
51+
super().__init__(_output_types=output_types, **kw)
5852

5953
@classmethod
6054
@redirect_custom_log
@@ -135,7 +129,14 @@ def tile(cls, op):
135129
kw["nsplits"] = ((np.nan,) * len(chunks), (out_df.shape[1],))
136130
else:
137131
kw["nsplits"] = ((np.nan,) * len(chunks),)
138-
return new_op.new_tileables([in_groupby], **kw)
132+
ret = new_op.new_tileable([in_groupby], **kw)
133+
if not op.maybe_agg:
134+
return [ret]
135+
else:
136+
# auto merge small chunks if df.groupby().apply(func)
137+
# may be an aggregation operation
138+
yield ret.chunks # trigger execution for chunks
139+
return [auto_merge_chunks(get_context(), ret)]
139140

140141
def _infer_df_func_returns(
141142
self, in_groupby, in_df, dtypes, dtype=None, name=None, index=None
@@ -147,6 +148,12 @@ def _infer_df_func_returns(
147148
self.func, *self.args, **self.kwds
148149
)
149150

151+
if len(infer_df) <= 2:
152+
# we create mock df with 4 rows, 2 groups
153+
# if return df has 2 rows, we assume that
154+
# it's an aggregation operation
155+
self.maybe_agg = True
156+
150157
# todo return proper index when sort=True is implemented
151158
index_value = parse_index(infer_df.index[:0], in_df.key, self.func)
152159

mars/dataframe/groupby/tests/test_groupby.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,11 @@ def apply_series(s):
199199
assert applied.chunks[0].shape == (np.nan,)
200200
assert applied.chunks[0].dtype == df1.a.dtype
201201

202-
applied = tile(mdf.groupby("b").apply(lambda df: df.a.sum()))
202+
applied = mdf.groupby("b").apply(lambda df: df.a.sum())
203+
assert applied.op.maybe_agg is True
204+
# force set to pass test
205+
applied.op.maybe_agg = None
206+
applied = tile(applied)
203207
assert applied.dtype == df1.a.dtype
204208
assert applied.shape == (np.nan,)
205209
assert applied.op._op_type_ == opcodes.APPLY

mars/dataframe/tests/test_utils.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import operator
1616
from collections import OrderedDict
1717
from numbers import Integral
18+
from typing import List, Dict
1819

1920
import numpy as np
2021
import pandas as pd
@@ -24,7 +25,7 @@
2425
from ...core import tile
2526
from ...utils import Timer
2627
from ..core import IndexValue
27-
from ..initializer import DataFrame, Index
28+
from ..initializer import DataFrame, Series, Index
2829
from ..utils import (
2930
decide_dataframe_chunk_sizes,
3031
decide_series_chunk_size,
@@ -39,6 +40,7 @@
3940
make_dtypes,
4041
build_concatenated_rows_frame,
4142
merge_index_value,
43+
auto_merge_chunks,
4244
)
4345

4446

@@ -582,3 +584,56 @@ def test_build_concatenated_rows_frame(setup, columns):
582584
concatenated.chunks[i].columns_value.to_pandas(), df.columns
583585
)
584586
pd.testing.assert_frame_equal(concatenated.execute().fetch(), df)
587+
588+
589+
def test_auto_merge_chunks():
590+
from ..merge import DataFrameConcat
591+
592+
pdf = pd.DataFrame(np.random.rand(16, 4), columns=list("abcd"))
593+
memory_size = pdf.iloc[:4].memory_usage().sum()
594+
595+
class FakeContext:
596+
def __init__(self, retval=True):
597+
self._retval = retval
598+
599+
def get_chunks_meta(self, data_keys: List[str], **_) -> List[Dict]:
600+
if self._retval:
601+
return [{"memory_size": memory_size}] * len(data_keys)
602+
else:
603+
return [None] * len(data_keys)
604+
605+
df = tile(DataFrame(pdf, chunk_size=4))
606+
df2 = auto_merge_chunks(FakeContext(), df, 2 * memory_size)
607+
assert len(df2.chunks) == 2
608+
assert isinstance(df2.chunks[0].op, DataFrameConcat)
609+
assert len(df2.chunks[0].op.inputs) == 2
610+
assert isinstance(df2.chunks[1].op, DataFrameConcat)
611+
assert len(df2.chunks[1].op.inputs) == 2
612+
613+
df2 = auto_merge_chunks(FakeContext(), df, 3 * memory_size)
614+
assert len(df2.chunks) == 2
615+
assert isinstance(df2.chunks[0].op, DataFrameConcat)
616+
assert len(df2.chunks[0].op.inputs) == 3
617+
assert df2.chunks[1] is df.chunks[-1]
618+
619+
# mock situation that df not executed
620+
df2 = auto_merge_chunks(FakeContext(False), df, 3 * memory_size)
621+
assert df2 is df
622+
623+
# number of chunks on columns > 1
624+
df3 = tile(DataFrame(pdf, chunk_size=2))
625+
df4 = auto_merge_chunks(FakeContext(), df3, 2 * memory_size)
626+
assert df4 is df3
627+
628+
# test series
629+
ps = pdf.loc[:, "a"]
630+
memory_size = ps.iloc[:4].memory_usage()
631+
s = tile(Series(ps, chunk_size=4))
632+
s2 = auto_merge_chunks(FakeContext(), s, 2 * memory_size)
633+
assert len(s2.chunks) == 2
634+
assert isinstance(s2.chunks[0].op, DataFrameConcat)
635+
assert s2.chunks[0].name == "a"
636+
assert len(s2.chunks[0].op.inputs) == 2
637+
assert isinstance(s2.chunks[1].op, DataFrameConcat)
638+
assert s2.chunks[1].name == "a"
639+
assert len(s2.chunks[1].op.inputs) == 2

mars/dataframe/utils.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,28 @@
1717
import operator
1818
from contextlib import contextmanager
1919
from numbers import Integral
20+
from typing import List, Union
2021

2122
import numpy as np
2223
import pandas as pd
2324
from pandas.api.types import is_string_dtype
2425
from pandas.api.extensions import ExtensionDtype
2526
from pandas.core.dtypes.cast import find_common_type
2627

28+
from ..config import options
2729
from ..core import Entity, ExecutableTuple
30+
from ..core.context import Context
2831
from ..lib.mmh3 import hash as mmh_hash
2932
from ..tensor.utils import dictify_chunk_size, normalize_chunk_sizes
30-
from ..utils import tokenize, sbytes, lazy_import, ModulePlaceholder, is_full_slice
33+
from ..typing import ChunkType, TileableType
34+
from ..utils import (
35+
tokenize,
36+
sbytes,
37+
lazy_import,
38+
ModulePlaceholder,
39+
is_full_slice,
40+
parse_readable_size,
41+
)
3142

3243
try:
3344
import pyarrow as pa
@@ -1293,3 +1304,91 @@ def is_cudf(x):
12931304
if isinstance(x, (cudf.DataFrame, cudf.Series, cudf.Index)):
12941305
return True
12951306
return False
1307+
1308+
1309+
def auto_merge_chunks(
1310+
ctx: Context,
1311+
df_or_series: TileableType,
1312+
merged_file_size: Union[int, float, str] = None,
1313+
) -> TileableType:
1314+
from .merge import DataFrameConcat
1315+
1316+
if df_or_series.ndim == 2 and df_or_series.chunk_shape[1] > 1:
1317+
# skip auto merge optimization for DataFrame
1318+
# that has more than 1 chunks on columns axis
1319+
return df_or_series
1320+
1321+
metas = ctx.get_chunks_meta(
1322+
[c.key for c in df_or_series.chunks], fields=["memory_size"], error="ignore"
1323+
)
1324+
memory_sizes = [meta["memory_size"] if meta is not None else None for meta in metas]
1325+
if any(size is None for size in memory_sizes):
1326+
# has not been executed before, cannot get accurate memory size, skip auto merge
1327+
return df_or_series
1328+
1329+
def _concat_chunks(merge_chunks: List[ChunkType], output_index: int):
1330+
chunk_size = sum(c.shape[0] for c in merge_chunks)
1331+
concat_op = DataFrameConcat(output_types=df_or_series.op.output_types)
1332+
if df_or_series.ndim == 1:
1333+
kw = dict(
1334+
dtype=df_or_series.dtype,
1335+
index_value=merge_index_value(
1336+
{c.index: c.index_value for c in merge_chunks}
1337+
),
1338+
shape=(chunk_size,),
1339+
index=(output_index,),
1340+
name=df_or_series.name,
1341+
)
1342+
else:
1343+
kw = dict(
1344+
dtypes=merge_chunks[0].dtypes,
1345+
index_value=merge_index_value(
1346+
{c.index: c.index_value for c in merge_chunks}
1347+
),
1348+
columns_value=merge_chunks[0].columns_value,
1349+
shape=(chunk_size, merge_chunks[0].shape[1]),
1350+
index=(output_index, 0),
1351+
)
1352+
return concat_op.new_chunk(merge_chunks, **kw)
1353+
1354+
to_merge_size = (
1355+
parse_readable_size(merged_file_size)[0]
1356+
if merged_file_size is not None
1357+
else options.chunk_store_limit
1358+
)
1359+
to_merge_chunks = []
1360+
acc_memory_size = 0
1361+
n_split = []
1362+
out_chunks = []
1363+
for chunk, chunk_memory_size in zip(df_or_series.chunks, memory_sizes):
1364+
if acc_memory_size + chunk_memory_size > to_merge_size:
1365+
# adding current chunk would exceed the maximum,
1366+
# concat previous chunks
1367+
merged_chunk = _concat_chunks(to_merge_chunks, len(n_split))
1368+
out_chunks.append(merged_chunk)
1369+
n_split.append(merged_chunk.shape[0])
1370+
# reset
1371+
acc_memory_size = 0
1372+
to_merge_chunks = []
1373+
1374+
to_merge_chunks.append(chunk)
1375+
acc_memory_size += chunk_memory_size
1376+
# process the last chunk
1377+
if len(to_merge_chunks) > 1:
1378+
merged_chunk = _concat_chunks(to_merge_chunks, len(n_split))
1379+
out_chunks.append(merged_chunk)
1380+
n_split.append(merged_chunk.shape[0])
1381+
else:
1382+
assert len(to_merge_chunks) == 1
1383+
last_chunk = to_merge_chunks[0]
1384+
out_chunks.append(last_chunk)
1385+
n_split.append(last_chunk.shape[0])
1386+
1387+
new_op = df_or_series.op.copy()
1388+
params = df_or_series.params.copy()
1389+
params["chunks"] = out_chunks
1390+
if df_or_series.ndim == 1:
1391+
params["nsplits"] = (tuple(n_split),)
1392+
else:
1393+
params["nsplits"] = (tuple(n_split), df_or_series.nsplits[1])
1394+
return new_op.new_tileable(df_or_series.op.inputs, kws=[params])

0 commit comments

Comments
 (0)