Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions datafusion/physical-plan/src/sorts/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,24 @@ pub struct BatchBuilder {
/// Maintain a list of [`RecordBatch`] and their corresponding stream
batches: Vec<(usize, RecordBatch)>,

/// Accounts for memory used by buffered batches
/// Accounts for memory used by buffered batches.
///
/// May include pre-reserved bytes (from `sort_spill_reservation_bytes`)
/// that were transferred via [`MemoryReservation::take()`] to prevent
/// starvation when concurrent sort partitions compete for pool memory.
reservation: MemoryReservation,

/// Tracks the actual memory used by buffered batches (not including
/// pre-reserved bytes). This allows [`Self::push_batch`] to skip pool
/// allocation requests when the pre-reserved bytes cover the batch.
batches_mem_used: usize,

/// The initial reservation size at construction time. When the reservation
/// is pre-loaded with `sort_spill_reservation_bytes` (via `take()`), this
/// records that amount so we never shrink below it, maintaining the
/// anti-starvation guarantee throughout the merge.
initial_reservation: usize,

/// The current [`BatchCursor`] for each stream
cursors: Vec<BatchCursor>,

Expand All @@ -59,19 +74,29 @@ impl BatchBuilder {
batch_size: usize,
reservation: MemoryReservation,
) -> Self {
let initial_reservation = reservation.size();
Self {
schema,
batches: Vec::with_capacity(stream_count * 2),
cursors: vec![BatchCursor::default(); stream_count],
indices: Vec::with_capacity(batch_size),
reservation,
batches_mem_used: 0,
initial_reservation,
}
}

/// Append a new batch in `stream_idx`
pub fn push_batch(&mut self, stream_idx: usize, batch: RecordBatch) -> Result<()> {
self.reservation
.try_grow(get_record_batch_memory_size(&batch))?;
let size = get_record_batch_memory_size(&batch);
self.batches_mem_used += size;
// Only request additional memory from the pool when actual batch
// usage exceeds the current reservation (which may include
// pre-reserved bytes from sort_spill_reservation_bytes).
if self.batches_mem_used > self.reservation.size() {
self.reservation
.try_grow(self.batches_mem_used - self.reservation.size())?;
}
let batch_idx = self.batches.len();
self.batches.push((stream_idx, batch));
self.cursors[stream_idx] = BatchCursor {
Expand Down Expand Up @@ -143,11 +168,19 @@ impl BatchBuilder {
stream_cursor.batch_idx = retained;
retained += 1;
} else {
self.reservation.shrink(get_record_batch_memory_size(batch));
self.batches_mem_used -= get_record_batch_memory_size(batch);
}
retain
});

// Release excess memory back to the pool, but never shrink below
// initial_reservation to maintain the anti-starvation guarantee
// for the merge phase.
let target = self.batches_mem_used.max(self.initial_reservation);
if self.reservation.size() > target {
self.reservation.shrink(self.reservation.size() - target);
}

Ok(Some(RecordBatch::try_new(
Arc::clone(&self.schema),
columns,
Expand Down
91 changes: 55 additions & 36 deletions datafusion/physical-plan/src/sorts/multi_level_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,12 @@ impl MultiLevelMergeBuilder {

// Need to merge multiple streams
(_, _) => {
let mut memory_reservation = self.reservation.new_empty();
// Transfer any pre-reserved bytes (from sort_spill_reservation_bytes)
// to the merge memory reservation. This prevents starvation when
// concurrent sort partitions compete for pool memory: the pre-reserved
// bytes cover spill file buffer reservations without additional pool
// allocation.
let mut memory_reservation = self.reservation.take();

// Don't account for existing streams memory
// as we are not holding the memory for them
Expand Down Expand Up @@ -337,8 +342,10 @@ impl MultiLevelMergeBuilder {
builder = builder.with_bypass_mempool();
} else {
// If we are only merging in-memory streams, we need to use the memory reservation
// because we don't know the maximum size of the batches in the streams
builder = builder.with_reservation(self.reservation.new_empty());
// because we don't know the maximum size of the batches in the streams.
// Use take() to transfer any pre-reserved bytes so the merge can use them
// as its initial budget without additional pool allocation.
builder = builder.with_reservation(self.reservation.take());
}

builder.build()
Expand All @@ -356,45 +363,57 @@ impl MultiLevelMergeBuilder {
) -> Result<(Vec<SortedSpillFile>, usize)> {
assert_ne!(buffer_len, 0, "Buffer length must be greater than 0");
let mut number_of_spills_to_read_for_current_phase = 0;
// Track total memory needed for spill file buffers. When the
// reservation has pre-reserved bytes (from sort_spill_reservation_bytes),
// those bytes cover the first N spill files without additional pool
// allocation, preventing starvation under memory pressure.
let mut total_needed: usize = 0;

for spill in &self.sorted_spill_files {
// For memory pools that are not shared this is good, for other this is not
// and there should be some upper limit to memory reservation so we won't starve the system
match reservation.try_grow(
get_reserved_bytes_for_record_batch_size(
spill.max_record_batch_memory,
// Size will be the same as the sliced size, bc it is a spilled batch.
spill.max_record_batch_memory,
) * buffer_len,
) {
Ok(_) => {
number_of_spills_to_read_for_current_phase += 1;
}
// If we can't grow the reservation, we need to stop
Err(err) => {
// We must have at least 2 streams to merge, so if we don't have enough memory
// fail
if minimum_number_of_required_streams
> number_of_spills_to_read_for_current_phase
{
// Free the memory we reserved for this merge as we either try again or fail
reservation.free();
if buffer_len > 1 {
// Try again with smaller buffer size, it will be slower but at least we can merge
return self.get_sorted_spill_files_to_merge(
buffer_len - 1,
minimum_number_of_required_streams,
reservation,
);
let per_spill = get_reserved_bytes_for_record_batch_size(
spill.max_record_batch_memory,
// Size will be the same as the sliced size, bc it is a spilled batch.
spill.max_record_batch_memory,
) * buffer_len;
total_needed += per_spill;

// Only request additional memory from the pool when total needed
// exceeds what's already reserved (which may include pre-reserved
// bytes from sort_spill_reservation_bytes).
if total_needed > reservation.size() {
match reservation.try_grow(total_needed - reservation.size()) {
Ok(_) => {
number_of_spills_to_read_for_current_phase += 1;
}
// If we can't grow the reservation, we need to stop
Err(err) => {
// We must have at least 2 streams to merge, so if we don't have enough memory
// fail
if minimum_number_of_required_streams
> number_of_spills_to_read_for_current_phase
{
// Free the memory we reserved for this merge as we either try again or fail
reservation.free();
if buffer_len > 1 {
// Try again with smaller buffer size, it will be slower but at least we can merge
return self.get_sorted_spill_files_to_merge(
buffer_len - 1,
minimum_number_of_required_streams,
reservation,
);
}

return Err(err);
}

return Err(err);
// We reached the maximum amount of memory we can use
// for this merge
break;
}

// We reached the maximum amount of memory we can use
// for this merge
break;
}
} else {
// Pre-reserved bytes cover this spill file's buffer
number_of_spills_to_read_for_current_phase += 1;
}
}

Expand Down
146 changes: 145 additions & 1 deletion datafusion/physical-plan/src/sorts/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,13 @@ impl ExternalSorter {
self.sort_and_spill_in_mem_batches().await?;
}

// Transfer the pre-reserved merge memory to the streaming merge
// using `take()` instead of `new_empty()`. This ensures the merge
// stream starts with `sort_spill_reservation_bytes` already
// allocated, preventing starvation when concurrent sort partitions
// compete for pool memory. `take()` moves the bytes atomically
// without releasing them back to the pool, so other partitions
// cannot race to consume the freed memory.
StreamingMergeBuilder::new()
.with_sorted_spill_files(std::mem::take(&mut self.finished_spill_files))
.with_spill_manager(self.spill_manager.clone())
Expand All @@ -363,7 +370,7 @@ impl ExternalSorter {
.with_metrics(self.metrics.baseline.clone())
.with_batch_size(self.batch_size)
.with_fetch(None)
.with_reservation(self.merge_reservation.new_empty())
.with_reservation(self.merge_reservation.take())
.build()
} else {
self.in_mem_sort_stream(self.metrics.baseline.clone())
Expand Down Expand Up @@ -2728,4 +2735,141 @@ mod tests {

Ok(())
}

/// End-to-end test that verifies `ExternalSorter::sort()` atomically
/// transfers the pre-reserved merge bytes to the merge stream via `take()`.
///
/// This test directly exercises the `ExternalSorter` code path:
/// 1. Create a sorter with a tight memory pool and insert enough data
/// to force spilling
/// 2. Call `sort()` to get the merge stream
/// 3. Verify that dropping the sorter does NOT free the pre-reserved
/// bytes back to the pool (they should have been transferred to
/// the merge stream)
/// 4. Simulate contention: a task grabs all available pool memory
/// 5. Verify the merge stream still works (it has its own pre-reserved bytes)
///
/// Before the fix, main (using `new_empty()`), step 3 fails: the sorter drop frees
/// `sort_spill_reservation_bytes` back to the pool, and the task can
/// steal them, causing the merge stream to starve.
///
/// With the fix (using `take()`), the bytes are atomically transferred
/// to the merge stream. The sorter drop frees 0 bytes, so there's
/// nothing for the task to steal.
#[tokio::test]
async fn test_sort_merge_reservation_transferred_not_freed() -> Result<()> {
use datafusion_execution::memory_pool::{
GreedyMemoryPool, MemoryConsumer, MemoryPool,
};
use futures::TryStreamExt;

let sort_spill_reservation_bytes: usize = 10 * 1024; // 10 KB

// Pool: merge reservation (10KB) + enough room for sort to work.
// The room must accommodate batch data accumulation before spilling.
let sort_working_memory: usize = 40 * 1024; // 40 KB for sort operations
let pool_size = sort_spill_reservation_bytes + sort_working_memory;
let pool: Arc<dyn MemoryPool> = Arc::new(GreedyMemoryPool::new(pool_size));

let runtime = RuntimeEnvBuilder::new()
.with_memory_pool(Arc::clone(&pool))
.build_arc()?;

let metrics_set = ExecutionPlanMetricsSet::new();
let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)]));

let mut sorter = ExternalSorter::new(
0,
Arc::clone(&schema),
[PhysicalSortExpr::new_default(Arc::new(Column::new("x", 0)))].into(),
128, // batch_size
sort_spill_reservation_bytes,
usize::MAX, // sort_in_place_threshold_bytes (high to avoid concat path)
SpillCompression::Uncompressed,
&metrics_set,
Arc::clone(&runtime),
)?;

// Insert enough data to force spilling. Each batch is ~400 bytes
// (100 rows × 4 bytes). With 40KB of working memory, we'll spill
// after accumulating ~100 batches worth. 200 batches guarantees
// multiple spill cycles.
let num_batches = 200;
for i in 0..num_batches {
let values: Vec<i32> = ((i * 100)..((i + 1) * 100)).rev().collect();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(Int32Array::from(values))],
)?;
sorter.insert_batch(batch).await?;
}

assert!(
sorter.spilled_before(),
"Test requires spilling to exercise the merge path"
);

// Call sort() to get the merge stream. After this:
// - With take() (the fix): merge_reservation = 0, merge stream has R bytes
// - With new_empty() (before fix): merge_reservation = R, merge stream has 0 bytes
let merge_stream = sorter.sort().await?;

// Record pool state before dropping the sorter
let reserved_before_drop = pool.reserved();

// Drop the sorter. This frees merge_reservation:
// - With take() (the fix): frees 0 bytes (already transferred to merge stream)
// - With new_empty() (before fix): frees R bytes back to pool
drop(sorter);

let reserved_after_drop = pool.reserved();

// THE KEY ASSERTION: dropping the sorter should NOT free the
// pre-reserved merge bytes. They must have been transferred to
// the merge stream via take().
assert_eq!(
reserved_after_drop,
reserved_before_drop,
"Dropping the sorter freed {} bytes back to the pool! \
The merge reservation bytes should have been transferred \
to the merge stream (via take()), not freed back to the pool \
(via new_empty()). Freed bytes can be stolen by concurrent \
partitions, causing merge starvation.",
reserved_before_drop - reserved_after_drop
);

// Simulate contention: a task (representing another partition)
// grabs all available pool memory
let task = MemoryConsumer::new("TaskPartition").register(&pool);
let available = pool_size.saturating_sub(pool.reserved());
if available > 0 {
task.try_grow(available).unwrap();
}

// The merge stream should still work because it holds the
// pre-reserved bytes (transferred via take())
let batches: Vec<RecordBatch> = merge_stream.try_collect().await?;
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(
total_rows,
(num_batches * 100) as usize,
"Merge stream should produce all rows even under memory contention"
);

// Verify data is sorted
let merged = concat_batches(&schema, &batches)?;
let col = merged.column(0).as_primitive::<Int32Type>();
for i in 1..col.len() {
assert!(
col.value(i - 1) <= col.value(i),
"Output should be sorted, but found {} > {} at index {}",
col.value(i - 1),
col.value(i),
i
);
}

drop(task);
Ok(())
}
}