5555from ..reduction .aggregation import is_funcs_aggregate , normalize_reduction_funcs
5656from ..utils import parse_index , build_concatenated_rows_frame , is_cudf
5757from .core import DataFrameGroupByOperand
58+ from .sort import (
59+ DataFramePSRSGroupbySample ,
60+ DataFrameGroupbyConcatPivot ,
61+ DataFrameGroupbySortShuffle ,
62+ )
5863
5964cp = lazy_import ("cupy" , globals = globals (), rename = "cp" )
6065cudf = lazy_import ("cudf" , globals = globals ())
@@ -293,6 +298,117 @@ def __call__(self, groupby):
293298 else :
294299 return self ._call_series (groupby , df )
295300
301+ @classmethod
302+ def partition_merge_data (
303+ cls ,
304+ op : "DataFrameGroupByAgg" ,
305+ partition_chunks : List [ChunkType ],
306+ proxy_chunk : ChunkType ,
307+ ):
308+ # stage 4: all *ith* classes are gathered and merged
309+ partition_sort_chunks = []
310+ properties = dict (by = op .groupby_params ["by" ], gpu = op .is_gpu ())
311+ out_df = op .outputs [0 ]
312+
313+ for i , partition_chunk in enumerate (partition_chunks ):
314+ output_types = (
315+ [OutputType .dataframe_groupby ]
316+ if out_df .ndim == 2
317+ else [OutputType .series_groupby ]
318+ )
319+ partition_shuffle_reduce = DataFrameGroupbySortShuffle (
320+ stage = OperandStage .reduce ,
321+ reducer_index = (i , 0 ),
322+ output_types = output_types ,
323+ ** properties ,
324+ )
325+ chunk_shape = list (partition_chunk .shape )
326+ chunk_shape [0 ] = np .nan
327+
328+ kw = dict (
329+ shape = tuple (chunk_shape ),
330+ index = partition_chunk .index ,
331+ index_value = partition_chunk .index_value ,
332+ )
333+ if op .outputs [0 ].ndim == 2 :
334+ kw .update (
335+ dict (
336+ columns_value = partition_chunk .columns_value ,
337+ dtypes = partition_chunk .dtypes ,
338+ )
339+ )
340+ else :
341+ kw .update (dict (dtype = partition_chunk .dtype , name = partition_chunk .name ))
342+ cs = partition_shuffle_reduce .new_chunks ([proxy_chunk ], ** kw )
343+ partition_sort_chunks .append (cs [0 ])
344+ return partition_sort_chunks
345+
346+ @classmethod
347+ def partition_local_data (
348+ cls ,
349+ op : "DataFrameGroupByAgg" ,
350+ sorted_chunks : List [ChunkType ],
351+ concat_pivot_chunk : ChunkType ,
352+ in_df : TileableType ,
353+ ):
354+ # properties = dict(by=op.groupby_params["by"], gpu=op.is_gpu())
355+ out_df = op .outputs [0 ]
356+ map_chunks = []
357+ chunk_shape = (in_df .chunk_shape [0 ], 1 )
358+ for chunk in sorted_chunks :
359+ chunk_inputs = [chunk , concat_pivot_chunk ]
360+ output_types = (
361+ [OutputType .dataframe_groupby ]
362+ if out_df .ndim == 2
363+ else [OutputType .series_groupby ]
364+ )
365+ map_chunk_op = DataFrameGroupbySortShuffle (
366+ shuffle_size = chunk_shape [0 ],
367+ stage = OperandStage .map ,
368+ n_partition = len (sorted_chunks ),
369+ output_types = output_types ,
370+ )
371+ kw = dict ()
372+ if out_df .ndim == 2 :
373+ kw .update (
374+ dict (
375+ columns_value = chunk_inputs [0 ].columns_value ,
376+ dtypes = chunk_inputs [0 ].dtypes ,
377+ )
378+ )
379+ else :
380+ kw .update (dict (dtype = chunk_inputs [0 ].dtype , name = chunk_inputs [0 ].name ))
381+
382+ map_chunks .append (
383+ map_chunk_op .new_chunk (
384+ chunk_inputs ,
385+ shape = chunk_shape ,
386+ index = chunk .index ,
387+ index_value = chunk_inputs [0 ].index_value ,
388+ # **kw
389+ )
390+ )
391+
392+ return map_chunks
393+
394+ @classmethod
395+ def _gen_shuffle_chunks_with_pivot (
396+ cls ,
397+ op : "DataFrameGroupByAgg" ,
398+ in_df : TileableType ,
399+ chunks : List [ChunkType ],
400+ pivot : ChunkType ,
401+ ):
402+ map_chunks = cls .partition_local_data (op , chunks , pivot , in_df )
403+
404+ proxy_chunk = DataFrameShuffleProxy (
405+ output_types = [OutputType .dataframe ]
406+ ).new_chunk (map_chunks , shape = ())
407+
408+ partition_sort_chunks = cls .partition_merge_data (op , map_chunks , proxy_chunk )
409+
410+ return partition_sort_chunks
411+
296412 @classmethod
297413 def _gen_shuffle_chunks (cls , op , in_df , chunks ):
298414 # generate map chunks
@@ -333,7 +449,6 @@ def _gen_shuffle_chunks(cls, op, in_df, chunks):
333449 index_value = None ,
334450 )
335451 )
336-
337452 return reduce_chunks
338453
339454 @classmethod
@@ -349,7 +464,7 @@ def _gen_map_chunks(
349464 chunk_inputs = [chunk ]
350465 map_op = op .copy ().reset_key ()
351466 # force as_index=True for map phase
352- map_op .output_types = [ OutputType . dataframe ]
467+ map_op .output_types = op . output_types
353468 map_op .groupby_params = map_op .groupby_params .copy ()
354469 map_op .groupby_params ["as_index" ] = True
355470 if isinstance (map_op .groupby_params ["by" ], list ):
@@ -367,21 +482,25 @@ def _gen_map_chunks(
367482 map_op .stage = OperandStage .map
368483 map_op .pre_funcs = func_infos .pre_funcs
369484 map_op .agg_funcs = func_infos .agg_funcs
370- new_index = chunk .index if len (chunk .index ) == 2 else (chunk .index [0 ], 0 )
371- if op .output_types [0 ] == OutputType .dataframe :
485+ new_index = chunk .index if len (chunk .index ) == 2 else (chunk .index [0 ],)
486+ if out_df .ndim == 2 :
487+ new_index = (new_index [0 ], 0 ) if len (new_index ) == 1 else new_index
372488 map_chunk = map_op .new_chunk (
373489 chunk_inputs ,
374490 shape = out_df .shape ,
375491 index = new_index ,
376492 index_value = out_df .index_value ,
377493 columns_value = out_df .columns_value ,
494+ dtypes = out_df .dtypes ,
378495 )
379496 else :
497+ new_index = new_index [:1 ] if len (new_index ) == 2 else new_index
380498 map_chunk = map_op .new_chunk (
381499 chunk_inputs ,
382- shape = (out_df .shape [0 ], 1 ),
500+ shape = (out_df .shape [0 ],),
383501 index = new_index ,
384502 index_value = out_df .index_value ,
503+ dtype = out_df .dtype ,
385504 )
386505 map_chunks .append (map_chunk )
387506 return map_chunks
@@ -422,7 +541,96 @@ def _tile_with_shuffle(
422541 ):
423542 # First, perform groupby and aggregation on each chunk.
424543 agg_chunks = cls ._gen_map_chunks (op , in_df .chunks , out_df , func_infos )
425- return cls ._perform_shuffle (op , agg_chunks , in_df , out_df , func_infos )
544+ pivot_chunk = None
545+ if op .groupby_params ["sort" ] and len (in_df .chunks ) > 1 :
546+ agg_chunk_len = len (agg_chunks )
547+ sample_chunks = cls ._sample_chunks (op , agg_chunks )
548+ pivot_chunk = cls ._gen_pivot_chunk (op , sample_chunks , agg_chunk_len )
549+
550+ return cls ._perform_shuffle (
551+ op , agg_chunks , in_df , out_df , func_infos , pivot_chunk
552+ )
553+
554+ @classmethod
555+ def _gen_pivot_chunk (
556+ cls ,
557+ op : "DataFrameGroupByAgg" ,
558+ sample_chunks : List [ChunkType ],
559+ agg_chunk_len : int ,
560+ ):
561+
562+ properties = dict (
563+ by = op .groupby_params ["by" ],
564+ gpu = op .is_gpu (),
565+ )
566+
567+ # stage 2: gather and merge samples, choose and broadcast p-1 pivots
568+ kind = "quicksort"
569+ output_types = [OutputType .tensor ]
570+
571+ concat_pivot_op = DataFrameGroupbyConcatPivot (
572+ kind = kind ,
573+ n_partition = agg_chunk_len ,
574+ output_types = output_types ,
575+ ** properties ,
576+ )
577+
578+ concat_pivot_chunk = concat_pivot_op .new_chunk (
579+ sample_chunks ,
580+ shape = (agg_chunk_len ,),
581+ dtype = object ,
582+ )
583+ return concat_pivot_chunk
584+
585+ @classmethod
586+ def _sample_chunks (
587+ cls ,
588+ op : "DataFrameGroupByAgg" ,
589+ agg_chunks : List [ChunkType ],
590+ ):
591+ chunk_shape = len (agg_chunks )
592+ sampled_chunks = []
593+
594+ properties = dict (
595+ by = op .groupby_params ["by" ],
596+ gpu = op .is_gpu (),
597+ )
598+
599+ for i , chunk in enumerate (agg_chunks ):
600+ kws = []
601+ sampled_shape = (
602+ (chunk_shape , chunk .shape [1 ]) if chunk .ndim == 2 else (chunk_shape ,)
603+ )
604+ chunk_index = (i , 0 ) if chunk .ndim == 2 else (i ,)
605+ chunk_op = DataFramePSRSGroupbySample (
606+ kind = "quicksort" ,
607+ n_partition = chunk_shape ,
608+ output_types = op .output_types ,
609+ ** properties ,
610+ )
611+ if op .output_types [0 ] == OutputType .dataframe :
612+ kws .append (
613+ {
614+ "shape" : sampled_shape ,
615+ "index_value" : chunk .index_value ,
616+ "index" : chunk_index ,
617+ "type" : "regular_sampled" ,
618+ }
619+ )
620+ else :
621+ kws .append (
622+ {
623+ "shape" : sampled_shape ,
624+ "index_value" : chunk .index_value ,
625+ "index" : chunk_index ,
626+ "type" : "regular_sampled" ,
627+ "dtype" : chunk .dtype ,
628+ }
629+ )
630+ chunk = chunk_op .new_chunk ([chunk ], kws = kws )
631+ sampled_chunks .append (chunk )
632+
633+ return sampled_chunks
426634
427635 @classmethod
428636 def _perform_shuffle (
@@ -432,9 +640,15 @@ def _perform_shuffle(
432640 in_df : TileableType ,
433641 out_df : TileableType ,
434642 func_infos : ReductionSteps ,
643+ pivot_chunk : ChunkType ,
435644 ):
436645 # Shuffle the aggregation chunk.
437- reduce_chunks = cls ._gen_shuffle_chunks (op , in_df , agg_chunks )
646+ if pivot_chunk is not None :
647+ reduce_chunks = cls ._gen_shuffle_chunks_with_pivot (
648+ op , in_df , agg_chunks , pivot_chunk
649+ )
650+ else :
651+ reduce_chunks = cls ._gen_shuffle_chunks (op , in_df , agg_chunks )
438652
439653 # Combine groups
440654 agg_chunks = []
@@ -505,14 +719,17 @@ def _combine_tree(
505719 if len (chks ) == 1 :
506720 chk = chks [0 ]
507721 else :
508- concat_op = DataFrameConcat (output_types = [ OutputType . dataframe ] )
722+ concat_op = DataFrameConcat (output_types = out_df . op . output_types )
509723 # Change index for concatenate
510724 for j , c in enumerate (chks ):
511725 c ._index = (j , 0 )
512- chk = concat_op .new_chunk (chks , dtypes = chks [0 ].dtypes )
726+ if out_df .ndim == 2 :
727+ chk = concat_op .new_chunk (chks , dtypes = chks [0 ].dtypes )
728+ else :
729+ chk = concat_op .new_chunk (chks , dtype = chunks [0 ].dtype )
513730 chunk_op = op .copy ().reset_key ()
514731 chunk_op .tileable_op_key = None
515- chunk_op .output_types = [ OutputType . dataframe ]
732+ chunk_op .output_types = out_df . op . output_types
516733 chunk_op .stage = OperandStage .combine
517734 chunk_op .groupby_params = chunk_op .groupby_params .copy ()
518735 chunk_op .groupby_params .pop ("selection" , None )
@@ -536,8 +753,11 @@ def _combine_tree(
536753 )
537754 chunks = new_chunks
538755
539- concat_op = DataFrameConcat (output_types = [OutputType .dataframe ])
540- chk = concat_op .new_chunk (chunks , dtypes = chunks [0 ].dtypes )
756+ concat_op = DataFrameConcat (output_types = out_df .op .output_types )
757+ if out_df .ndim == 2 :
758+ chk = concat_op .new_chunk (chunks , dtypes = chunks [0 ].dtypes )
759+ else :
760+ chk = concat_op .new_chunk (chunks , dtype = chunks [0 ].dtype )
541761 chunk_op = op .copy ().reset_key ()
542762 chunk_op .tileable_op_key = op .key
543763 chunk_op .stage = OperandStage .agg
@@ -621,9 +841,15 @@ def _tile_auto(
621841 return cls ._combine_tree (op , chunks + left_chunks , out_df , func_infos )
622842 else :
623843 # otherwise, use shuffle
844+ pivot_chunk = None
845+ if op .groupby_params ["sort" ] and len (in_df .chunks ) > 1 :
846+ agg_chunk_len = len (chunks + left_chunks )
847+ sample_chunks = cls ._sample_chunks (op , chunks + left_chunks )
848+ pivot_chunk = cls ._gen_pivot_chunk (op , sample_chunks , agg_chunk_len )
849+
624850 logger .debug ("Choose shuffle method for groupby operand %s" , op )
625851 return cls ._perform_shuffle (
626- op , chunks + left_chunks , in_df , out_df , func_infos
852+ op , chunks + left_chunks , in_df , out_df , func_infos , pivot_chunk
627853 )
628854
629855 @classmethod
@@ -671,8 +897,6 @@ def _get_grouped(cls, op: "DataFrameGroupByAgg", df, ctx, copy=False, grouper=No
671897 if op .stage == OperandStage .agg :
672898 grouped = df .groupby (** params )
673899 else :
674- # for the intermediate phases, do not sort
675- params ["sort" ] = False
676900 grouped = df .groupby (** params )
677901
678902 if selection is not None :
0 commit comments