Skip to content

Commit 269516f

Browse files
util: ensure wait_for can only be used in a non-async runtime
Previously, `wait_for` would deadlock when used in an async runtime (ie. any async function calls `wait_for` directly or indirectly). Now, `wait_for` returns an error if called in an async runtime. There's no reason to call it in an async runtime / function because you can simply await the future. This change updates a few places where `wait_for` is called to be async functions and just wait on the future. Why did the old `Spawner` deadlock? `Handle::current().block_on(f)` and the `std::sync::mpsc::channel` would both block an executor thread in tokio (this is a very bad practice generally). If there's no executer threads available, you can't run anything. Now, the closure just does `f.await` instead of `block_on`, which does not block an executor. The channel is now a `tokio::sync::mpsc::channel` as well with a buffer size of 1, so the async send will not block (I think a buffered std channel would have worked, but it's better to use tokio implementations generally). The receiver of the channel is in a sync runtime, so it does a blocking receive. Testing - unit tests now pass - added a unit test for `wait_for` returning an error in an async runtime - test for nested `wait_for` calls which now return an error Closes https://github.com/datafusion-contrib/datafusion-distributed/issues/63
1 parent 18ed7d3 commit 269516f

File tree

5 files changed

+89
-50
lines changed

5 files changed

+89
-50
lines changed

src/planning.rs

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::{
33
env,
44
sync::{Arc, LazyLock},
55
};
6+
use tokio::sync::OnceCell;
67

78
use anyhow::{anyhow, Context};
89
use arrow_flight::Action;
@@ -89,17 +90,23 @@ impl DDStage {
8990
}
9091
}
9192

