@@ -695,7 +695,28 @@ def _pack_inputs(agg_funcs: List[ReductionAggStep], in_data):
695695 return out_dict
696696
697697 @staticmethod
698- def _do_custom_agg (op , custom_reduction , * input_objs ):
698+ def _do_custom_agg_single (op , custom_reduction , input_obj ):
699+ if op .stage == OperandStage .map :
700+ if custom_reduction .pre_with_agg :
701+ apply_fun = custom_reduction .pre
702+ else :
703+
704+ def apply_fun (obj ):
705+ return custom_reduction .agg (custom_reduction .pre (obj ))
706+
707+ elif op .stage == OperandStage .agg :
708+
709+ def apply_fun (obj ):
710+ return custom_reduction .post (custom_reduction .agg (obj ))
711+
712+ else :
713+ apply_fun = custom_reduction .agg
714+
715+ res = input_obj .apply (apply_fun )
716+ return (res ,)
717+
718+ @staticmethod
719+ def _do_custom_agg_multiple (op , custom_reduction , * input_objs ):
699720 xdf = cudf if op .gpu else pd
700721 results = []
701722 out = op .outputs [0 ]
@@ -763,6 +784,13 @@ def _do_custom_agg(op, custom_reduction, *input_objs):
763784 concat_result = tuple (xdf .concat (parts ) for parts in zip (* results ))
764785 return concat_result
765786
787+ @classmethod
788+ def _do_custom_agg (cls , op , custom_reduction , * input_objs , output_limit : int = 1 ):
789+ if output_limit == 1 :
790+ return cls ._do_custom_agg_single (op , custom_reduction , input_objs [0 ])
791+ else :
792+ return cls ._do_custom_agg_multiple (op , custom_reduction , * input_objs )
793+
766794 @staticmethod
767795 def _do_predefined_agg (input_obj , agg_func , single_func = False , ** kwds ):
768796 ndim = getattr (input_obj , "ndim" , None ) or input_obj .obj .ndim
@@ -857,12 +885,16 @@ def _wrapped_func(col):
857885 _agg_func_name ,
858886 custom_reduction ,
859887 _output_key ,
860- _output_limit ,
888+ output_limit ,
861889 kwds ,
862890 ) in op .agg_funcs :
863891 input_obj = ret_map_groupbys [input_key ]
864892 if map_func_name == "custom_reduction" :
865- agg_dfs .extend (cls ._do_custom_agg (op , custom_reduction , input_obj ))
893+ agg_dfs .extend (
894+ cls ._do_custom_agg (
895+ op , custom_reduction , input_obj , output_limit = output_limit
896+ )
897+ )
866898 else :
867899 single_func = map_func_name == op .raw_func
868900 agg_dfs .append (
@@ -903,12 +935,16 @@ def _execute_combine(cls, ctx, op: "DataFrameGroupByAgg"):
903935 agg_func_name ,
904936 custom_reduction ,
905937 output_key ,
906- _output_limit ,
938+ output_limit ,
907939 kwds ,
908940 ) in op .agg_funcs :
909941 input_obj = in_data_dict [output_key ]
910942 if agg_func_name == "custom_reduction" :
911- combines .extend (cls ._do_custom_agg (op , custom_reduction , * input_obj ))
943+ combines .extend (
944+ cls ._do_custom_agg (
945+ op , custom_reduction , * input_obj , output_limit = output_limit
946+ )
947+ )
912948 else :
913949 combines .append (
914950 cls ._do_predefined_agg (input_obj , agg_func_name , ** kwds )
@@ -943,15 +979,15 @@ def _execute_agg(cls, ctx, op: "DataFrameGroupByAgg"):
943979 agg_func_name ,
944980 custom_reduction ,
945981 output_key ,
946- _output_limit ,
982+ output_limit ,
947983 kwds ,
948984 ) in op .agg_funcs :
949985 if agg_func_name == "custom_reduction" :
950986 input_obj = tuple (
951987 cls ._get_grouped (op , o , ctx ) for o in in_data_dict [output_key ]
952988 )
953989 in_data_dict [output_key ] = cls ._do_custom_agg (
954- op , custom_reduction , * input_obj
990+ op , custom_reduction , * input_obj , output_limit = output_limit
955991 )[0 ]
956992 else :
957993 input_obj = cls ._get_grouped (op , in_data_dict [output_key ], ctx )
0 commit comments