Skip to content

Commit 3adefad

Browse files
sak2002sakshamkumar-byt继盛
authored
Support sort=True for Groupby (#2959)
Co-authored-by: Saksham Kumar <[email protected]> Co-authored-by: 继盛 <[email protected]>
1 parent be82602 commit 3adefad

File tree

6 files changed

+534
-17
lines changed

6 files changed

+534
-17
lines changed

mars/dataframe/groupby/aggregation.py

Lines changed: 239 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@
5555
from ..reduction.aggregation import is_funcs_aggregate, normalize_reduction_funcs
5656
from ..utils import parse_index, build_concatenated_rows_frame, is_cudf
5757
from .core import DataFrameGroupByOperand
58+
from .sort import (
59+
DataFramePSRSGroupbySample,
60+
DataFrameGroupbyConcatPivot,
61+
DataFrameGroupbySortShuffle,
62+
)
5863

5964
cp = lazy_import("cupy", globals=globals(), rename="cp")
6065
cudf = lazy_import("cudf", globals=globals())
@@ -293,6 +298,117 @@ def __call__(self, groupby):
293298
else:
294299
return self._call_series(groupby, df)
295300

301+
@classmethod
302+
def partition_merge_data(
303+
cls,
304+
op: "DataFrameGroupByAgg",
305+
partition_chunks: List[ChunkType],
306+
proxy_chunk: ChunkType,
307+
):
308+
# stage 4: all *ith* classes are gathered and merged
309+
partition_sort_chunks = []
310+
properties = dict(by=op.groupby_params["by"], gpu=op.is_gpu())
311+
out_df = op.outputs[0]
312+
313+
for i, partition_chunk in enumerate(partition_chunks):
314+
output_types = (
315+
[OutputType.dataframe_groupby]
316+
if out_df.ndim == 2
317+
else [OutputType.series_groupby]
318+
)
319+
partition_shuffle_reduce = DataFrameGroupbySortShuffle(
320+
stage=OperandStage.reduce,
321+
reducer_index=(i, 0),
322+
output_types=output_types,
323+
**properties,
324+
)
325+
chunk_shape = list(partition_chunk.shape)
326+
chunk_shape[0] = np.nan
327+
328+
kw = dict(
329+
shape=tuple(chunk_shape),
330+
index=partition_chunk.index,
331+
index_value=partition_chunk.index_value,
332+
)
333+
if op.outputs[0].ndim == 2:
334+
kw.update(
335+
dict(
336+
columns_value=partition_chunk.columns_value,
337+
dtypes=partition_chunk.dtypes,
338+
)
339+
)
340+
else:
341+
kw.update(dict(dtype=partition_chunk.dtype, name=partition_chunk.name))
342+
cs = partition_shuffle_reduce.new_chunks([proxy_chunk], **kw)
343+
partition_sort_chunks.append(cs[0])
344+
return partition_sort_chunks
345+
346+
@classmethod
347+
def partition_local_data(
348+
cls,
349+
op: "DataFrameGroupByAgg",
350+
sorted_chunks: List[ChunkType],
351+
concat_pivot_chunk: ChunkType,
352+
in_df: TileableType,
353+
):
354+
# properties = dict(by=op.groupby_params["by"], gpu=op.is_gpu())
355+
out_df = op.outputs[0]
356+
map_chunks = []
357+
chunk_shape = (in_df.chunk_shape[0], 1)
358+
for chunk in sorted_chunks:
359+
chunk_inputs = [chunk, concat_pivot_chunk]
360+
output_types = (
361+
[OutputType.dataframe_groupby]
362+
if out_df.ndim == 2
363+
else [OutputType.series_groupby]
364+
)
365+
map_chunk_op = DataFrameGroupbySortShuffle(
366+
shuffle_size=chunk_shape[0],
367+
stage=OperandStage.map,
368+
n_partition=len(sorted_chunks),
369+
output_types=output_types,
370+
)
371+
kw = dict()
372+
if out_df.ndim == 2:
373+
kw.update(
374+
dict(
375+
columns_value=chunk_inputs[0].columns_value,
376+
dtypes=chunk_inputs[0].dtypes,
377+
)
378+
)
379+
else:
380+
kw.update(dict(dtype=chunk_inputs[0].dtype, name=chunk_inputs[0].name))
381+
382+
map_chunks.append(
383+
map_chunk_op.new_chunk(
384+
chunk_inputs,
385+
shape=chunk_shape,
386+
index=chunk.index,
387+
index_value=chunk_inputs[0].index_value,
388+
# **kw
389+
)
390+
)
391+
392+
return map_chunks
393+
394+
@classmethod
395+
def _gen_shuffle_chunks_with_pivot(
396+
cls,
397+
op: "DataFrameGroupByAgg",
398+
in_df: TileableType,
399+
chunks: List[ChunkType],
400+
pivot: ChunkType,
401+
):
402+
map_chunks = cls.partition_local_data(op, chunks, pivot, in_df)
403+
404+
proxy_chunk = DataFrameShuffleProxy(
405+
output_types=[OutputType.dataframe]
406+
).new_chunk(map_chunks, shape=())
407+
408+
partition_sort_chunks = cls.partition_merge_data(op, map_chunks, proxy_chunk)
409+
410+
return partition_sort_chunks
411+
296412
@classmethod
297413
def _gen_shuffle_chunks(cls, op, in_df, chunks):
298414
# generate map chunks
@@ -333,7 +449,6 @@ def _gen_shuffle_chunks(cls, op, in_df, chunks):
333449
index_value=None,
334450
)
335451
)
336-
337452
return reduce_chunks
338453

