From 1b8b36e0662ac5d1d0ba606f4a9a3c355ac5ccea Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Sat, 22 Nov 2025 17:02:56 +0100 Subject: [PATCH 1/6] Fix network shuffle deadlock --- src/execution_plans/network_shuffle.rs | 38 ++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/src/execution_plans/network_shuffle.rs b/src/execution_plans/network_shuffle.rs index 1e2f313..7639012 100644 --- a/src/execution_plans/network_shuffle.rs +++ b/src/execution_plans/network_shuffle.rs @@ -13,6 +13,7 @@ use arrow_flight::decode::FlightRecordBatchStream; use arrow_flight::error::FlightError; use bytes::Bytes; use dashmap::DashMap; +use datafusion::common::runtime::SpawnedTask; use datafusion::common::{exec_err, internal_datafusion_err, plan_err}; use datafusion::error::DataFusionError; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; @@ -22,12 +23,13 @@ use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, }; -use futures::{StreamExt, TryFutureExt, TryStreamExt}; +use futures::{Stream, StreamExt, TryFutureExt, TryStreamExt}; use http::Extensions; use prost::Message; use std::any::Any; use std::fmt::Formatter; use std::sync::Arc; +use tokio_stream::wrappers::ReceiverStream; use tonic::Request; use tonic::metadata::MetadataMap; @@ -372,7 +374,39 @@ impl ExecutionPlan for NetworkShuffleExec { Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), - futures::stream::select_all(stream), + spawn_select_all(stream.collect()), ))) } } + +const NETWORK_STREAM_BUFFER_BATCHES: usize = 10; + +/// Consumes all the provided streams in parallel sending their produced messages to a single +/// queue in random order. The resulting queue is returned as a stream. +// FIXME: It should not be necessary to do this, it should be fine to just consume +// all the messages with a normal tokio::stream::select_all, however, that has the chance +// of deadlocking the stream on the server side (https://github.com/datafusion-contrib/datafusion-distributed/issues/228). +// Until we figure out what's wrong there, this is a good enough solution. +fn spawn_select_all(inner: Vec) -> impl Stream +where + T: Stream + Send + Unpin + 'static, + T::Item: Send, +{ + let (tx, rx) = tokio::sync::mpsc::channel(NETWORK_STREAM_BUFFER_BATCHES); + let mut tasks = vec![]; + for mut t in inner { + let tx = tx.clone(); + tasks.push(SpawnedTask::spawn(async move { + while let Some(msg) = t.next().await { + if tx.send(msg).await.is_err() { + return; + }; + } + })) + } + + ReceiverStream::new(rx).inspect(move |_| { + // keep the tasks alive as long as the stream lives + let _ = &tasks; + }) +} From 75b6af3c46a3ff08de57280fb4166bd0910f63d1 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Sat, 22 Nov 2025 17:25:01 +0100 Subject: [PATCH 2/6] Use an unbounded queue --- src/execution_plans/network_shuffle.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/execution_plans/network_shuffle.rs b/src/execution_plans/network_shuffle.rs index 7639012..4961f81 100644 --- a/src/execution_plans/network_shuffle.rs +++ b/src/execution_plans/network_shuffle.rs @@ -29,7 +29,7 @@ use prost::Message; use std::any::Any; use std::fmt::Formatter; use std::sync::Arc; -use tokio_stream::wrappers::ReceiverStream; +use tokio_stream::wrappers::UnboundedReceiverStream; use tonic::Request; use tonic::metadata::MetadataMap; @@ -379,33 +379,32 @@ impl ExecutionPlan for NetworkShuffleExec { } } -const NETWORK_STREAM_BUFFER_BATCHES: usize = 10; - /// Consumes all the provided streams in parallel sending their produced messages to a single /// queue in random order. The resulting queue is returned as a stream. // FIXME: It should not be necessary to do this, it should be fine to just consume // all the messages with a normal tokio::stream::select_all, however, that has the chance // of deadlocking the stream on the server side (https://github.com/datafusion-contrib/datafusion-distributed/issues/228). +// Even having these channels bounded would result in deadlocks (learned it the hard way). // Until we figure out what's wrong there, this is a good enough solution. fn spawn_select_all(inner: Vec) -> impl Stream where T: Stream + Send + Unpin + 'static, T::Item: Send, { - let (tx, rx) = tokio::sync::mpsc::channel(NETWORK_STREAM_BUFFER_BATCHES); + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); let mut tasks = vec![]; for mut t in inner { let tx = tx.clone(); tasks.push(SpawnedTask::spawn(async move { while let Some(msg) = t.next().await { - if tx.send(msg).await.is_err() { + if tx.send(msg).is_err() { return; }; } })) } - ReceiverStream::new(rx).inspect(move |_| { + UnboundedReceiverStream::new(rx).inspect(move |_| { // keep the tasks alive as long as the stream lives let _ = &tasks; }) From c1df5b95af3ff082038ddb7735c05c42c3f92f7e Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Sun, 23 Nov 2025 11:08:08 +0100 Subject: [PATCH 3/6] Integrate with memory accounting --- src/execution_plans/network_shuffle.rs | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/src/execution_plans/network_shuffle.rs b/src/execution_plans/network_shuffle.rs index 4961f81..9722f22 100644 --- a/src/execution_plans/network_shuffle.rs +++ b/src/execution_plans/network_shuffle.rs @@ -8,6 +8,7 @@ use crate::protobuf::StageKey; use crate::protobuf::{map_flight_to_datafusion_error, map_status_to_datafusion_error}; use crate::stage::{MaybeEncodedPlan, Stage}; use crate::{ChannelResolver, DistributedTaskContext, InputStageInfo, NetworkBoundary}; +use arrow::record_batch::RecordBatch; use arrow_flight::Ticket; use arrow_flight::decode::FlightRecordBatchStream; use arrow_flight::error::FlightError; @@ -16,6 +17,7 @@ use dashmap::DashMap; use datafusion::common::runtime::SpawnedTask; use datafusion::common::{exec_err, internal_datafusion_err, plan_err}; use datafusion::error::DataFusionError; +use datafusion::execution::memory_pool::{MemoryConsumer, MemoryPool}; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_expr::Partitioning; use datafusion::physical_plan::repartition::RepartitionExec; @@ -374,7 +376,7 @@ impl ExecutionPlan for NetworkShuffleExec { Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), - spawn_select_all(stream.collect()), + spawn_select_all(stream.collect(), Arc::clone(context.memory_pool())), ))) } } @@ -386,26 +388,38 @@ impl ExecutionPlan for NetworkShuffleExec { // of deadlocking the stream on the server side (https://github.com/datafusion-contrib/datafusion-distributed/issues/228). // Even having these channels bounded would result in deadlocks (learned it the hard way). // Until we figure out what's wrong there, this is a good enough solution. -fn spawn_select_all(inner: Vec) -> impl Stream +fn spawn_select_all( + inner: Vec, + pool: Arc, +) -> impl Stream> where - T: Stream + Send + Unpin + 'static, - T::Item: Send, + T: Stream> + Send + Unpin + 'static, { let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let mut tasks = vec![]; for mut t in inner { let tx = tx.clone(); + let pool = Arc::clone(&pool); + let consumer = MemoryConsumer::new("NetworkShuffleExec"); + tasks.push(SpawnedTask::spawn(async move { while let Some(msg) = t.next().await { - if tx.send(msg).is_err() { + let mut reservation = consumer.clone_with_new_id().register(&pool); + if let Ok(msg) = &msg { + reservation.grow(msg.get_array_memory_size()); + } + + if tx.send((msg, reservation)).is_err() { return; }; } })) } - UnboundedReceiverStream::new(rx).inspect(move |_| { + UnboundedReceiverStream::new(rx).map(move |(msg, _reservation)| { // keep the tasks alive as long as the stream lives let _ = &tasks; + msg }) } From b166ee8b6367dccd17bdd511c3a4927ade66db95 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Sun, 23 Nov 2025 17:49:41 +0100 Subject: [PATCH 4/6] Also apply spawn_select_all to network coalesce --- src/execution_plans/common.rs | 48 +++++++++++++++++++++++ src/execution_plans/network_coalesce.rs | 11 ++++-- src/execution_plans/network_shuffle.rs | 51 +------------------------ 3 files changed, 57 insertions(+), 53 deletions(-) diff --git a/src/execution_plans/common.rs b/src/execution_plans/common.rs index f085b3e..5bb1716 100644 --- a/src/execution_plans/common.rs +++ b/src/execution_plans/common.rs @@ -1,8 +1,13 @@ +use arrow::array::RecordBatch; +use datafusion::common::runtime::SpawnedTask; use datafusion::common::{DataFusionError, plan_err}; +use datafusion::execution::memory_pool::{MemoryConsumer, MemoryPool}; use datafusion::physical_expr::Partitioning; use datafusion::physical_plan::{ExecutionPlan, PlanProperties}; +use futures::{Stream, StreamExt}; use std::borrow::Borrow; use std::sync::Arc; +use tokio_stream::wrappers::UnboundedReceiverStream; pub(super) fn require_one_child( children: L, @@ -40,3 +45,46 @@ pub(super) fn scale_partitioning( Partitioning::UnknownPartitioning(p) => Partitioning::UnknownPartitioning(f(*p)), } } + +/// Consumes all the provided streams in parallel sending their produced messages to a single +/// queue in random order. The resulting queue is returned as a stream. +// FIXME: It should not be necessary to do this, it should be fine to just consume +// all the messages with a normal tokio::stream::select_all, however, that has the chance +// of deadlocking the stream on the server side (https://github.com/datafusion-contrib/datafusion-distributed/issues/228). +// Even having these channels bounded would result in deadlocks (learned it the hard way). +// Until we figure out what's wrong there, this is a good enough solution. +pub(super) fn spawn_select_all( + inner: Vec, + pool: Arc, +) -> impl Stream> +where + T: Stream> + Send + Unpin + 'static, +{ + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + + let mut tasks = vec![]; + for mut t in inner { + let tx = tx.clone(); + let pool = Arc::clone(&pool); + let consumer = MemoryConsumer::new("NetworkShuffleExec"); + + tasks.push(SpawnedTask::spawn(async move { + while let Some(msg) = t.next().await { + let mut reservation = consumer.clone_with_new_id().register(&pool); + if let Ok(msg) = &msg { + reservation.grow(msg.get_array_memory_size()); + } + + if tx.send((msg, reservation)).is_err() { + return; + }; + } + })) + } + + UnboundedReceiverStream::new(rx).map(move |(msg, _reservation)| { + // keep the tasks alive as long as the stream lives + let _ = &tasks; + msg + }) +} diff --git a/src/execution_plans/network_coalesce.rs b/src/execution_plans/network_coalesce.rs index f6653f2..07379b4 100644 --- a/src/execution_plans/network_coalesce.rs +++ b/src/execution_plans/network_coalesce.rs @@ -1,7 +1,9 @@ use crate::channel_resolver_ext::get_distributed_channel_resolver; use crate::config_extension_ext::ContextGrpcMetadata; use crate::distributed_planner::{InputStageInfo, NetworkBoundary, limit_tasks_err}; -use crate::execution_plans::common::{require_one_child, scale_partitioning_props}; +use crate::execution_plans::common::{ + require_one_child, scale_partitioning_props, spawn_select_all, +}; use crate::flight_service::DoGet; use crate::metrics::MetricsCollectingStream; use crate::metrics::proto::MetricsSetProto; @@ -18,7 +20,7 @@ use datafusion::error::DataFusionError; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; -use futures::{TryFutureExt, TryStreamExt}; +use futures::{StreamExt, TryFutureExt, TryStreamExt}; use http::Extensions; use prost::Message; use std::any::Any; @@ -319,11 +321,12 @@ impl ExecutionPlan for NetworkCoalesceExec { .map_err(map_flight_to_datafusion_error), ) } - .try_flatten_stream(); + .try_flatten_stream() + .boxed(); Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), - stream, + spawn_select_all(vec![stream], Arc::clone(context.memory_pool())), ))) } } diff --git a/src/execution_plans/network_shuffle.rs b/src/execution_plans/network_shuffle.rs index 9722f22..89f87d3 100644 --- a/src/execution_plans/network_shuffle.rs +++ b/src/execution_plans/network_shuffle.rs @@ -1,6 +1,6 @@ use crate::channel_resolver_ext::get_distributed_channel_resolver; use crate::config_extension_ext::ContextGrpcMetadata; -use crate::execution_plans::common::{require_one_child, scale_partitioning}; +use crate::execution_plans::common::{require_one_child, scale_partitioning, spawn_select_all}; use crate::flight_service::DoGet; use crate::metrics::MetricsCollectingStream; use crate::metrics::proto::MetricsSetProto; @@ -8,16 +8,13 @@ use crate::protobuf::StageKey; use crate::protobuf::{map_flight_to_datafusion_error, map_status_to_datafusion_error}; use crate::stage::{MaybeEncodedPlan, Stage}; use crate::{ChannelResolver, DistributedTaskContext, InputStageInfo, NetworkBoundary}; -use arrow::record_batch::RecordBatch; use arrow_flight::Ticket; use arrow_flight::decode::FlightRecordBatchStream; use arrow_flight::error::FlightError; use bytes::Bytes; use dashmap::DashMap; -use datafusion::common::runtime::SpawnedTask; use datafusion::common::{exec_err, internal_datafusion_err, plan_err}; use datafusion::error::DataFusionError; -use datafusion::execution::memory_pool::{MemoryConsumer, MemoryPool}; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_expr::Partitioning; use datafusion::physical_plan::repartition::RepartitionExec; @@ -25,13 +22,12 @@ use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, }; -use futures::{Stream, StreamExt, TryFutureExt, TryStreamExt}; +use futures::{StreamExt, TryFutureExt, TryStreamExt}; use http::Extensions; use prost::Message; use std::any::Any; use std::fmt::Formatter; use std::sync::Arc; -use tokio_stream::wrappers::UnboundedReceiverStream; use tonic::Request; use tonic::metadata::MetadataMap; @@ -380,46 +376,3 @@ impl ExecutionPlan for NetworkShuffleExec { ))) } } - -/// Consumes all the provided streams in parallel sending their produced messages to a single -/// queue in random order. The resulting queue is returned as a stream. -// FIXME: It should not be necessary to do this, it should be fine to just consume -// all the messages with a normal tokio::stream::select_all, however, that has the chance -// of deadlocking the stream on the server side (https://github.com/datafusion-contrib/datafusion-distributed/issues/228). -// Even having these channels bounded would result in deadlocks (learned it the hard way). -// Until we figure out what's wrong there, this is a good enough solution. -fn spawn_select_all( - inner: Vec, - pool: Arc, -) -> impl Stream> -where - T: Stream> + Send + Unpin + 'static, -{ - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - - let mut tasks = vec![]; - for mut t in inner { - let tx = tx.clone(); - let pool = Arc::clone(&pool); - let consumer = MemoryConsumer::new("NetworkShuffleExec"); - - tasks.push(SpawnedTask::spawn(async move { - while let Some(msg) = t.next().await { - let mut reservation = consumer.clone_with_new_id().register(&pool); - if let Ok(msg) = &msg { - reservation.grow(msg.get_array_memory_size()); - } - - if tx.send((msg, reservation)).is_err() { - return; - }; - } - })) - } - - UnboundedReceiverStream::new(rx).map(move |(msg, _reservation)| { - // keep the tasks alive as long as the stream lives - let _ = &tasks; - msg - }) -} From f5d3ac53550e23ecc6342f4221244949b3314be3 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Mon, 24 Nov 2025 07:56:03 +0100 Subject: [PATCH 5/6] Better name for spawn_select_all memory consumer --- src/execution_plans/common.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/execution_plans/common.rs b/src/execution_plans/common.rs index 5bb1716..647211d 100644 --- a/src/execution_plans/common.rs +++ b/src/execution_plans/common.rs @@ -66,7 +66,7 @@ where for mut t in inner { let tx = tx.clone(); let pool = Arc::clone(&pool); - let consumer = MemoryConsumer::new("NetworkShuffleExec"); + let consumer = MemoryConsumer::new("NetworkBoundary"); tasks.push(SpawnedTask::spawn(async move { while let Some(msg) = t.next().await { From 48dc726324ca8fca8f90e92c2fc231bb7fbd9af0 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Mon, 24 Nov 2025 10:55:42 +0100 Subject: [PATCH 6/6] Add test --- src/execution_plans/common.rs | 66 ++++++++++++++++++++++++++++++++--- 1 file changed, 62 insertions(+), 4 deletions(-) diff --git a/src/execution_plans/common.rs b/src/execution_plans/common.rs index 647211d..3f46714 100644 --- a/src/execution_plans/common.rs +++ b/src/execution_plans/common.rs @@ -53,12 +53,14 @@ pub(super) fn scale_partitioning( // of deadlocking the stream on the server side (https://github.com/datafusion-contrib/datafusion-distributed/issues/228). // Even having these channels bounded would result in deadlocks (learned it the hard way). // Until we figure out what's wrong there, this is a good enough solution. -pub(super) fn spawn_select_all( +pub(super) fn spawn_select_all( inner: Vec, pool: Arc, -) -> impl Stream> +) -> impl Stream> where - T: Stream> + Send + Unpin + 'static, + T: Stream> + Send + Unpin + 'static, + El: MemoryFootPrint + Send + 'static, + Err: Send + 'static, { let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); @@ -72,7 +74,7 @@ where while let Some(msg) = t.next().await { let mut reservation = consumer.clone_with_new_id().register(&pool); if let Ok(msg) = &msg { - reservation.grow(msg.get_array_memory_size()); + reservation.grow(msg.get_memory_size()); } if tx.send((msg, reservation)).is_err() { @@ -88,3 +90,59 @@ where msg }) } + +pub(super) trait MemoryFootPrint { + fn get_memory_size(&self) -> usize; +} + +impl MemoryFootPrint for RecordBatch { + fn get_memory_size(&self) -> usize { + self.get_array_memory_size() + } +} + +#[cfg(test)] +mod tests { + use crate::execution_plans::common::{MemoryFootPrint, spawn_select_all}; + use datafusion::execution::memory_pool::{MemoryPool, UnboundedMemoryPool}; + use std::error::Error; + use std::sync::Arc; + use tokio_stream::StreamExt; + + #[tokio::test] + async fn memory_reservation() -> Result<(), Box> { + let pool: Arc = Arc::new(UnboundedMemoryPool::default()); + + let mut stream = spawn_select_all( + vec![ + futures::stream::iter(vec![Ok::<_, String>(1), Ok(2), Ok(3)]), + futures::stream::iter(vec![Ok(4), Ok(5)]), + ], + Arc::clone(&pool), + ); + tokio::time::sleep(tokio::time::Duration::from_millis(1)).await; + let reserved = pool.reserved(); + assert_eq!(reserved, 15); + + for i in [1, 2, 3] { + let n = stream.next().await.unwrap()?; + assert_eq!(i, n) + } + + let reserved = pool.reserved(); + assert_eq!(reserved, 9); + + drop(stream); + + let reserved = pool.reserved(); + assert_eq!(reserved, 0); + + Ok(()) + } + + impl MemoryFootPrint for usize { + fn get_memory_size(&self) -> usize { + *self + } + } +}