@@ -609,14 +609,11 @@ impl Stream for GroupedHashAggregateStream {
609609 match & self . exec_state {
610610 ExecutionState :: ReadingInput => ' reading_input: {
611611 match ready ! ( self . input. poll_next_unpin( cx) ) {
612- // new batch to aggregate
613- Some ( Ok ( batch) ) => {
612+ // New batch to aggregate in partial aggregation operator
613+ Some ( Ok ( batch) ) if self . mode == AggregateMode :: Partial => {
614614 let timer = elapsed_compute. timer ( ) ;
615615 let input_rows = batch. num_rows ( ) ;
616616
617- // Make sure we have enough capacity for `batch`, otherwise spill
618- extract_ok ! ( self . spill_previous_if_necessary( & batch) ) ;
619-
620617 // Do the grouping
621618 extract_ok ! ( self . group_aggregate_batch( batch) ) ;
622619
@@ -649,10 +646,49 @@ impl Stream for GroupedHashAggregateStream {
649646
650647 timer. done ( ) ;
651648 }
649+
650+ // New batch to aggregate in terminal aggregation operator
651+ // (Final/FinalPartitioned/Single/SinglePartitioned)
652+ Some ( Ok ( batch) ) => {
653+ let timer = elapsed_compute. timer ( ) ;
654+
655+ // Make sure we have enough capacity for `batch`, otherwise spill
656+ extract_ok ! ( self . spill_previous_if_necessary( & batch) ) ;
657+
658+ // Do the grouping
659+ extract_ok ! ( self . group_aggregate_batch( batch) ) ;
660+
661+ // If we can begin emitting rows, do so,
662+ // otherwise keep consuming input
663+ assert ! ( !self . input_done) ;
664+
665+ // If the number of group values equals or exceeds the soft limit,
666+ // emit all groups and switch to producing output
667+ if self . hit_soft_group_limit ( ) {
668+ timer. done ( ) ;
669+ extract_ok ! ( self . set_input_done_and_produce_output( ) ) ;
670+ // make sure the exec_state just set is not overwritten below
671+ break ' reading_input;
672+ }
673+
674+ if let Some ( to_emit) = self . group_ordering . emit_to ( ) {
675+ let batch = extract_ok ! ( self . emit( to_emit, false ) ) ;
676+ self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
677+ timer. done ( ) ;
678+ // make sure the exec_state just set is not overwritten below
679+ break ' reading_input;
680+ }
681+
682+ timer. done ( ) ;
683+ }
684+
685+ // Found error from input stream
652686 Some ( Err ( e) ) => {
653687 // inner had error, return to caller
654688 return Poll :: Ready ( Some ( Err ( e) ) ) ;
655689 }
690+
691+ // Found end from input stream
656692 None => {
657693 // inner is done, emit all rows and switch to producing output
658694 extract_ok ! ( self . set_input_done_and_produce_output( ) ) ;
@@ -691,7 +727,12 @@ impl Stream for GroupedHashAggregateStream {
691727 (
692728 if self . input_done {
693729 ExecutionState :: Done
694- } else if self . should_skip_aggregation ( ) {
730+ }
731+ // In Partial aggregation, we also need to check
732+ // if we should trigger partial skipping
733+ else if self . mode == AggregateMode :: Partial
734+ && self . should_skip_aggregation ( )
735+ {
695736 ExecutionState :: SkippingAggregation
696737 } else {
697738 ExecutionState :: ReadingInput
@@ -879,10 +920,10 @@ impl GroupedHashAggregateStream {
879920 if self . group_values . len ( ) > 0
880921 && batch. num_rows ( ) > 0
881922 && matches ! ( self . group_ordering, GroupOrdering :: None )
882- && !matches ! ( self . mode, AggregateMode :: Partial )
883923 && !self . spill_state . is_stream_merging
884924 && self . update_memory_reservation ( ) . is_err ( )
885925 {
926+ assert_ne ! ( self . mode, AggregateMode :: Partial ) ;
886927 // Use input batch (Partial mode) schema for spilling because
887928 // the spilled data will be merged and re-evaluated later.
888929 self . spill_state . spill_schema = batch. schema ( ) ;
@@ -927,9 +968,9 @@ impl GroupedHashAggregateStream {
927968 fn emit_early_if_necessary ( & mut self ) -> Result < ( ) > {
928969 if self . group_values . len ( ) >= self . batch_size
929970 && matches ! ( self . group_ordering, GroupOrdering :: None )
930- && matches ! ( self . mode, AggregateMode :: Partial )
931971 && self . update_memory_reservation ( ) . is_err ( )
932972 {
973+ assert_eq ! ( self . mode, AggregateMode :: Partial ) ;
933974 let n = self . group_values . len ( ) / self . batch_size * self . batch_size ;
934975 let batch = self . emit ( EmitTo :: First ( n) , false ) ?;
935976 self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
@@ -1002,6 +1043,8 @@ impl GroupedHashAggregateStream {
10021043 }
10031044
10041045 /// Updates skip aggregation probe state.
1046+ ///
1047+ /// Notice: It should only be called in Partial aggregation
10051048 fn update_skip_aggregation_probe ( & mut self , input_rows : usize ) {
10061049 if let Some ( probe) = self . skip_aggregation_probe . as_mut ( ) {
10071050 // Skip aggregation probe is not supported if stream has any spills,
@@ -1013,6 +1056,8 @@ impl GroupedHashAggregateStream {
10131056
10141057 /// In case the probe indicates that aggregation may be
10151058 /// skipped, forces stream to produce currently accumulated output.
1059+ ///
1060+ /// Notice: It should only be called in Partial aggregation
10161061 fn switch_to_skip_aggregation ( & mut self ) -> Result < ( ) > {
10171062 if let Some ( probe) = self . skip_aggregation_probe . as_mut ( ) {
10181063 if probe. should_skip ( ) {
@@ -1026,6 +1071,8 @@ impl GroupedHashAggregateStream {
10261071
10271072 /// Returns true if the aggregation probe indicates that aggregation
10281073 /// should be skipped.
1074+ ///
1075+ /// Notice: It should only be called in Partial aggregation
10291076 fn should_skip_aggregation ( & self ) -> bool {
10301077 self . skip_aggregation_probe
10311078 . as_ref ( )
0 commit comments