Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions src/planning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::{
env,
sync::{Arc, LazyLock},
};
use tokio::sync::OnceCell;

use anyhow::{anyhow, Context};
use arrow_flight::Action;
Expand Down Expand Up @@ -89,17 +90,23 @@ impl DDStage {
}
}

static STATE: LazyLock<Result<SessionState>> = LazyLock::new(|| {
let wait_result = wait_for(make_state(), "make_state");
match wait_result {
Ok(Ok(state)) => Ok(state),
Ok(Err(e)) => Err(anyhow!("Failed to initialize state: {}", e).into()),
Err(e) => Err(anyhow!("Failed to initialize state: {}", e).into()),
}
});
// STATE contains the global SessionState which contains table information and config options.
//
// Note that the OnceCell is thead safe (it is Sync + Send): https://docs.rs/tokio/latest/tokio/sync/struct.OnceCell.html
static STATE: OnceCell<Result<SessionState>> = OnceCell::const_new();

// get_ctx returns the global SessionContext.
pub async fn get_ctx() -> Result<SessionContext> {
let result = STATE
.get_or_init(|| async {
match make_state().await {
Ok(state) => Ok(state),
Err(e) => Err(anyhow!("Failed to initialize state: {}", e).into()),
}
})
.await;

pub fn get_ctx() -> Result<SessionContext> {
match &*STATE {
match result {
Ok(state) => Ok(SessionContext::new_with_state(state.clone())),
Err(e) => Err(anyhow!("Context initialization failed: {}", e).into()),
}
Expand Down
5 changes: 3 additions & 2 deletions src/proxy_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,9 @@ impl DDProxyHandler {
stage_id: u64,
addrs: Addrs,
) -> Result<Response<crate::flight::DoGetStream>, Status> {
let mut ctx =
get_ctx().map_err(|e| Status::internal(format!("Could not create context {e:?}")))?;
let mut ctx = get_ctx()
.await
.map_err(|e| Status::internal(format!("Could not create context {e:?}")))?;

add_ctx_extentions(&mut ctx, &self.host, &query_id, stage_id, addrs, vec![])
.map_err(|e| Status::internal(format!("Could not add context extensions {e:?}")))?;
Expand Down
8 changes: 6 additions & 2 deletions src/query_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ impl QueryPlanner {
/// distributed plan, and distributed stages, but it does not yet contain
/// worker addresses or tasks, as they are filled in later by `distribute_plan()`.
pub async fn prepare(&self, sql: &str) -> Result<QueryPlan> {
let mut ctx = get_ctx().map_err(|e| anyhow!("Could not create context: {e}"))?;
let mut ctx = get_ctx()
.await
.map_err(|e| anyhow!("Could not create context: {e}"))?;
if let Some(customizer) = &self.customizer {
customizer
.customize(&mut ctx)
Expand Down Expand Up @@ -119,7 +121,9 @@ impl QueryPlanner {
/// distributed plan, and distributed stages, but it does not yet contain
/// worker addresses or tasks, as they are filled in later by `distribute_plan()`.
pub async fn prepare_substrait(&self, substrait_plan: Plan) -> Result<QueryPlan> {
let mut ctx = get_ctx().map_err(|e| anyhow!("Could not create context: {e}"))?;
let mut ctx = get_ctx()
.await
.map_err(|e| anyhow!("Could not create context: {e}"))?;
if let Some(customizer) = &self.customizer {
customizer
.customize(&mut ctx)
Expand Down
97 changes: 62 additions & 35 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ use std::{
time::Duration,
};

use anyhow::{anyhow, Context as anyhowctx};
use tokio::time::timeout;

use anyhow::{anyhow, Context as anyhowctx, Error};
use arrow::{
array::RecordBatch,
datatypes::SchemaRef,
Expand Down Expand Up @@ -53,6 +55,7 @@ use tokio::{
use tonic::transport::Channel;
use url::Url;

use crate::result::DDError;
use crate::{
logging::{debug, error, trace},
protobuf::StageAddrs,
Expand All @@ -76,59 +79,75 @@ impl Spawner {
Self { runtime }
}

// wait_for_future waits for the future f to complete. If it does not complete, this will
// block forever.
fn wait_for_future<F>(&self, f: F, name: &str) -> Result<F::Output>
where
F: Future + Send + 'static,
F::Output: Send,
{
// sanity check that we are not in an async runtime. We don't want the code below to
// block an executor accidentally.
if Handle::try_current().is_ok() {
panic!("cannot call wait_for_future within an async runtime")
}

let name_c = name.to_owned();
trace!("Spawner::wait_for {name_c}");
let (tx, rx) = std::sync::mpsc::channel::<F::Output>();
let (tx, mut rx) = tokio::sync::mpsc::channel::<F::Output>(1);

let func = move || {
let func = async move || {
trace!("spawned fut start {name_c}");

let out = Handle::current().block_on(f);
let out = f.await;
trace!("spawned fut stop {name_c}");
tx.send(out).inspect_err(|e| {
error!("ERROR sending future reesult over channel!!!! {e:?}");
})
let result = tx.send(out).await;

// This should never happen. An error occurs if the channel was closed or the receiver
// was dropped. Neither happens before this line in this function.
if let Err(e) = result {
error!("ERROR sending future {name_c} result over channel! {e:?}");
}
// tx is dropped, channel is closed.
};

// Spawn the task in the runtime.
{
let _guard = self.runtime.enter();
let handle = Handle::current();

trace!("Spawner spawning {name}");
handle.spawn_blocking(func);
trace!("Spawner spawned {name}");
trace!("Spawner spawning {name} (sync)");
tokio::spawn(func());
trace!("Spawner spawned {name} (sync)");
}

let out = rx
.recv_timeout(Duration::from_secs(5))
.inspect_err(|e| {
error!("Spawner::wait_for {name} timed out waiting for future result: {e:?}");
})
.context("Spawner::wait_for failed to receive future result")?;

debug!("Spawner::wait_for {name} returning");
Ok(out)
match rx.blocking_recv() {
// Channel was closed without any messages.
None => {
error!("Spawner::wait_for {name} timed out waiting for future result");
Err(DDError::Other(anyhow!("future {} did not complete", name)))
}
Some(result) => Ok(result),
}
}
}

// SPAWNER is used to run futures in a synchronous runtime.
static SPAWNER: OnceLock<Spawner> = OnceLock::new();

// wait_for blocks on the future and returns when the future is complete. It will return an error
// if called in an async runtime, since the async runtime should simply await f.
pub fn wait_for<F>(f: F, name: &str) -> Result<F::Output>
where
F: Future + Send + 'static,
F::Output: Send,
{
let spawner = SPAWNER.get_or_init(Spawner::new);
if Handle::try_current().is_ok() {
return Err(DDError::Other(anyhow!(
"cannot call wait_for in async runtime. consider awaitiing the future {} instead",
name
)));
}

trace!("waiting for future: {name}");
let spawner = SPAWNER.get_or_init(Spawner::new);
let name = name.to_owned();
let out = spawner.wait_for_future(f, &name);
trace!("done waiting for future: {name}");
out
}

Expand Down Expand Up @@ -653,19 +672,27 @@ mod test {
}

#[test]
fn test_wait_for_nested() {
println!("test_wait_for_nested");
fn test_wait_for_nested_error() {
let fut = async || {
println!("in outter fut");
let fut5 = async || {
println!("in inner fut");
5
};
let fut5 = async || 5;
wait_for(fut5(), "inner").unwrap()
};

let out = wait_for(fut(), "outer").unwrap();
assert_eq!(out, 5);
// Return an error because the nested wait_for is called in an async runtime.
let out = wait_for(fut(), "outer");
assert!(out.is_err());
}

#[tokio::test]
async fn test_wait_for_errors_in_async_runtime() {
let fut5 = async || {
println!("in inner fut");
5
};

// Return an error because the nested wait_for is called in an async runtime.
let out = wait_for(fut5(), "fut5");
assert!(out.is_err());
}

#[test(tokio::test)]
Expand Down
2 changes: 1 addition & 1 deletion src/worker_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ impl DDWorkerHandler {
stage_addrs: Addrs,
partition_group: Vec<u64>,
) -> Result<SessionContext> {
let mut ctx = get_ctx()?;
let mut ctx = get_ctx().await?;
let host = Host {
addr: self.addr.clone(),
name: self.name.clone(),
Expand Down
Loading