@@ -68,7 +68,8 @@ def __init__(self, properties: QueryProperties, **parameters):
6868        self .column_map , self .aggregate_functions  =  build_aggregations (self .aggregates )
6969
7070        self .buffer  =  []
71-         self .max_buffer_size  =  50   # Process in chunks to avoid excessive memory usage 
71+         self .max_buffer_size  =  100   # Process in chunks to avoid excessive memory usage 
72+         self ._partial_aggregated  =  False   # Track if we've done a partial aggregation 
7273
7374    @property  
7475    def  config (self ):  # pragma: no cover 
@@ -86,38 +87,122 @@ def execute(self, morsel: pyarrow.Table, **kwargs):
8687                yield  EOS 
8788                return 
8889
89-             # If we have partial results in buffer, do final aggregation 
90-             if  len (self .buffer ) >  0 :
91-                 table  =  pyarrow .concat_tables (
92-                     self .buffer ,
93-                     promote_options = "permissive" ,
94-                 )
90+             # Do final aggregation if we have buffered data 
91+             table  =  pyarrow .concat_tables (
92+                 self .buffer ,
93+                 promote_options = "permissive" ,
94+             )
95+             # Only combine chunks if we haven't done partial aggregation yet 
96+             # combine_chunks can fail after partial aggregation due to buffer structure 
97+             if  not  self ._partial_aggregated :
9598                table  =  table .combine_chunks ()
99+ 
100+             # If we've done partial aggregations, the aggregate functions need adjusting 
101+             # because columns like "*" have been renamed to "*_count" 
102+             if  self ._partial_aggregated :
103+                 # Build new aggregate functions for re-aggregating partial results 
104+                 adjusted_aggs  =  []
105+                 adjusted_column_map  =  {}
106+ 
107+                 for  field_name , function , _count_options  in  self .aggregate_functions :
108+                     # For COUNT aggregates, the column is now named "*_count" and we need to SUM it 
109+                     if  function  ==  "count" :
110+                         renamed_field  =  f"{ field_name }  
111+                         adjusted_aggs .append ((renamed_field , "sum" , None ))
112+                         # The final column will be named "*_count_sum", need to track for renaming 
113+                         for  orig_name , mapped_name  in  self .column_map .items ():
114+                             if  mapped_name  ==  f"{ field_name }  :
115+                                 adjusted_column_map [orig_name ] =  f"{ renamed_field }  
116+                     # For other aggregates, we can re-aggregate with the same function 
117+                     else :
118+                         renamed_field  =  f"{ field_name } { function }  .replace ("_hash_" , "_" )
119+                         # Some aggregates can be re-aggregated (sum, max, min) 
120+                         if  function  in  ("sum" , "max" , "min" , "hash_one" , "all" , "any" ):
121+                             adjusted_aggs .append ((renamed_field , function , None ))
122+                             # Track the mapping: original -> intermediate -> final 
123+                             for  orig_name , mapped_name  in  self .column_map .items ():
124+                                 if  mapped_name  ==  renamed_field :
125+                                     # sum->sum, max->max, etc. means same name 
126+                                     adjusted_column_map [orig_name ] =  (
127+                                         f"{ renamed_field } { function }  .replace ("_hash_" , "_" )
128+                                     )
129+                         elif  function  ==  "mean" :
130+                             # For mean, just take one of the existing values (not ideal) 
131+                             adjusted_aggs .append ((renamed_field , "hash_one" , None ))
132+                             for  orig_name , mapped_name  in  self .column_map .items ():
133+                                 if  mapped_name  ==  renamed_field :
134+                                     adjusted_column_map [orig_name ] =  f"{ renamed_field }  
135+                         elif  function  ==  "hash_list" :
136+                             # For ARRAY_AGG, we need to flatten lists 
137+                             adjusted_aggs .append ((renamed_field , "hash_list" , None ))
138+                             for  orig_name , mapped_name  in  self .column_map .items ():
139+                                 if  mapped_name  ==  renamed_field :
140+                                     adjusted_column_map [orig_name ] =  f"{ renamed_field }  
141+                         else :
142+                             # For other aggregates, take one value 
143+                             adjusted_aggs .append ((renamed_field , "hash_one" , None ))
144+                             for  orig_name , mapped_name  in  self .column_map .items ():
145+                                 if  mapped_name  ==  renamed_field :
146+                                     adjusted_column_map [orig_name ] =  f"{ renamed_field }  
147+ 
148+                 groups  =  table .group_by (self .group_by_columns )
149+                 groups  =  groups .aggregate (adjusted_aggs )
150+ 
151+                 # Use the adjusted column map for selecting/renaming 
152+                 groups  =  groups .select (list (adjusted_column_map .values ()) +  self .group_by_columns )
153+                 groups  =  groups .rename_columns (
154+                     list (adjusted_column_map .keys ()) +  self .group_by_columns 
155+                 )
156+             else :
96157                groups  =  table .group_by (self .group_by_columns )
97158                groups  =  groups .aggregate (self .aggregate_functions )
98-                 self .buffer  =  [groups ]  # Replace buffer with final result 
99- 
100-             # Now buffer has the final aggregated result 
101-             groups  =  self .buffer [0 ]
102- 
103-             # do the secondary activities for ARRAY_AGG 
104-             for  node  in  get_all_nodes_of_type (self .aggregates , select_nodes = (NodeType .AGGREGATOR ,)):
105-                 if  node .value  ==  "ARRAY_AGG"  and  node .order  or  node .limit :
106-                     # rip the column out of the table 
107-                     column_name  =  self .column_map [node .schema_column .identity ]
108-                     column_def  =  groups .field (column_name )  # this is used 
109-                     column  =  groups .column (column_name ).to_pylist ()
110-                     groups  =  groups .drop ([column_name ])
159+ 
160+                 # project to the desired column names from the pyarrow names 
161+                 groups  =  groups .select (list (self .column_map .values ()) +  self .group_by_columns )
162+                 groups  =  groups .rename_columns (list (self .column_map .keys ()) +  self .group_by_columns )
163+ 
164+             # do the secondary activities for ARRAY_AGG (order and limit) 
165+             array_agg_nodes  =  [
166+                 node 
167+                 for  node  in  get_all_nodes_of_type (
168+                     self .aggregates , select_nodes = (NodeType .AGGREGATOR ,)
169+                 )
170+                 if  node .value  ==  "ARRAY_AGG"  and  (node .order  or  node .limit )
171+             ]
172+ 
173+             if  array_agg_nodes :
174+                 # Process all ARRAY_AGG columns that need ordering/limiting 
175+                 arrays_to_update  =  {}
176+                 field_defs  =  {}
177+ 
178+                 for  node  in  array_agg_nodes :
179+                     column_name  =  node .schema_column .identity 
180+ 
181+                     # Store field definition before we drop the column 
182+                     field_defs [column_name ] =  groups .field (column_name )
183+ 
184+                     # Extract and process the data 
185+                     column_data  =  groups .column (column_name ).to_pylist ()
186+ 
187+                     # Apply ordering if specified 
111188                    if  node .order :
112-                         column  =  [sorted (c , reverse = bool (node .order [0 ][1 ])) for  c  in  column ]
189+                         column_data  =  [
190+                             sorted (c , reverse = bool (node .order [0 ][1 ])) for  c  in  column_data 
191+                         ]
192+ 
193+                     # Apply limit if specified 
113194                    if  node .limit :
114-                         column  =  [c [: node .limit ] for  c  in  column ]
115-                     # put the new column into the table 
116-                     groups  =  groups .append_column (column_def , [column ])
195+                         column_data  =  [c [: node .limit ] for  c  in  column_data ]
196+ 
197+                     arrays_to_update [column_name ] =  column_data 
198+ 
199+                 # Drop all columns we're updating 
200+                 columns_to_drop  =  list (arrays_to_update .keys ())
201+                 groups  =  groups .drop (columns_to_drop )
117202
118-             # project to the desired column names from the pyarrow names 
119-             groups   =   groups . select ( list ( self . column_map . values ())  +   self . group_by_columns ) 
120-             groups  =  groups .rename_columns ( list ( self . column_map . keys ())  +   self . group_by_columns )
203+                  # Append all updated columns back 
204+                  for   column_name ,  column_data   in   arrays_to_update . items (): 
205+                      groups  =  groups .append_column ( field_defs [ column_name ], [ column_data ] )
121206
122207            num_rows  =  groups .num_rows 
123208            for  start  in  range (0 , num_rows , CHUNK_SIZE ):
@@ -128,9 +213,10 @@ def execute(self, morsel: pyarrow.Table, **kwargs):
128213
129214        morsel  =  project (morsel , self .all_identifiers )
130215        # Add a "*" column, this is an int because when a bool it miscounts 
216+         # FIX: Use int8 as the comment states (bool can miscount) 
131217        if  "*"  not  in morsel .column_names :
132218            morsel  =  morsel .append_column (
133-                 "*" , [numpy .ones (shape = morsel .num_rows , dtype = numpy .bool_ )]
219+                 "*" , [numpy .ones (shape = morsel .num_rows , dtype = numpy .int8 )]
134220            )
135221        if  self .evaluatable_nodes :
136222            morsel  =  evaluate_and_append (self .evaluatable_nodes , morsel )
@@ -144,9 +230,11 @@ def execute(self, morsel: pyarrow.Table, **kwargs):
144230                self .buffer ,
145231                promote_options = "permissive" ,
146232            )
233+             # Only combine chunks once before aggregation 
147234            table  =  table .combine_chunks ()
148235            groups  =  table .group_by (self .group_by_columns )
149236            groups  =  groups .aggregate (self .aggregate_functions )
150237            self .buffer  =  [groups ]  # Replace buffer with partial result 
238+             self ._partial_aggregated  =  True   # Mark that we've done a partial aggregation 
151239
152240        yield  None 
0 commit comments