From 337418a1048c91459f8cfd605206b593a1818988 Mon Sep 17 00:00:00 2001 From: Nachiket Roy Date: Sun, 11 Jan 2026 15:02:24 +0000 Subject: [PATCH] fix aggregateexec internal spilling --- .../physical-plan/src/aggregates/row_hash.rs | 138 ++++++++++++------ 1 file changed, 91 insertions(+), 47 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 1ae7202711112..299b231971278 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -24,7 +24,6 @@ use std::vec; use super::AggregateExec; use super::order::GroupOrdering; use crate::aggregates::group_values::{GroupByMetrics, GroupValues, new_group_values}; -use crate::aggregates::order::GroupOrderingFull; use crate::aggregates::{ AggregateMode, PhysicalGroupBy, create_schema, evaluate_group_by, evaluate_many, evaluate_optional, @@ -68,6 +67,8 @@ pub(crate) enum ExecutionState { /// /// See "partial aggregation" discussion on [`GroupedHashAggregateStream`] SkippingAggregation, + /// Special state to emit remaining groups when input is done + EmittingRemaining, /// All input has been consumed and all groups have been emitted Done, } @@ -802,6 +803,9 @@ impl Stream for GroupedHashAggregateStream { // Found end from input stream None => { + if self.spill_state.is_stream_merging { + self.group_ordering.input_done(); + } // inner is done, emit all rows and switch to producing output self.set_input_done_and_produce_output()?; } @@ -844,31 +848,48 @@ impl Stream for GroupedHashAggregateStream { // slice off a part of the batch, if needed let output_batch; let size = self.batch_size; - (self.exec_state, output_batch) = if batch.num_rows() <= size { - ( - if self.input_done { - ExecutionState::Done - } - // In Partial aggregation, we also need to check - // if we should trigger partial skipping - else if self.mode == AggregateMode::Partial - && self.should_skip_aggregation() - { - ExecutionState::SkippingAggregation - } else { - ExecutionState::ReadingInput - }, - batch.clone(), - ) - } else { - // output first batch_size rows - let size = self.batch_size; + + // First handle the simple case where we need to slice + if batch.num_rows() > size { let num_remaining = batch.num_rows() - size; let remaining = batch.slice(size, num_remaining); - let output = batch.slice(0, size); - (ExecutionState::ProducingOutput(remaining), output) + output_batch = batch.slice(0, size); + self.exec_state = ExecutionState::ProducingOutput(remaining); + + if let Some(reduction_factor) = self.reduction_factor.as_ref() { + reduction_factor.add_part(output_batch.num_rows()); + } + + debug_assert!(output_batch.num_rows() > 0); + return Poll::Ready(Some(Ok( + output_batch.record_output(&self.baseline_metrics) + ))); + } + + // The entire batch fits within batch_size + output_batch = batch.clone(); + + // Now determine the next state + let next_state = if self.input_done { + // We're done with input, check if we have more groups to emit + if !self.group_values.is_empty() { + ExecutionState::EmittingRemaining + } else { + ExecutionState::Done + } + } + // In Partial aggregation, we also need to check + // if we should trigger partial skipping + else if self.mode == AggregateMode::Partial + && self.should_skip_aggregation() + { + ExecutionState::SkippingAggregation + } else { + ExecutionState::ReadingInput }; + self.exec_state = next_state; + if let Some(reduction_factor) = self.reduction_factor.as_ref() { reduction_factor.add_part(output_batch.num_rows()); } @@ -881,8 +902,31 @@ impl Stream for GroupedHashAggregateStream { ))); } + ExecutionState::EmittingRemaining => { + let remaining = self.group_values.len(); + + if remaining == 0 { + self.exec_state = ExecutionState::Done; + continue; + } + + let to_emit = remaining.min(self.batch_size); + + match self.emit(EmitTo::First(to_emit), false)? { + Some(batch) => { + self.exec_state = ExecutionState::ProducingOutput(batch); + continue; + } + None => { + self.exec_state = ExecutionState::Done; + continue; + } + } + } + ExecutionState::Done => { - // Sanity check: all groups should have been emitted by now + self.group_ordering.input_done(); + // all groups should have been emitted by now if !self.group_values.is_empty() { return Poll::Ready(Some(internal_err!( "AggregateStream was in Done state with {} groups left in hash table. \ @@ -1032,7 +1076,7 @@ impl GroupedHashAggregateStream { match self.oom_mode { OutOfMemoryMode::Spill if !self.group_values.is_empty() => { self.spill()?; - self.clear_shrink(self.batch_size); + // Force memory update after clearing self.update_memory_reservation()?; Ok(None) } @@ -1120,6 +1164,10 @@ impl GroupedHashAggregateStream { /// This process helps in reducing memory pressure by allowing the data to be /// read back with streaming merge. fn spill(&mut self) -> Result<()> { + // Check if there's anything to spill + if self.group_values.is_empty() { + return Ok(()); + } // Emit and sort intermediate aggregation state let Some(emit) = self.emit(EmitTo::All, true)? else { return Ok(()); @@ -1140,7 +1188,9 @@ impl GroupedHashAggregateStream { self.spill_state.spills.push(SortedSpillFile { file: spillfile, max_record_batch_memory, - }) + }); + // Clear memory after successful spill + self.clear_all(); } None => { return internal_err!( @@ -1179,26 +1229,26 @@ impl GroupedHashAggregateStream { /// in case of disk spilling, the SPM stream have been drained. fn set_input_done_and_produce_output(&mut self) -> Result<()> { self.input_done = true; - self.group_ordering.input_done(); let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); let timer = elapsed_compute.timer(); self.exec_state = if self.spill_state.spills.is_empty() { // Input has been entirely processed without spilling to disk. + // Emit the first batch of groups (up to batch_size) + if !self.group_values.is_empty() { + // Start emitting remaining groups + ExecutionState::EmittingRemaining + } else { + ExecutionState::Done + } + } else { + // Spill any remaining data to disk. + self.spill()?; // This should clear memory - // Flush any remaining group values. - let batch = self.emit(EmitTo::All, false)?; + // Reset the group values collectors. + self.clear_all(); - // If there are none, we're done; otherwise switch to emitting them - batch.map_or(ExecutionState::Done, ExecutionState::ProducingOutput) - } else { - // Spill any remaining data to disk. There is some performance overhead in - // writing out this last chunk of data and reading it back. The benefit of - // doing this is that memory usage for this stream is reduced, and the more - // sophisticated memory handling in `MultiLevelMergeBuilder` can take over - // instead. - // Spilling to disk and reading back also ensures batch size is consistent - // rather than potentially having one significantly larger last batch. - self.spill()?; // TODO: use sort_batch_chunked instead? + // IMPORTANT: release hash-aggregation memory before streaming merge + self.reservation.free(); // Mark that we're switching to stream merging mode. self.spill_state.is_stream_merging = true; @@ -1212,21 +1262,15 @@ impl GroupedHashAggregateStream { .with_batch_size(self.batch_size) .with_reservation(self.reservation.new_empty()) .build()?; - self.input_done = false; - // Reset the group values collectors. - self.clear_all(); - - // We can now use `GroupOrdering::Full` since the spill files are sorted - // on the grouping columns. - self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new()); + self.input_done = false; // Use `OutOfMemoryMode::ReportError` from this point on // to ensure we don't spill the spilled data to disk again. self.oom_mode = OutOfMemoryMode::ReportError; + // Update memory reservation for the new streaming merge self.update_memory_reservation()?; - ExecutionState::ReadingInput }; timer.done();