92-
static STATE: LazyLock<Result<SessionState>> = LazyLock::new(|| {
93-
let wait_result = wait_for(make_state(), "make_state");
94-
match wait_result {
95-
Ok(Ok(state)) => Ok(state),
96-
Ok(Err(e)) => Err(anyhow!("Failed to initialize state: {}", e).into()),
97-
Err(e) => Err(anyhow!("Failed to initialize state: {}", e).into()),
98-
}
99-
});
93+
// STATE contains the global SessionState which contains table information and config options.
94+
//
95+
// Note that the OnceCell is thead safe (it is Sync + Send): https://docs.rs/tokio/latest/tokio/sync/struct.OnceCell.html
96+
static STATE: OnceCell<Result<SessionState>> = OnceCell::const_new();
97+
98+
// get_ctx returns the global SessionContext.
99+
pub async fn get_ctx() -> Result<SessionContext> {
100+
let result = STATE
101+
.get_or_init(|| async {
102+
match make_state().await {
103+
Ok(state) => Ok(state),
104+
Err(e) => Err(anyhow!("Failed to initialize state: {}", e).into()),
105+
}
106+
})
107+
.await;
100108

101-
pub fn get_ctx() -> Result<SessionContext> {
102-
match &*STATE {
109+
match result {
103110
Ok(state) => Ok(SessionContext::new_with_state(state.clone())),
104111
Err(e) => Err(anyhow!("Context initialization failed: {}", e).into()),
105112
}

src/proxy_service.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,9 @@ impl DDProxyHandler {
121121
stage_id: u64,
122122
addrs: Addrs,
123123
) -> Result<Response<crate::flight::DoGetStream>, Status> {
124-
let mut ctx =
125-
get_ctx().map_err(|e| Status::internal(format!("Could not create context {e:?}")))?;
124+
let mut ctx = get_ctx()
125+
.await
126+
.map_err(|e| Status::internal(format!("Could not create context {e:?}")))?;
126127

127128
add_ctx_extentions(&mut ctx, &self.host, &query_id, stage_id, addrs, vec![])
128129
.map_err(|e| Status::internal(format!("Could not add context extensions {e:?}")))?;

src/query_planner.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,9 @@ impl QueryPlanner {
9191
/// distributed plan, and distributed stages, but it does not yet contain
9292
/// worker addresses or tasks, as they are filled in later by `distribute_plan()`.
9393
pub async fn prepare(&self, sql: &str) -> Result<QueryPlan> {
94-
let mut ctx = get_ctx().map_err(|e| anyhow!("Could not create context: {e}"))?;
94+
let mut ctx = get_ctx()
95+
.await
96+
.map_err(|e| anyhow!("Could not create context: {e}"))?;
9597
if let Some(customizer) = &self.customizer {
9698
customizer
9799
.customize(&mut ctx)
@@ -119,7 +121,9 @@ impl QueryPlanner {
119121
/// distributed plan, and distributed stages, but it does not yet contain
120122
/// worker addresses or tasks, as they are filled in later by `distribute_plan()`.
121123
pub async fn prepare_substrait(&self, substrait_plan: Plan) -> Result<QueryPlan> {
122-
let mut ctx = get_ctx().map_err(|e| anyhow!("Could not create context: {e}"))?;
124+
let mut ctx = get_ctx()
125+
.await
126+
.map_err(|e| anyhow!("Could not create context: {e}"))?;
123127
if let Some(customizer) = &self.customizer {
124128
customizer
125129
.customize(&mut ctx)

src/util.rs

Lines changed: 62 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ use std::{
99
time::Duration,
1010
};
1111

12-
use anyhow::{anyhow, Context as anyhowctx};
12+
use tokio::time::timeout;
13+
14+
use anyhow::{anyhow, Context as anyhowctx, Error};
1315
use arrow::{
1416
array::RecordBatch,
1517
datatypes::SchemaRef,
@@ -53,6 +55,7 @@ use tokio::{
5355
use tonic::transport::Channel;
5456
use url::Url;
5557

58+
use crate::result::DDError;
5659
use crate::{
5760
logging::{debug, error, trace},
5861
protobuf::StageAddrs,
@@ -76,59 +79,75 @@ impl Spawner {
7679
Self { runtime }
7780
}
7881

82+
// wait_for_future waits for the future f to complete. If it does not complete, this will
83+
// block forever.
7984
fn wait_for_future<F>(&self, f: F, name: &str) -> Result<F::Output>
8085
where
8186
F: Future + Send + 'static,
8287
F::Output: Send,
8388
{
89+
// sanity check that we are not in an async runtime. We don't want the code below to
90+
// block an executor accidentally.
91+
if Handle::try_current().is_ok() {
92+
panic!("cannot call wait_for_future within an async runtime")
93+
}
94+
8495
let name_c = name.to_owned();
85-
trace!("Spawner::wait_for {name_c}");
86-
let (tx, rx) = std::sync::mpsc::channel::<F::Output>();
96+
let (tx, mut rx) = tokio::sync::mpsc::channel::<F::Output>(1);
8797

88-
let func = move || {
98+
let func = async move || {
8999
trace!("spawned fut start {name_c}");
90-
91-
let out = Handle::current().block_on(f);
100+
let out = f.await;
92101
trace!("spawned fut stop {name_c}");
93-
tx.send(out).inspect_err(|e| {
94-
error!("ERROR sending future reesult over channel!!!! {e:?}");
95-
})
102+
let result = tx.send(out).await;
103+
104+
// This should never happen. An error occurs if the channel was closed or the receiver
105+
// was dropped. Neither happens before this line in this function.
106+
if let Err(e) = result {
107+
error!("ERROR sending future {name_c} result over channel! {e:?}");
108+
}
109+
// tx is dropped, channel is closed.
96110
};
97111

112+
// Spawn the task in the runtime.
98113
{
99114
let _guard = self.runtime.enter();
100-
let handle = Handle::current();
101-
102-
trace!("Spawner spawning {name}");
103-
handle.spawn_blocking(func);
104-
trace!("Spawner spawned {name}");
115+
trace!("Spawner spawning {name} (sync)");
116+
tokio::spawn(func());
117+
trace!("Spawner spawned {name} (sync)");
105118
}
106119

107-
let out = rx
108-
.recv_timeout(Duration::from_secs(5))
109-
.inspect_err(|e| {
110-
error!("Spawner::wait_for {name} timed out waiting for future result: {e:?}");
111-
})
112-
.context("Spawner::wait_for failed to receive future result")?;
113-
114-
debug!("Spawner::wait_for {name} returning");
115-
Ok(out)
120+
match rx.blocking_recv() {
121+
// Channel was closed without any messages.
122+
None => {
123+
error!("Spawner::wait_for {name} timed out waiting for future result");
124+
Err(DDError::Other(anyhow!("future {} did not complete", name)))
125+
}
126+
Some(result) => Ok(result),
127+
}
116128
}
117129
}
118130

131+
// SPAWNER is used to run futures in a synchronous runtime.
119132
static SPAWNER: OnceLock<Spawner> = OnceLock::new();
120133

134+
// wait_for blocks on the future and returns when the future is complete. It will return an error
135+
// if called in an async runtime, since the async runtime should simply await f.
121136
pub fn wait_for<F>(f: F, name: &str) -> Result<F::Output>
122137
where
123138
F: Future + Send + 'static,
124139
F::Output: Send,
125140
{
126-
let spawner = SPAWNER.get_or_init(Spawner::new);
141+
if Handle::try_current().is_ok() {
142+
return Err(DDError::Other(anyhow!(
143+
"cannot call wait_for in async runtime. consider awaitiing the future {} instead",
144+
name
145+
)));
146+
}
127147

128-
trace!("waiting for future: {name}");
148+
let spawner = SPAWNER.get_or_init(Spawner::new);
129149
let name = name.to_owned();
130150
let out = spawner.wait_for_future(f, &name);
131-
trace!("done waiting for future: {name}");
132151
out
133152
}
134153

@@ -653,19 +672,27 @@ mod test {
653672
}
654673

655674
#[test]
656-
fn test_wait_for_nested() {
657-
println!("test_wait_for_nested");
675+
fn test_wait_for_nested_error() {
658676
let fut = async || {
659-
println!("in outter fut");
660-
let fut5 = async || {
661-
println!("in inner fut");
662-
5
663-
};
677+
let fut5 = async || 5;
664678
wait_for(fut5(), "inner").unwrap()
665679
};
666680

667-
let out = wait_for(fut(), "outer").unwrap();
668-
assert_eq!(out, 5);
681+
// Return an error because the nested wait_for is called in an async runtime.
682+
let out = wait_for(fut(), "outer");
683+
assert!(out.is_err());
684+
}
685+
686+
#[tokio::test]
687+
async fn test_wait_for_errors_in_async_runtime() {
688+
let fut5 = async || {
689+
println!("in inner fut");
690+
5
691+
};
692+
693+
// Return an error because the nested wait_for is called in an async runtime.
694+
let out = wait_for(fut5(), "fut5");
695+
assert!(out.is_err());
669696
}
670697

671698
#[test(tokio::test)]

src/worker_service.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ impl DDWorkerHandler {
261261
stage_addrs: Addrs,
262262
partition_group: Vec<u64>,
263263
) -> Result<SessionContext> {
264-
let mut ctx = get_ctx()?;
264+
let mut ctx = get_ctx().await?;
265265
let host = Host {
266266
addr: self.addr.clone(),
267267
name: self.name.clone(),

0 commit comments

Comments
 (0)