diff --git a/src/execution_plans/common.rs b/src/execution_plans/common.rs index f085b3e..3f46714 100644 --- a/src/execution_plans/common.rs +++ b/src/execution_plans/common.rs @@ -1,8 +1,13 @@ +use arrow::array::RecordBatch; +use datafusion::common::runtime::SpawnedTask; use datafusion::common::{DataFusionError, plan_err}; +use datafusion::execution::memory_pool::{MemoryConsumer, MemoryPool}; use datafusion::physical_expr::Partitioning; use datafusion::physical_plan::{ExecutionPlan, PlanProperties}; +use futures::{Stream, StreamExt}; use std::borrow::Borrow; use std::sync::Arc; +use tokio_stream::wrappers::UnboundedReceiverStream; pub(super) fn require_one_child( children: L, @@ -40,3 +45,104 @@ pub(super) fn scale_partitioning( Partitioning::UnknownPartitioning(p) => Partitioning::UnknownPartitioning(f(*p)), } } + +/// Consumes all the provided streams in parallel sending their produced messages to a single +/// queue in random order. The resulting queue is returned as a stream. +// FIXME: It should not be necessary to do this, it should be fine to just consume +// all the messages with a normal tokio::stream::select_all, however, that has the chance +// of deadlocking the stream on the server side (https://github.com/datafusion-contrib/datafusion-distributed/issues/228). +// Even having these channels bounded would result in deadlocks (learned it the hard way). +// Until we figure out what's wrong there, this is a good enough solution. +pub(super) fn spawn_select_all( + inner: Vec, + pool: Arc, +) -> impl Stream> +where + T: Stream> + Send + Unpin + 'static, + El: MemoryFootPrint + Send + 'static, + Err: Send + 'static, +{ + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + + let mut tasks = vec![]; + for mut t in inner { + let tx = tx.clone(); + let pool = Arc::clone(&pool); + let consumer = MemoryConsumer::new("NetworkBoundary"); + + tasks.push(SpawnedTask::spawn(async move { + while let Some(msg) = t.next().await { + let mut reservation = consumer.clone_with_new_id().register(&pool); + if let Ok(msg) = &msg { + reservation.grow(msg.get_memory_size()); + } + + if tx.send((msg, reservation)).is_err() { + return; + }; + } + })) + } + + UnboundedReceiverStream::new(rx).map(move |(msg, _reservation)| { + // keep the tasks alive as long as the stream lives + let _ = &tasks; + msg + }) +} + +pub(super) trait MemoryFootPrint { + fn get_memory_size(&self) -> usize; +} + +impl MemoryFootPrint for RecordBatch { + fn get_memory_size(&self) -> usize { + self.get_array_memory_size() + } +} + +#[cfg(test)] +mod tests { + use crate::execution_plans::common::{MemoryFootPrint, spawn_select_all}; + use datafusion::execution::memory_pool::{MemoryPool, UnboundedMemoryPool}; + use std::error::Error; + use std::sync::Arc; + use tokio_stream::StreamExt; + + #[tokio::test] + async fn memory_reservation() -> Result<(), Box> { + let pool: Arc = Arc::new(UnboundedMemoryPool::default()); + + let mut stream = spawn_select_all( + vec![ + futures::stream::iter(vec![Ok::<_, String>(1), Ok(2), Ok(3)]), + futures::stream::iter(vec![Ok(4), Ok(5)]), + ], + Arc::clone(&pool), + ); + tokio::time::sleep(tokio::time::Duration::from_millis(1)).await; + let reserved = pool.reserved(); + assert_eq!(reserved, 15); + + for i in [1, 2, 3] { + let n = stream.next().await.unwrap()?; + assert_eq!(i, n) + } + + let reserved = pool.reserved(); + assert_eq!(reserved, 9); + + drop(stream); + + let reserved = pool.reserved(); + assert_eq!(reserved, 0); + + Ok(()) + } + + impl MemoryFootPrint for usize { + fn get_memory_size(&self) -> usize { + *self + } + } +} diff --git a/src/execution_plans/network_coalesce.rs b/src/execution_plans/network_coalesce.rs index f6653f2..07379b4 100644 --- a/src/execution_plans/network_coalesce.rs +++ b/src/execution_plans/network_coalesce.rs @@ -1,7 +1,9 @@ use crate::channel_resolver_ext::get_distributed_channel_resolver; use crate::config_extension_ext::ContextGrpcMetadata; use crate::distributed_planner::{InputStageInfo, NetworkBoundary, limit_tasks_err}; -use crate::execution_plans::common::{require_one_child, scale_partitioning_props}; +use crate::execution_plans::common::{ + require_one_child, scale_partitioning_props, spawn_select_all, +}; use crate::flight_service::DoGet; use crate::metrics::MetricsCollectingStream; use crate::metrics::proto::MetricsSetProto; @@ -18,7 +20,7 @@ use datafusion::error::DataFusionError; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; -use futures::{TryFutureExt, TryStreamExt}; +use futures::{StreamExt, TryFutureExt, TryStreamExt}; use http::Extensions; use prost::Message; use std::any::Any; @@ -319,11 +321,12 @@ impl ExecutionPlan for NetworkCoalesceExec { .map_err(map_flight_to_datafusion_error), ) } - .try_flatten_stream(); + .try_flatten_stream() + .boxed(); Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), - stream, + spawn_select_all(vec![stream], Arc::clone(context.memory_pool())), ))) } } diff --git a/src/execution_plans/network_shuffle.rs b/src/execution_plans/network_shuffle.rs index 1e2f313..89f87d3 100644 --- a/src/execution_plans/network_shuffle.rs +++ b/src/execution_plans/network_shuffle.rs @@ -1,6 +1,6 @@ use crate::channel_resolver_ext::get_distributed_channel_resolver; use crate::config_extension_ext::ContextGrpcMetadata; -use crate::execution_plans::common::{require_one_child, scale_partitioning}; +use crate::execution_plans::common::{require_one_child, scale_partitioning, spawn_select_all}; use crate::flight_service::DoGet; use crate::metrics::MetricsCollectingStream; use crate::metrics::proto::MetricsSetProto; @@ -372,7 +372,7 @@ impl ExecutionPlan for NetworkShuffleExec { Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), - futures::stream::select_all(stream), + spawn_select_all(stream.collect(), Arc::clone(context.memory_pool())), ))) } }