1+ use datafusion:: logical_expr:: EmitTo ;
12use datafusion:: physical_plan:: aggregates:: group_values:: multi_group_by:: GroupColumn ;
23
34use std:: mem:: { self , size_of} ;
@@ -6,15 +7,15 @@ use datafusion::arrow::array::{Array, ArrayRef, RecordBatch};
67use datafusion:: arrow:: compute:: cast;
78use datafusion:: arrow:: datatypes:: {
89 BinaryType , BinaryViewType , DataType , Date32Type , Date64Type , Decimal128Type , Float32Type ,
9- Float64Type , Int16Type , Int32Type , Int64Type , Int8Type , LargeBinaryType , LargeUtf8Type ,
10- Schema , SchemaRef , StringViewType , Time32MillisecondType , Time32SecondType ,
11- Time64MicrosecondType , Time64NanosecondType , TimeUnit , TimestampMicrosecondType ,
12- TimestampMillisecondType , TimestampNanosecondType , TimestampSecondType , UInt16Type ,
13- UInt32Type , UInt64Type , UInt8Type , Utf8Type ,
10+ Float64Type , Int16Type , Int32Type , Int64Type , Int8Type , LargeBinaryType , LargeUtf8Type , Schema ,
11+ SchemaRef , StringViewType , Time32MillisecondType , Time32SecondType , Time64MicrosecondType ,
12+ Time64NanosecondType , TimeUnit , TimestampMicrosecondType , TimestampMillisecondType ,
13+ TimestampNanosecondType , TimestampSecondType , UInt16Type , UInt32Type , UInt64Type , UInt8Type ,
14+ Utf8Type ,
1415} ;
1516use datafusion:: dfschema:: internal_err;
1617use datafusion:: dfschema:: not_impl_err;
17- use datafusion:: error:: Result as DFResult ;
18+ use datafusion:: error:: { DataFusionError , Result as DFResult } ;
1819use datafusion:: physical_expr:: binary_map:: OutputType ;
1920use datafusion:: physical_plan:: aggregates:: group_values:: multi_group_by:: {
2021 ByteGroupValueBuilder , ByteViewGroupValueBuilder , PrimitiveGroupValueBuilder ,
@@ -129,41 +130,83 @@ impl SortedGroupValues {
129130 & DataType :: Time32 ( t) => match t {
130131 TimeUnit :: Second => {
131132 instantiate_primitive ! ( v, nullable, Time32SecondType , data_type) ;
132- instantiate_primitive_comparator ! ( comparators, nullable, Time32SecondType ) ;
133+ instantiate_primitive_comparator ! (
134+ comparators,
135+ nullable,
136+ Time32SecondType
137+ ) ;
133138 }
134139 TimeUnit :: Millisecond => {
135140 instantiate_primitive ! ( v, nullable, Time32MillisecondType , data_type) ;
136- instantiate_primitive_comparator ! ( comparators, nullable, Time32MillisecondType ) ;
141+ instantiate_primitive_comparator ! (
142+ comparators,
143+ nullable,
144+ Time32MillisecondType
145+ ) ;
137146 }
138147 _ => { }
139148 } ,
140149 & DataType :: Time64 ( t) => match t {
141150 TimeUnit :: Microsecond => {
142151 instantiate_primitive ! ( v, nullable, Time64MicrosecondType , data_type) ;
143- instantiate_primitive_comparator ! ( comparators, nullable, Time64MicrosecondType ) ;
152+ instantiate_primitive_comparator ! (
153+ comparators,
154+ nullable,
155+ Time64MicrosecondType
156+ ) ;
144157 }
145158 TimeUnit :: Nanosecond => {
146159 instantiate_primitive ! ( v, nullable, Time64NanosecondType , data_type) ;
147- instantiate_primitive_comparator ! ( comparators, nullable, Time64NanosecondType ) ;
160+ instantiate_primitive_comparator ! (
161+ comparators,
162+ nullable,
163+ Time64NanosecondType
164+ ) ;
148165 }
149166 _ => { }
150167 } ,
151168 & DataType :: Timestamp ( t, _) => match t {
152169 TimeUnit :: Second => {
153170 instantiate_primitive ! ( v, nullable, TimestampSecondType , data_type) ;
154- instantiate_primitive_comparator ! ( comparators, nullable, TimestampSecondType ) ;
171+ instantiate_primitive_comparator ! (
172+ comparators,
173+ nullable,
174+ TimestampSecondType
175+ ) ;
155176 }
156177 TimeUnit :: Millisecond => {
157- instantiate_primitive ! ( v, nullable, TimestampMillisecondType , data_type) ;
158- instantiate_primitive_comparator ! ( comparators, nullable, TimestampMillisecondType ) ;
178+ instantiate_primitive ! (
179+ v,
180+ nullable,
181+ TimestampMillisecondType ,
182+ data_type
183+ ) ;
184+ instantiate_primitive_comparator ! (
185+ comparators,
186+ nullable,
187+ TimestampMillisecondType
188+ ) ;
159189 }
160190 TimeUnit :: Microsecond => {
161- instantiate_primitive ! ( v, nullable, TimestampMicrosecondType , data_type) ;
162- instantiate_primitive_comparator ! ( comparators, nullable, TimestampMicrosecondType ) ;
191+ instantiate_primitive ! (
192+ v,
193+ nullable,
194+ TimestampMicrosecondType ,
195+ data_type
196+ ) ;
197+ instantiate_primitive_comparator ! (
198+ comparators,
199+ nullable,
200+ TimestampMicrosecondType
201+ ) ;
163202 }
164203 TimeUnit :: Nanosecond => {
165204 instantiate_primitive ! ( v, nullable, TimestampNanosecondType , data_type) ;
166- instantiate_primitive_comparator ! ( comparators, nullable, TimestampNanosecondType ) ;
205+ instantiate_primitive_comparator ! (
206+ comparators,
207+ nullable,
208+ TimestampNanosecondType
209+ ) ;
167210 }
168211 } ,
169212 & DataType :: Decimal128 ( _, _) => {
@@ -231,8 +274,8 @@ impl SortedGroupValues {
231274 self . group_values [ 0 ] . len ( )
232275 }
233276
234- pub fn emit ( & mut self ) -> DFResult < Vec < ArrayRef > > {
235- /* let mut output = match emit_to {
277+ fn emit ( & mut self , emit_to : EmitTo ) -> DFResult < Vec < ArrayRef > > {
278+ let mut output = match emit_to {
236279 EmitTo :: All => {
237280 let group_values = mem:: take ( & mut self . group_values ) ;
238281 debug_assert ! ( self . group_values. is_empty( ) ) ;
@@ -253,7 +296,6 @@ impl SortedGroupValues {
253296 }
254297 } ;
255298
256- // TODO: Materialize dictionaries in group keys (#7647)
257299 for ( field, array) in self . schema . fields . iter ( ) . zip ( & mut output) {
258300 let expected = field. data_type ( ) ;
259301 if let DataType :: Dictionary ( _, v) = expected {
@@ -267,78 +309,82 @@ impl SortedGroupValues {
267309 }
268310 }
269311
270- Ok(output) */
271- todo ! ( )
312+ Ok ( output)
272313 }
273314
274315 fn clear_shrink ( & mut self , batch : & RecordBatch ) {
275316 self . group_values . clear ( ) ;
317+ self . comparators . clear ( ) ;
276318 self . rows_inds . clear ( ) ;
277319 self . equal_to_results . clear ( ) ;
278320 }
279321
280322 fn intern_impl ( & mut self , cols : & [ ArrayRef ] , groups : & mut Vec < usize > ) -> DFResult < ( ) > {
281- /* let n_rows = cols[0].len();
323+ let n_rows = cols[ 0 ] . len ( ) ;
282324 groups. clear ( ) ;
283325
284326 if n_rows == 0 {
285327 return Ok ( ( ) ) ;
286328 }
287329
330+ // Handle first row - compare with last group or create new group
288331 let first_group_idx = self . make_new_group_if_needed ( cols, 0 ) ;
289332 groups. push ( first_group_idx) ;
290333
291334 if n_rows == 1 {
292335 return Ok ( ( ) ) ;
293336 }
294337
295- if self.rows_inds.len() < n_rows {
296- let old_len = self.rows_inds.len();
297- self.rows_inds.extend(old_len..n_rows);
298- }
299-
300- self.equal_to_results.fill(true);
338+ // Prepare buffer for vectorized comparison
301339 self . equal_to_results . resize ( n_rows - 1 , true ) ;
340+ self . equal_to_results [ ..n_rows - 1 ] . fill ( true ) ;
302341
303- let lhs_rows = &self.rows_inds[0..n_rows - 1];
304- let rhs_rows = &self.rows_inds[1..n_rows];
305- for (col_idx, group_col) in self.group_values.iter().enumerate() {
306- cols[col_idx].vectorized_equal_to(
307- lhs_rows,
308- &cols[col_idx],
309- rhs_rows,
310- &mut self.equal_to_results,
311- );
342+ // Vectorized comparison: compare row[i] with row[i+1] for all columns
343+ for ( col, comparator) in cols. iter ( ) . zip ( & self . comparators ) {
344+ comparator. compare_adjacent ( col, & mut self . equal_to_results [ ..n_rows - 1 ] ) ;
312345 }
313- println!("!!!!! AAAAAAAAAA");
346+
347+ // Build groups based on comparison results
314348 let mut current_group_idx = first_group_idx;
315349 for i in 0 ..n_rows - 1 {
316350 if !self . equal_to_results [ i] {
351+ // Group boundary detected - add new group
317352 for ( col_idx, group_value) in self . group_values . iter_mut ( ) . enumerate ( ) {
318353 group_value. append_val ( & cols[ col_idx] , i + 1 ) ;
319354 }
320355 current_group_idx = self . group_values [ 0 ] . len ( ) - 1 ;
321356 }
322357 groups. push ( current_group_idx) ;
323358 }
324- println!("!!!!! BBBBBBB");
325- Ok(()) */
359+
326360 Ok ( ( ) )
327361 }
328362
363+ /// Compare the specified row with the last group and create a new group if different.
364+ ///
365+ /// This is used to handle the first row of a batch, which needs to be compared
366+ /// with the last group from the previous batch to detect group boundaries across batches.
367+ ///
368+ /// Returns the group index for this row.
329369 fn make_new_group_if_needed ( & mut self , cols : & [ ArrayRef ] , row : usize ) -> usize {
330370 let new_group_needed = if self . group_values [ 0 ] . len ( ) == 0 {
371+ // No groups yet - always create first group
331372 true
332373 } else {
374+ // Compare with last group - if any column differs, need new group
333375 self . group_values . iter ( ) . enumerate ( ) . any ( |( i, group_val) | {
334376 !group_val. equal_to ( self . group_values [ 0 ] . len ( ) - 1 , & cols[ i] , row)
335377 } )
336378 } ;
379+
337380 if new_group_needed {
381+ // Add new group with values from this row
338382 for ( i, group_value) in self . group_values . iter_mut ( ) . enumerate ( ) {
339383 group_value. append_val ( & cols[ i] , row) ;
340384 }
341385 }
386+
387+ // Return index of the group (either newly created or existing last group)
342388 self . group_values [ 0 ] . len ( ) - 1
343389 }
344390}
0 commit comments