@@ -7,6 +7,7 @@ use crate::metastore::multi_index::MultiPartition;
77use crate :: metastore:: table:: Table ;
88use crate :: metastore:: { Column , ColumnType , IdRow , Index , Partition } ;
99use crate :: queryplanner:: filter_by_key_range:: FilterByKeyRangeExec ;
10+ use crate :: queryplanner:: inline_aggregate:: sorted_group_values:: SortedGroupValues ;
1011use crate :: queryplanner:: merge_sort:: LastRowByUniqueKeyExec ;
1112use crate :: queryplanner:: metadata_cache:: { MetadataCacheFactory , NoopParquetMetadataCache } ;
1213use crate :: queryplanner:: optimizations:: { CubeQueryPlanner , PreOptimizeRule } ;
@@ -50,8 +51,8 @@ use datafusion::dfschema::internal_err;
5051use datafusion:: dfschema:: not_impl_err;
5152use datafusion:: error:: DataFusionError ;
5253use datafusion:: error:: Result as DFResult ;
53- use datafusion:: execution:: TaskContext ;
54- use datafusion:: logical_expr:: { Expr , GroupsAccumulator , LogicalPlan } ;
54+ use datafusion:: execution:: { RecordBatchStream , TaskContext } ;
55+ use datafusion:: logical_expr:: { EmitTo , Expr , GroupsAccumulator , LogicalPlan } ;
5556use datafusion:: physical_expr:: expressions:: Column as DFColumn ;
5657use datafusion:: physical_expr:: LexOrdering ;
5758use datafusion:: physical_expr:: { self , GroupsAccumulatorAdapter } ;
@@ -135,6 +136,7 @@ pub(crate) struct InlineAggregateStream {
135136 input_done : bool ,
136137
137138 accumulators : Vec < Box < dyn GroupsAccumulator > > ,
139+ group_values : SortedGroupValues ,
138140 current_group_indices : Vec < usize > ,
139141}
140142
@@ -189,6 +191,7 @@ impl InlineAggregateStream {
189191
190192 let exec_state = ExecutionState :: ReadingInput ;
191193 let current_group_indices = Vec :: with_capacity ( batch_size) ;
194+ let group_values = SortedGroupValues :: try_new ( group_schema) ?;
192195
193196 Ok ( InlineAggregateStream {
194197 schema : agg_schema,
@@ -201,6 +204,7 @@ impl InlineAggregateStream {
201204 exec_state,
202205 batch_size,
203206 current_group_indices,
207+ group_values,
204208 input_done : false ,
205209 } )
206210 }
@@ -303,144 +307,68 @@ impl Stream for InlineAggregateStream {
303307 ) -> Poll < Option < Self :: Item > > {
304308 loop {
305309 match & self . exec_state {
306- ExecutionState :: ReadingInput => ' reading_input : {
310+ ExecutionState :: ReadingInput => {
307311 match ready ! ( self . input. poll_next_unpin( cx) ) {
308- // New batch to aggregate in partial aggregation operator
309- Some ( Ok ( batch) ) if self . mode == InlineAggregateMode :: Partial => {
310- /* let timer = elapsed_compute.timer();
311- let input_rows = batch.num_rows();
312-
313- // Do the grouping
314- self.group_aggregate_batch(batch)?;
315-
316- self.update_skip_aggregation_probe(input_rows);
317-
318- // If we can begin emitting rows, do so,
319- // otherwise keep consuming input
320- assert!(!self.input_done);
321-
322- // If the number of group values equals or exceeds the soft limit,
323- // emit all groups and switch to producing output
324- if self.hit_soft_group_limit() {
325- timer.done();
326- self.set_input_done_and_produce_output()?;
327- // make sure the exec_state just set is not overwritten below
328- break 'reading_input;
329- }
330-
331- if let Some(to_emit) = self.group_ordering.emit_to() {
332- timer.done();
333- if let Some(batch) = self.emit(to_emit, false)? {
334-
335- ExecutionState::ProducingOutput(batch);
336- };
337- // make sure the exec_state just set is not overwritten below
338- break 'reading_input;
339- }
340-
341- self.emit_early_if_necessary()?;
342-
343- self.switch_to_skip_aggregation()?;
344-
345- timer.done(); */
346- todo ! ( )
347- }
348-
349- // New batch to aggregate in terminal aggregation operator
350- // (Final/FinalPartitioned/Single/SinglePartitioned)
312+ // New input batch to aggregate
351313 Some ( Ok ( batch) ) => {
352- /* let timer = elapsed_compute.timer();
353-
354- // Make sure we have enough capacity for `batch`, otherwise spill
355- self.spill_previous_if_necessary(&batch)?;
356-
357- // Do the grouping
358-
359-
360- // If we can begin emitting rows, do so,
361- // otherwise keep consuming input
362- assert!(!self.input_done);
363-
364- // If the number of group values equals or exceeds the soft limit,
365- // emit all groups and switch to producing output
366- if self.hit_soft_group_limit() {
367- timer.done();
368- self.set_input_done_and_produce_output()?;
369- // make sure the exec_state just set is not overwritten below
370- break 'reading_input;
314+ // Aggregate the batch
315+ if let Err ( e) = self . group_aggregate_batch ( batch) {
316+ return Poll :: Ready ( Some ( Err ( e) ) ) ;
371317 }
372318
373- if let Some(to_emit) = self.group_ordering.emit_to() {
374- timer.done();
375- if let Some(batch) = self.emit(to_emit, false)? {
376- self.exec_state =
377- ExecutionState::ProducingOutput(batch);
378- };
379- // make sure the exec_state just set is not overwritten below
380- break 'reading_input;
319+ // Try to emit a batch if we have enough groups
320+ match self . emit_early_if_ready ( ) {
321+ Ok ( Some ( batch) ) => {
322+ self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
323+ }
324+ Ok ( None ) => {
325+ // Not enough groups yet, continue reading
326+ }
327+ Err ( e) => {
328+ return Poll :: Ready ( Some ( Err ( e) ) ) ;
329+ }
381330 }
382-
383- timer.done(); */
384- todo ! ( )
385331 }
386332
387- // Found error from input stream
333+ // Error from input stream
388334 Some ( Err ( e) ) => {
389- // inner had error, return to caller
390335 return Poll :: Ready ( Some ( Err ( e) ) ) ;
391336 }
392337
393- // Found end from input stream
338+ // Input stream exhausted - emit all remaining groups
394339 None => {
395- // inner is done, emit all rows and switch to producing output
396- //self.set_input_done_and_produce_output()?;
397- todo ! ( )
340+ self . input_done = true ;
341+
342+ match self . emit ( EmitTo :: All ) {
343+ Ok ( Some ( batch) ) => {
344+ self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
345+ }
346+ Ok ( None ) => {
347+ // No groups to emit, we're done
348+ self . exec_state = ExecutionState :: Done ;
349+ }
350+ Err ( e) => {
351+ return Poll :: Ready ( Some ( Err ( e) ) ) ;
352+ }
353+ }
398354 }
399355 }
400356 }
401357
402358 ExecutionState :: ProducingOutput ( batch) => {
403- // slice off a part of the batch, if needed
404- /* let output_batch;
405- let size = self.batch_size;
406- (self.exec_state, output_batch) = if batch.num_rows() <= size {
407- (
408- if self.input_done {
409- ExecutionState::Done
410- }
411- // In Partial aggregation, we also need to check
412- // if we should trigger partial skipping
413- else if self.mode == AggregateMode::Partial
414- && self.should_skip_aggregation()
415- {
416- ExecutionState::SkippingAggregation
417- } else {
418- ExecutionState::ReadingInput
419- },
420- batch.clone(),
421- )
359+ let batch = batch. clone ( ) ;
360+
361+ // Determine next state
362+ self . exec_state = if self . input_done {
363+ ExecutionState :: Done
422364 } else {
423- // output first batch_size rows
424- let size = self.batch_size;
425- let num_remaining = batch.num_rows() - size;
426- let remaining = batch.slice(size, num_remaining);
427- let output = batch.slice(0, size);
428- (ExecutionState::ProducingOutput(remaining), output)
365+ ExecutionState :: ReadingInput
429366 } ;
430- // Empty record batches should not be emitted.
431- // They need to be treated as [`Option<RecordBatch>`]es and handled separately
432- debug_assert!(output_batch.num_rows() > 0);
433- return Poll::Ready(Some(Ok(
434- output_batch.record_output(&self.baseline_metrics)
435- ))); */
436- todo ! ( )
367+
368+ return Poll :: Ready ( Some ( Ok ( batch) ) ) ;
437369 }
438370
439371 ExecutionState :: Done => {
440- // release the memory reservation since sending back output batch itself needs
441- // some memory reservation, so make some room for it.
442- /* self.clear_all();
443- let _ = self.update_memory_reservation(); */
444372 return Poll :: Ready ( None ) ;
445373 }
446374 }
@@ -449,58 +377,109 @@ impl Stream for InlineAggregateStream {
449377}
450378
451379impl InlineAggregateStream {
380+ /// Emit groups based on EmitTo strategy.
381+ ///
382+ /// Returns None if there are no groups to emit.
383+ /// Emit groups based on EmitTo strategy.
384+ ///
385+ /// Returns None if there are no groups to emit.
386+ fn emit ( & mut self , emit_to : EmitTo ) -> DFResult < Option < RecordBatch > > {
387+ if self . group_values . is_empty ( ) {
388+ return Ok ( None ) ;
389+ }
390+
391+ // Get group values arrays
392+ let group_arrays = self . group_values . emit ( emit_to) ?;
393+
394+ // Get aggregate arrays based on mode
395+ let mut aggr_arrays = vec ! [ ] ;
396+ for acc in & mut self . accumulators {
397+ match self . mode {
398+ InlineAggregateMode :: Partial => {
399+ // Emit intermediate state
400+ let state = acc. state ( emit_to) ?;
401+ aggr_arrays. extend ( state) ;
402+ }
403+ InlineAggregateMode :: Final => {
404+ // Emit final aggregated values
405+ aggr_arrays. push ( acc. evaluate ( emit_to) ?) ;
406+ }
407+ }
408+ }
409+
410+ // Combine group columns and aggregate columns
411+ let mut columns = group_arrays;
412+ columns. extend ( aggr_arrays) ;
413+
414+ let batch = RecordBatch :: try_new ( Arc :: clone ( & self . schema ) , columns) ?;
415+
416+ Ok ( Some ( batch) )
417+ }
418+
419+ /// Check if we have enough groups to emit a batch, keeping the last (potentially incomplete) group.
420+ ///
421+ /// For sorted aggregation, we emit batches of size batch_size when we have accumulated
422+ /// more than batch_size groups. We always keep the last group as it may continue in the next input batch.
423+ fn should_emit_early ( & self ) -> bool {
424+ // Need at least (batch_size + 1) groups to emit batch_size and keep 1
425+ self . group_values . len ( ) > self . batch_size
426+ }
427+
428+ /// Emit a batch of groups if we have enough accumulated, keeping the last group.
429+ ///
430+ /// Returns Some(batch) if emitted, None otherwise.
431+ fn emit_early_if_ready ( & mut self ) -> DFResult < Option < RecordBatch > > {
432+ if !self . should_emit_early ( ) {
433+ return Ok ( None ) ;
434+ }
435+
436+ // Emit exactly batch_size groups, keeping the rest (including last incomplete group)
437+ self . emit ( EmitTo :: First ( self . batch_size ) )
438+ }
439+
452440 fn group_aggregate_batch ( & mut self , batch : RecordBatch ) -> DFResult < ( ) > {
453441 // Evaluate the grouping expressions
454- /* let group_by_values = evaluate_group_by(&self.group_by, &batch)?;
442+ let group_by_values = evaluate_group_by ( & self . group_by , & batch) ?;
455443
456444 // Evaluate the aggregation expressions.
457445 let input_values = evaluate_many ( & self . aggregate_arguments , & batch) ?;
458446
459447 // Evaluate the filter expressions, if any, against the inputs
460448 let filter_values = evaluate_optional ( & self . filter_expressions , & batch) ?;
461449
462- for group_values in &group_by_values {
463- // calculate the group indices for each input row
464- let starting_num_groups = self.group_values.len();
465- self.group_values
466- .intern(group_values, &mut self.current_group_indices)?;
467- let group_indices = &self.current_group_indices;
468-
469- // Update ordering information if necessary
470- /* let total_num_groups = self.group_values.len();
471- if total_num_groups > starting_num_groups {
472- self.group_ordering
473- .new_groups(group_values, group_indices, total_num_groups)?;
474- } */
475-
476- // Gather the inputs to call the actual accumulator
477- let t = self
478- .accumulators
479- .iter_mut()
480- .zip(input_values.iter())
481- .zip(filter_values.iter());
482-
483- for ((acc, values), opt_filter) in t {
484- let opt_filter = opt_filter.as_ref().map(|filter| filter.as_boolean());
485-
486- // Call the appropriate method on each aggregator with
487- // the entire input row and the relevant group indexes
488- match self.mode {
489- InlineAggregateMode::Partial => {
490- acc.update_batch(values, group_indices, opt_filter, total_num_groups)?;
450+ assert_eq ! ( group_by_values. len( ) , 1 , "Exactly 1 group value required" ) ;
451+ self . group_values
452+ . intern ( & group_by_values[ 0 ] , & mut self . current_group_indices ) ?;
453+ let group_indices = & self . current_group_indices ;
454+
455+ let total_num_groups = self . group_values . len ( ) ;
456+ // Gather the inputs to call the actual accumulator
457+ let t = self
458+ . accumulators
459+ . iter_mut ( )
460+ . zip ( input_values. iter ( ) )
461+ . zip ( filter_values. iter ( ) ) ;
462+
463+ for ( ( acc, values) , opt_filter) in t {
464+ let opt_filter = opt_filter. as_ref ( ) . map ( |filter| filter. as_boolean ( ) ) ;
465+
466+ // Call the appropriate method on each aggregator with
467+ // the entire input row and the relevant group indexes
468+ match self . mode {
469+ InlineAggregateMode :: Partial => {
470+ acc. update_batch ( values, group_indices, opt_filter, total_num_groups) ?;
471+ }
472+ _ => {
473+ if opt_filter. is_some ( ) {
474+ return internal_err ! ( "aggregate filter should be applied in partial stage, there should be no filter in final stage" ) ;
491475 }
492- _ => {
493- if opt_filter.is_some() {
494- return internal_err!("aggregate filter should be applied in partial stage, there should be no filter in final stage");
495- }
496476
497- // if aggregation is over intermediate states,
498- // use merge
499- acc.merge_batch(values, group_indices, None, total_num_groups)?;
500- }
477+ // if aggregation is over intermediate states,
478+ // use merge
479+ acc. merge_batch ( values, group_indices, None , total_num_groups) ?;
501480 }
502481 }
503- } */
482+ }
504483 Ok ( ( ) )
505484 }
506485}
@@ -609,3 +588,9 @@ fn evaluate_group_by(
609588 } )
610589 . collect ( )
611590}
591+
592+ impl RecordBatchStream for InlineAggregateStream {
593+ fn schema ( & self ) -> SchemaRef {
594+ Arc :: clone ( & self . schema )
595+ }
596+ }
0 commit comments