Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 91 additions & 47 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use std::vec;
use super::AggregateExec;
use super::order::GroupOrdering;
use crate::aggregates::group_values::{GroupByMetrics, GroupValues, new_group_values};
use crate::aggregates::order::GroupOrderingFull;
use crate::aggregates::{
AggregateMode, PhysicalGroupBy, create_schema, evaluate_group_by, evaluate_many,
evaluate_optional,
Expand Down Expand Up @@ -68,6 +67,8 @@ pub(crate) enum ExecutionState {
///
/// See "partial aggregation" discussion on [`GroupedHashAggregateStream`]
SkippingAggregation,
/// Special state to emit remaining groups when input is done
EmittingRemaining,
/// All input has been consumed and all groups have been emitted
Done,
}
Expand Down Expand Up @@ -802,6 +803,9 @@ impl Stream for GroupedHashAggregateStream {

// Found end from input stream
None => {
if self.spill_state.is_stream_merging {
self.group_ordering.input_done();
}
// inner is done, emit all rows and switch to producing output
self.set_input_done_and_produce_output()?;
}
Expand Down Expand Up @@ -844,31 +848,48 @@ impl Stream for GroupedHashAggregateStream {
// slice off a part of the batch, if needed
let output_batch;
let size = self.batch_size;
(self.exec_state, output_batch) = if batch.num_rows() <= size {
(
if self.input_done {
ExecutionState::Done
}
// In Partial aggregation, we also need to check
// if we should trigger partial skipping
else if self.mode == AggregateMode::Partial
&& self.should_skip_aggregation()
{
ExecutionState::SkippingAggregation
} else {
ExecutionState::ReadingInput
},
batch.clone(),
)
} else {
// output first batch_size rows
let size = self.batch_size;

// First handle the simple case where we need to slice
if batch.num_rows() > size {
let num_remaining = batch.num_rows() - size;
let remaining = batch.slice(size, num_remaining);
let output = batch.slice(0, size);
(ExecutionState::ProducingOutput(remaining), output)
output_batch = batch.slice(0, size);
self.exec_state = ExecutionState::ProducingOutput(remaining);

if let Some(reduction_factor) = self.reduction_factor.as_ref() {
reduction_factor.add_part(output_batch.num_rows());
}

debug_assert!(output_batch.num_rows() > 0);
return Poll::Ready(Some(Ok(
output_batch.record_output(&self.baseline_metrics)
)));
}

// The entire batch fits within batch_size
output_batch = batch.clone();

// Now determine the next state
let next_state = if self.input_done {
// We're done with input, check if we have more groups to emit
if !self.group_values.is_empty() {
ExecutionState::EmittingRemaining
} else {
ExecutionState::Done
}
}
// In Partial aggregation, we also need to check
// if we should trigger partial skipping
else if self.mode == AggregateMode::Partial
&& self.should_skip_aggregation()
{
ExecutionState::SkippingAggregation
} else {
ExecutionState::ReadingInput
};

self.exec_state = next_state;

if let Some(reduction_factor) = self.reduction_factor.as_ref() {
reduction_factor.add_part(output_batch.num_rows());
}
Expand All @@ -881,8 +902,31 @@ impl Stream for GroupedHashAggregateStream {
)));
}

ExecutionState::EmittingRemaining => {
let remaining = self.group_values.len();

if remaining == 0 {
self.exec_state = ExecutionState::Done;
continue;
}

let to_emit = remaining.min(self.batch_size);

match self.emit(EmitTo::First(to_emit), false)? {
Some(batch) => {
self.exec_state = ExecutionState::ProducingOutput(batch);
continue;
}
None => {
self.exec_state = ExecutionState::Done;
continue;
}
}
}

ExecutionState::Done => {
// Sanity check: all groups should have been emitted by now
self.group_ordering.input_done();
// all groups should have been emitted by now
if !self.group_values.is_empty() {
return Poll::Ready(Some(internal_err!(
"AggregateStream was in Done state with {} groups left in hash table. \
Expand Down Expand Up @@ -1032,7 +1076,7 @@ impl GroupedHashAggregateStream {
match self.oom_mode {
OutOfMemoryMode::Spill if !self.group_values.is_empty() => {
self.spill()?;
self.clear_shrink(self.batch_size);
// Force memory update after clearing
self.update_memory_reservation()?;
Ok(None)
}
Expand Down Expand Up @@ -1120,6 +1164,10 @@ impl GroupedHashAggregateStream {
/// This process helps in reducing memory pressure by allowing the data to be
/// read back with streaming merge.
fn spill(&mut self) -> Result<()> {
// Check if there's anything to spill
if self.group_values.is_empty() {
return Ok(());
}
// Emit and sort intermediate aggregation state
let Some(emit) = self.emit(EmitTo::All, true)? else {
return Ok(());
Expand All @@ -1140,7 +1188,9 @@ impl GroupedHashAggregateStream {
self.spill_state.spills.push(SortedSpillFile {
file: spillfile,
max_record_batch_memory,
})
});
// Clear memory after successful spill
self.clear_all();
}
None => {
return internal_err!(
Expand Down Expand Up @@ -1179,26 +1229,26 @@ impl GroupedHashAggregateStream {
/// in case of disk spilling, the SPM stream have been drained.
fn set_input_done_and_produce_output(&mut self) -> Result<()> {
self.input_done = true;
self.group_ordering.input_done();
let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
let timer = elapsed_compute.timer();
self.exec_state = if self.spill_state.spills.is_empty() {
// Input has been entirely processed without spilling to disk.
// Emit the first batch of groups (up to batch_size)
if !self.group_values.is_empty() {
// Start emitting remaining groups
ExecutionState::EmittingRemaining
} else {
ExecutionState::Done
}
} else {
// Spill any remaining data to disk.
self.spill()?; // This should clear memory

// Flush any remaining group values.
let batch = self.emit(EmitTo::All, false)?;
// Reset the group values collectors.
self.clear_all();

// If there are none, we're done; otherwise switch to emitting them
batch.map_or(ExecutionState::Done, ExecutionState::ProducingOutput)
} else {
// Spill any remaining data to disk. There is some performance overhead in
// writing out this last chunk of data and reading it back. The benefit of
// doing this is that memory usage for this stream is reduced, and the more
// sophisticated memory handling in `MultiLevelMergeBuilder` can take over
// instead.
// Spilling to disk and reading back also ensures batch size is consistent
// rather than potentially having one significantly larger last batch.
self.spill()?; // TODO: use sort_batch_chunked instead?
// IMPORTANT: release hash-aggregation memory before streaming merge
self.reservation.free();

// Mark that we're switching to stream merging mode.
self.spill_state.is_stream_merging = true;
Expand All @@ -1212,21 +1262,15 @@ impl GroupedHashAggregateStream {
.with_batch_size(self.batch_size)
.with_reservation(self.reservation.new_empty())
.build()?;
self.input_done = false;

// Reset the group values collectors.
self.clear_all();

// We can now use `GroupOrdering::Full` since the spill files are sorted
// on the grouping columns.
self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new());
self.input_done = false;

// Use `OutOfMemoryMode::ReportError` from this point on
// to ensure we don't spill the spilled data to disk again.
self.oom_mode = OutOfMemoryMode::ReportError;

// Update memory reservation for the new streaming merge
self.update_memory_reservation()?;

ExecutionState::ReadingInput
};
timer.done();
Expand Down
Loading