Skip to content

Commit 94c5d77

Browse files
authored
Fix network boundary deadlocks (#240)
* Fix network shuffle deadlock * Use an unbounded queue * Integrate with memory accounting * Also apply spawn_select_all to network coalesce * Better name for spawn_select_all memory consumer * Add test
1 parent 3262025 commit 94c5d77

File tree

3 files changed

+115
-6
lines changed

3 files changed

+115
-6
lines changed

src/execution_plans/common.rs

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1+
use arrow::array::RecordBatch;
2+
use datafusion::common::runtime::SpawnedTask;
13
use datafusion::common::{DataFusionError, plan_err};
4+
use datafusion::execution::memory_pool::{MemoryConsumer, MemoryPool};
25
use datafusion::physical_expr::Partitioning;
36
use datafusion::physical_plan::{ExecutionPlan, PlanProperties};
7+
use futures::{Stream, StreamExt};
48
use std::borrow::Borrow;
59
use std::sync::Arc;
10+
use tokio_stream::wrappers::UnboundedReceiverStream;
611

712
pub(super) fn require_one_child<L, T>(
813
children: L,
@@ -40,3 +45,104 @@ pub(super) fn scale_partitioning(
4045
Partitioning::UnknownPartitioning(p) => Partitioning::UnknownPartitioning(f(*p)),
4146
}
4247
}
48+
49+
/// Consumes all the provided streams in parallel sending their produced messages to a single
50+
/// queue in random order. The resulting queue is returned as a stream.
51+
// FIXME: It should not be necessary to do this, it should be fine to just consume
52+
// all the messages with a normal tokio::stream::select_all, however, that has the chance
53+
// of deadlocking the stream on the server side (https://github.com/datafusion-contrib/datafusion-distributed/issues/228).
54+
// Even having these channels bounded would result in deadlocks (learned it the hard way).
55+
// Until we figure out what's wrong there, this is a good enough solution.
56+
pub(super) fn spawn_select_all<T, El, Err>(
57+
inner: Vec<T>,
58+
pool: Arc<dyn MemoryPool>,
59+
) -> impl Stream<Item = Result<El, Err>>
60+
where
61+
T: Stream<Item = Result<El, Err>> + Send + Unpin + 'static,
62+
El: MemoryFootPrint + Send + 'static,
63+
Err: Send + 'static,
64+
{
65+
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
66+
67+
let mut tasks = vec![];
68+
for mut t in inner {
69+
let tx = tx.clone();
70+
let pool = Arc::clone(&pool);
71+
let consumer = MemoryConsumer::new("NetworkBoundary");
72+
73+
tasks.push(SpawnedTask::spawn(async move {
74+
while let Some(msg) = t.next().await {
75+
let mut reservation = consumer.clone_with_new_id().register(&pool);
76+
if let Ok(msg) = &msg {
77+
reservation.grow(msg.get_memory_size());
78+
}
79+
80+
if tx.send((msg, reservation)).is_err() {
81+
return;
82+
};
83+
}
84+
}))
85+
}
86+
87+
UnboundedReceiverStream::new(rx).map(move |(msg, _reservation)| {
88+
// keep the tasks alive as long as the stream lives
89+
let _ = &tasks;
90+
msg
91+
})
92+
}
93+
94+
pub(super) trait MemoryFootPrint {
95+
fn get_memory_size(&self) -> usize;
96+
}
97+
98+
impl MemoryFootPrint for RecordBatch {
99+
fn get_memory_size(&self) -> usize {
100+
self.get_array_memory_size()
101+
}
102+
}
103+
104+
#[cfg(test)]
105+
mod tests {
106+
use crate::execution_plans::common::{MemoryFootPrint, spawn_select_all};
107+
use datafusion::execution::memory_pool::{MemoryPool, UnboundedMemoryPool};
108+
use std::error::Error;
109+
use std::sync::Arc;
110+
use tokio_stream::StreamExt;
111+
112+
#[tokio::test]
113+
async fn memory_reservation() -> Result<(), Box<dyn Error>> {
114+
let pool: Arc<dyn MemoryPool> = Arc::new(UnboundedMemoryPool::default());
115+
116+
let mut stream = spawn_select_all(
117+
vec![
118+
futures::stream::iter(vec![Ok::<_, String>(1), Ok(2), Ok(3)]),
119+
futures::stream::iter(vec![Ok(4), Ok(5)]),
120+
],
121+
Arc::clone(&pool),
122+
);
123+
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
124+
let reserved = pool.reserved();
125+
assert_eq!(reserved, 15);
126+
127+
for i in [1, 2, 3] {
128+
let n = stream.next().await.unwrap()?;
129+
assert_eq!(i, n)
130+
}
131+
132+
let reserved = pool.reserved();
133+
assert_eq!(reserved, 9);
134+
135+
drop(stream);
136+
137+
let reserved = pool.reserved();
138+
assert_eq!(reserved, 0);
139+
140+
Ok(())
141+
}
142+
143+
impl MemoryFootPrint for usize {
144+
fn get_memory_size(&self) -> usize {
145+
*self
146+
}
147+
}
148+
}

src/execution_plans/network_coalesce.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
use crate::channel_resolver_ext::get_distributed_channel_resolver;
22
use crate::config_extension_ext::ContextGrpcMetadata;
33
use crate::distributed_planner::{InputStageInfo, NetworkBoundary, limit_tasks_err};
4-
use crate::execution_plans::common::{require_one_child, scale_partitioning_props};
4+
use crate::execution_plans::common::{
5+
require_one_child, scale_partitioning_props, spawn_select_all,
6+
};
57
use crate::flight_service::DoGet;
68
use crate::metrics::MetricsCollectingStream;
79
use crate::metrics::proto::MetricsSetProto;
@@ -18,7 +20,7 @@ use datafusion::error::DataFusionError;
1820
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
1921
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
2022
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
21-
use futures::{TryFutureExt, TryStreamExt};
23+
use futures::{StreamExt, TryFutureExt, TryStreamExt};
2224
use http::Extensions;
2325
use prost::Message;
2426
use std::any::Any;
@@ -319,11 +321,12 @@ impl ExecutionPlan for NetworkCoalesceExec {
319321
.map_err(map_flight_to_datafusion_error),
320322
)
321323
}
322-
.try_flatten_stream();
324+
.try_flatten_stream()
325+
.boxed();
323326

324327
Ok(Box::pin(RecordBatchStreamAdapter::new(
325328
self.schema(),
326-
stream,
329+
spawn_select_all(vec![stream], Arc::clone(context.memory_pool())),
327330
)))
328331
}
329332
}

src/execution_plans/network_shuffle.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::channel_resolver_ext::get_distributed_channel_resolver;
22
use crate::config_extension_ext::ContextGrpcMetadata;
3-
use crate::execution_plans::common::{require_one_child, scale_partitioning};
3+
use crate::execution_plans::common::{require_one_child, scale_partitioning, spawn_select_all};
44
use crate::flight_service::DoGet;
55
use crate::metrics::MetricsCollectingStream;
66
use crate::metrics::proto::MetricsSetProto;
@@ -384,7 +384,7 @@ impl ExecutionPlan for NetworkShuffleExec {
384384

385385
Ok(Box::pin(RecordBatchStreamAdapter::new(
386386
self.schema(),
387-
futures::stream::select_all(stream),
387+
spawn_select_all(stream.collect(), Arc::clone(context.memory_pool())),
388388
)))
389389
}
390390
}

0 commit comments

Comments
 (0)