1717
1818from ... import opcodes
1919from ...core import OutputType , get_output_types , recursive_tile
20- from ...serialization .serializables import DictField , Int64Field
20+ from ...serialization .serializables import DictField , Int64Field , BoolField
21+ from ...utils import pd_release_version
2122from ..core import IndexValue
2223from ..operands import DataFrameOperandMixin , DataFrameOperand
2324from ..utils import build_concatenated_rows_frame , parse_index
2425
26+ _pandas_enable_negative = pd_release_version >= (1 , 4 , 0 )
27+
2528
2629class GroupByHead (DataFrameOperand , DataFrameOperandMixin ):
2730 _op_type_ = opcodes .GROUPBY_HEAD
2831 _op_module_ = "dataframe.groupby"
2932
30- _row_count = Int64Field ("row_count" )
31- _groupby_params = DictField ("groupby_params" )
32-
33- def __init__ (self , row_count = None , groupby_params = None , ** kw ):
34- super ().__init__ (_row_count = row_count , _groupby_params = groupby_params , ** kw )
35-
36- @property
37- def row_count (self ) -> int :
38- return self ._row_count
39-
40- @property
41- def groupby_params (self ) -> dict :
42- return self ._groupby_params
33+ row_count = Int64Field ("row_count" )
34+ groupby_params = DictField ("groupby_params" )
35+ enable_negative = BoolField ("enable_negative" )
4336
4437 def __call__ (self , groupby ):
4538 df = groupby
@@ -72,30 +65,32 @@ def tile(cls, op: "GroupByHead"):
7265 groupby_params = op .groupby_params .copy ()
7366 selection = groupby_params .pop ("selection" , None )
7467
68+ enable_negative = _pandas_enable_negative and op .enable_negative
69+
7570 if len (in_df .shape ) > 1 :
7671 in_df = build_concatenated_rows_frame (in_df )
7772 out_df = op .outputs [0 ]
7873
79- # when row_count is not positive or there is only one chunk,
80- # tile with a single chunk
81- if op .row_count <= 0 or len (in_df .chunks ) == 0 :
74+ # when row_count is not positive and pandas does not support negative head,
75+ # or there is only one chunk, tile with a single chunk
76+ if (not enable_negative and op .row_count <= 0 ) or len (in_df .chunks ) <= 1 :
77+ row_num = 0 if not enable_negative and op .row_count <= 0 else np .nan
78+ new_shape = (row_num ,)
79+ new_nsplits = ((row_num ,),)
80+ if out_df .ndim > 1 :
81+ new_shape += (out_df .shape [1 ],)
82+ new_nsplits += ((out_df .shape [1 ],),)
83+
8284 c = in_df .chunks [0 ]
8385 chunk_op = op .copy ().reset_key ()
84- params = c .params
85- row_num = 0 if op .row_count <= 0 else np .nan
86- params ["shape" ] = (row_num ,) + c .shape [1 :]
87- params ["index_value" ] = out_df .index_value
86+ params = out_df .params
87+ params ["shape" ] = new_shape
88+ params ["index" ] = (0 ,) * out_df .ndim
8889 out_chunk = chunk_op .new_chunk ([c ], ** params )
8990
9091 tileable_op = op .copy ().reset_key ()
91- params = out_df .params
92- params ["shape" ] = (row_num ,) + c .shape [1 :]
93- params ["index_value" ] = out_df .index_value
9492 return tileable_op .new_tileables (
95- [in_df ],
96- nsplits = ((row_num ,),) + in_df .nsplits [1 :],
97- chunks = [out_chunk ],
98- ** params
93+ [in_df ], nsplits = new_nsplits , chunks = [out_chunk ], ** params
9994 )
10095
10196 if in_df .ndim > 1 and selection :
@@ -116,15 +111,19 @@ def tile(cls, op: "GroupByHead"):
116111 in_df = yield from recursive_tile (in_df [pre_selection ])
117112
118113 # generate pre chunks
119- pre_chunks = []
120- for c in in_df .chunks :
121- pre_op = op .copy ().reset_key ()
122- pre_op ._output_types = get_output_types (c )
123- pre_op ._groupby_params = op .groupby_params .copy ()
124- pre_op ._groupby_params .pop ("selection" , None )
125- params = c .params
126- params ["shape" ] = (np .nan ,) + c .shape [1 :]
127- pre_chunks .append (pre_op .new_chunk ([c ], ** params ))
114+ if op .row_count < 0 :
115+ # when we have negative row counts, pre-groupby optimization is not possible
116+ pre_chunks = in_df .chunks
117+ else :
118+ pre_chunks = []
119+ for c in in_df .chunks :
120+ pre_op = op .copy ().reset_key ()
121+ pre_op ._output_types = get_output_types (c )
122+ pre_op .groupby_params = op .groupby_params .copy ()
123+ pre_op .groupby_params .pop ("selection" , None )
124+ params = c .params
125+ params ["shape" ] = (np .nan ,) + c .shape [1 :]
126+ pre_chunks .append (pre_op .new_chunk ([c ], ** params ))
128127
129128 new_op = op .copy ().reset_key ()
130129 new_op ._output_types = get_output_types (in_df )
@@ -142,8 +141,8 @@ def tile(cls, op: "GroupByHead"):
142141 post_chunks = []
143142 for c in grouped .chunks :
144143 post_op = op .copy ().reset_key ()
145- post_op ._groupby_params = op .groupby_params .copy ()
146- post_op ._groupby_params .pop ("selection" , None )
144+ post_op .groupby_params = op .groupby_params .copy ()
145+ post_op .groupby_params .pop ("selection" , None )
147146 if op .output_types [0 ] == OutputType .dataframe :
148147 index = c .index
149148 else :
@@ -175,7 +174,10 @@ def execute(cls, ctx, op: "GroupByHead"):
175174 if selection :
176175 grouped = grouped [selection ]
177176
178- ctx [op .outputs [0 ].key ] = grouped .head (op .row_count )
177+ result = grouped .head (op .row_count )
178+ if not op .enable_negative and op .row_count < 0 :
179+ result = result .iloc [:0 ]
180+ ctx [op .outputs [0 ].key ] = result
179181
180182
181183def head (groupby , n = 5 ):
@@ -215,5 +217,9 @@ def head(groupby, n=5):
215217 groupby_params = groupby .op .groupby_params .copy ()
216218 groupby_params .pop ("as_index" , None )
217219
218- op = GroupByHead (row_count = n , groupby_params = groupby_params )
220+ op = GroupByHead (
221+ row_count = n ,
222+ groupby_params = groupby_params ,
223+ enable_negative = _pandas_enable_negative ,
224+ )
219225 return op (groupby )
0 commit comments