339454
@classmethod
@@ -349,7 +464,7 @@ def _gen_map_chunks(
349464
chunk_inputs = [chunk]
350465
map_op = op.copy().reset_key()
351466
# force as_index=True for map phase
352-
map_op.output_types = [OutputType.dataframe]
467+
map_op.output_types = op.output_types
353468
map_op.groupby_params = map_op.groupby_params.copy()
354469
map_op.groupby_params["as_index"] = True
355470
if isinstance(map_op.groupby_params["by"], list):
@@ -367,21 +482,25 @@ def _gen_map_chunks(
367482
map_op.stage = OperandStage.map
368483
map_op.pre_funcs = func_infos.pre_funcs
369484
map_op.agg_funcs = func_infos.agg_funcs
370-
new_index = chunk.index if len(chunk.index) == 2 else (chunk.index[0], 0)
371-
if op.output_types[0] == OutputType.dataframe:
485+
new_index = chunk.index if len(chunk.index) == 2 else (chunk.index[0],)
486+
if out_df.ndim == 2:
487+
new_index = (new_index[0], 0) if len(new_index) == 1 else new_index
372488
map_chunk = map_op.new_chunk(
373489
chunk_inputs,
374490
shape=out_df.shape,
375491
index=new_index,
376492
index_value=out_df.index_value,
377493
columns_value=out_df.columns_value,
494+
dtypes=out_df.dtypes,
378495
)
379496
else:
497+
new_index = new_index[:1] if len(new_index) == 2 else new_index
380498
map_chunk = map_op.new_chunk(
381499
chunk_inputs,
382-
shape=(out_df.shape[0], 1),
500+
shape=(out_df.shape[0],),
383501
index=new_index,
384502
index_value=out_df.index_value,
503+
dtype=out_df.dtype,
385504
)
386505
map_chunks.append(map_chunk)
387506
return map_chunks
@@ -422,7 +541,96 @@ def _tile_with_shuffle(
422541
):
423542
# First, perform groupby and aggregation on each chunk.
424543
agg_chunks = cls._gen_map_chunks(op, in_df.chunks, out_df, func_infos)
425-
return cls._perform_shuffle(op, agg_chunks, in_df, out_df, func_infos)
544+
pivot_chunk = None
545+
if op.groupby_params["sort"] and len(in_df.chunks) > 1:
546+
agg_chunk_len = len(agg_chunks)
547+
sample_chunks = cls._sample_chunks(op, agg_chunks)
548+
pivot_chunk = cls._gen_pivot_chunk(op, sample_chunks, agg_chunk_len)
549+
550+
return cls._perform_shuffle(
551+
op, agg_chunks, in_df, out_df, func_infos, pivot_chunk
552+
)
553+
554+
@classmethod
555+
def _gen_pivot_chunk(
556+
cls,
557+
op: "DataFrameGroupByAgg",
558+
sample_chunks: List[ChunkType],
559+
agg_chunk_len: int,
560+
):
561+
562+
properties = dict(
563+
by=op.groupby_params["by"],
564+
gpu=op.is_gpu(),
565+
)
566+
567+
# stage 2: gather and merge samples, choose and broadcast p-1 pivots
568+
kind = "quicksort"
569+
output_types = [OutputType.tensor]
570+
571+
concat_pivot_op = DataFrameGroupbyConcatPivot(
572+
kind=kind,
573+
n_partition=agg_chunk_len,
574+
output_types=output_types,
575+
**properties,
576+
)
577+
578+
concat_pivot_chunk = concat_pivot_op.new_chunk(
579+
sample_chunks,
580+
shape=(agg_chunk_len,),
581+
dtype=object,
582+
)
583+
return concat_pivot_chunk
584+
585+
@classmethod
586+
def _sample_chunks(
587+
cls,
588+
op: "DataFrameGroupByAgg",
589+
agg_chunks: List[ChunkType],
590+
):
591+
chunk_shape = len(agg_chunks)
592+
sampled_chunks = []
593+
594+
properties = dict(
595+
by=op.groupby_params["by"],
596+
gpu=op.is_gpu(),
597+
)
598+
599+
for i, chunk in enumerate(agg_chunks):
600+
kws = []
601+
sampled_shape = (
602+
(chunk_shape, chunk.shape[1]) if chunk.ndim == 2 else (chunk_shape,)
603+
)
604+
chunk_index = (i, 0) if chunk.ndim == 2 else (i,)
605+
chunk_op = DataFramePSRSGroupbySample(
606+
kind="quicksort",
607+
n_partition=chunk_shape,
608+
output_types=op.output_types,
609+
**properties,
610+
)
611+
if op.output_types[0] == OutputType.dataframe:
612+
kws.append(
613+
{
614+
"shape": sampled_shape,
615+
"index_value": chunk.index_value,
616+
"index": chunk_index,
617+
"type": "regular_sampled",
618+
}
619+
)
620+
else:
621+
kws.append(
622+
{
623+
"shape": sampled_shape,
624+
"index_value": chunk.index_value,
625+
"index": chunk_index,
626+
"type": "regular_sampled",
627+
"dtype": chunk.dtype,
628+
}
629+
)
630+
chunk = chunk_op.new_chunk([chunk], kws=kws)
631+
sampled_chunks.append(chunk)
632+
633+
return sampled_chunks
426634

427635
@classmethod
428636
def _perform_shuffle(
@@ -432,9 +640,15 @@ def _perform_shuffle(
432640
in_df: TileableType,
433641
out_df: TileableType,
434642
func_infos: ReductionSteps,
643+
pivot_chunk: ChunkType,
435644
):
436645
# Shuffle the aggregation chunk.
437-
reduce_chunks = cls._gen_shuffle_chunks(op, in_df, agg_chunks)
646+
if pivot_chunk is not None:
647+
reduce_chunks = cls._gen_shuffle_chunks_with_pivot(
648+
op, in_df, agg_chunks, pivot_chunk
649+
)
650+
else:
651+
reduce_chunks = cls._gen_shuffle_chunks(op, in_df, agg_chunks)
438652

439653
# Combine groups
440654
agg_chunks = []
@@ -505,14 +719,17 @@ def _combine_tree(
505719
if len(chks) == 1:
506720
chk = chks[0]
507721
else:
508-
concat_op = DataFrameConcat(output_types=[OutputType.dataframe])
722+
concat_op = DataFrameConcat(output_types=out_df.op.output_types)
509723
# Change index for concatenate
510724
for j, c in enumerate(chks):
511725
c._index = (j, 0)
512-
chk = concat_op.new_chunk(chks, dtypes=chks[0].dtypes)
726+
if out_df.ndim == 2:
727+
chk = concat_op.new_chunk(chks, dtypes=chks[0].dtypes)
728+
else:
729+
chk = concat_op.new_chunk(chks, dtype=chunks[0].dtype)
513730
chunk_op = op.copy().reset_key()
514731
chunk_op.tileable_op_key = None
515-
chunk_op.output_types = [OutputType.dataframe]
732+
chunk_op.output_types = out_df.op.output_types
516733
chunk_op.stage = OperandStage.combine
517734
chunk_op.groupby_params = chunk_op.groupby_params.copy()
518735
chunk_op.groupby_params.pop("selection", None)
@@ -536,8 +753,11 @@ def _combine_tree(
536753
)
537754
chunks = new_chunks
538755

539-
concat_op = DataFrameConcat(output_types=[OutputType.dataframe])
540-
chk = concat_op.new_chunk(chunks, dtypes=chunks[0].dtypes)
756+
concat_op = DataFrameConcat(output_types=out_df.op.output_types)
757+
if out_df.ndim == 2:
758+
chk = concat_op.new_chunk(chunks, dtypes=chunks[0].dtypes)
759+
else:
760+
chk = concat_op.new_chunk(chunks, dtype=chunks[0].dtype)
541761
chunk_op = op.copy().reset_key()
542762
chunk_op.tileable_op_key = op.key
543763
chunk_op.stage = OperandStage.agg
@@ -621,9 +841,15 @@ def _tile_auto(
621841
return cls._combine_tree(op, chunks + left_chunks, out_df, func_infos)
622842
else:
623843
# otherwise, use shuffle
844+
pivot_chunk = None
845+
if op.groupby_params["sort"] and len(in_df.chunks) > 1:
846+
agg_chunk_len = len(chunks + left_chunks)
847+
sample_chunks = cls._sample_chunks(op, chunks + left_chunks)
848+
pivot_chunk = cls._gen_pivot_chunk(op, sample_chunks, agg_chunk_len)
849+
624850
logger.debug("Choose shuffle method for groupby operand %s", op)
625851
return cls._perform_shuffle(
626-
op, chunks + left_chunks, in_df, out_df, func_infos
852+
op, chunks + left_chunks, in_df, out_df, func_infos, pivot_chunk
627853
)
628854

629855
@classmethod
@@ -671,8 +897,6 @@ def _get_grouped(cls, op: "DataFrameGroupByAgg", df, ctx, copy=False, grouper=No
671897
if op.stage == OperandStage.agg:
672898
grouped = df.groupby(**params)
673899
else:
674-
# for the intermediate phases, do not sort
675-
params["sort"] = False
676900
grouped = df.groupby(**params)
677901

678902
if selection is not None:

0 commit comments

Comments
 (0)