diff --git a/datafusion/physical-plan/src/sorts/multi_level_merge.rs b/datafusion/physical-plan/src/sorts/multi_level_merge.rs index 2e0d668a29559..fabc47edf63d5 100644 --- a/datafusion/physical-plan/src/sorts/multi_level_merge.rs +++ b/datafusion/physical-plan/src/sorts/multi_level_merge.rs @@ -259,9 +259,10 @@ impl MultiLevelMergeBuilder { // as we are not holding the memory for them let mut sorted_streams = mem::take(&mut self.sorted_streams); - let (sorted_spill_files, buffer_size) = self + let (sorted_spill_files, _) = self .get_sorted_spill_files_to_merge( - 2, + // No read-ahead buffering needed, reserve memory for 1 batch per file + 1, // we must have at least 2 streams to merge 2_usize.saturating_sub(sorted_streams.len()), &mut memory_reservation, @@ -273,7 +274,6 @@ impl MultiLevelMergeBuilder { let stream = self .spill_manager .clone() - .with_batch_read_buffer_capacity(buffer_size) .read_spill_as_stream( spill.file, Some(spill.max_record_batch_memory), diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index 89b0276206774..7bd42e8290cdf 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -28,7 +28,7 @@ use std::sync::Arc; use super::{SpillReaderStream, in_progress_spill_file::InProgressSpillFile}; use crate::coop::cooperative; -use crate::{common::spawn_buffered, metrics::SpillMetrics}; +use crate::metrics::SpillMetrics; /// The `SpillManager` is responsible for the following tasks: /// - Reading and writing `RecordBatch`es to raw files based on the provided configurations. @@ -41,8 +41,6 @@ pub struct SpillManager { env: Arc, pub(crate) metrics: SpillMetrics, schema: SchemaRef, - /// Number of batches to buffer in memory during disk reads - batch_read_buffer_capacity: usize, /// general-purpose compression options pub(crate) compression: SpillCompression, } @@ -53,18 +51,10 @@ impl SpillManager { env, metrics, schema, - batch_read_buffer_capacity: 2, compression: SpillCompression::default(), } } - pub fn with_batch_read_buffer_capacity( - mut self, - batch_read_buffer_capacity: usize, - ) -> Self { - self.batch_read_buffer_capacity = batch_read_buffer_capacity; - self - } pub fn with_compression_type(mut self, spill_compression: SpillCompression) -> Self { self.compression = spill_compression; @@ -186,7 +176,7 @@ impl SpillManager { max_record_batch_memory, ))); - Ok(spawn_buffered(stream, self.batch_read_buffer_capacity)) + Ok(stream) } } diff --git a/datafusion/physical-plan/src/spill/spill_pool.rs b/datafusion/physical-plan/src/spill/spill_pool.rs index 8f7f5212f6c91..09a85a3db360f 100644 --- a/datafusion/physical-plan/src/spill/spill_pool.rs +++ b/datafusion/physical-plan/src/spill/spill_pool.rs @@ -1435,4 +1435,44 @@ mod tests { Ok(()) } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn test_concurrent_writer_reader_race_condition() -> Result<()> { + // stress testing the concurncy in the reader and the reader to make sure there is now race condtion + // going for 100 iterations with a 5 batches per iteration + const NUM_BATCHES: usize = 5; + const ITERATIONS: usize = 100; + + for iteration in 0..ITERATIONS { + let (writer, mut reader) = create_spill_channel(1024 * 1024); + + let writer_handle = SpawnedTask::spawn(async move { + for i in 0..NUM_BATCHES { + let batch = create_test_batch(i as i32 * 10, 10); + writer.push_batch(&batch).unwrap(); + tokio::task::yield_now().await; + } + }); + + let reader_handle = SpawnedTask::spawn(async move { + let mut batches_read = 0; + while let Some(result) = reader.next().await { + let _batch = result.unwrap(); + batches_read += 1; + tokio::task::yield_now().await; + } + batches_read + }); + + writer_handle.join().await.unwrap(); + let batches_read = reader_handle.join().await.unwrap(); + + assert_eq!( + batches_read, NUM_BATCHES, + "Iteration {iteration}: Expected {NUM_BATCHES} got {batches_read}." + ); + } + + Ok(()) + } }