Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/channel_resolver_ext.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use async_trait::async_trait;
use datafusion::common::exec_datafusion_err;
use datafusion::error::DataFusionError;
use datafusion::prelude::SessionConfig;
use std::sync::Arc;
Expand All @@ -16,9 +17,10 @@ pub(crate) fn set_distributed_channel_resolver(

pub(crate) fn get_distributed_channel_resolver(
cfg: &SessionConfig,
) -> Option<Arc<dyn ChannelResolver + Send + Sync>> {
) -> Result<Arc<dyn ChannelResolver + Send + Sync>, DataFusionError> {
cfg.get_extension::<ChannelResolverExtension>()
.map(|cm| cm.0.clone())
.ok_or_else(|| exec_datafusion_err!("ChannelResolver not present in the session config"))
}

#[derive(Clone)]
Expand Down
14 changes: 5 additions & 9 deletions src/execution_plans/arrow_flight_read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,18 +147,14 @@ impl ExecutionPlan for ArrowFlightReadExec {
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream, DataFusionError> {
let ArrowFlightReadExec::Ready(this) = self else {
let ArrowFlightReadExec::Ready(self_ready) = self else {
return exec_err!("ArrowFlightReadExec is not ready, was the distributed optimization step performed?");
};

// get the channel manager and current stage from our context
let Some(channel_resolver) = get_distributed_channel_resolver(context.session_config())
else {
return exec_err!(
"ArrowFlightReadExec requires a ChannelResolver in the session config"
);
};
let channel_resolver = get_distributed_channel_resolver(context.session_config())?;

// the `ArrowFlightReadExec` node can only be executed in the context of a `StageExec`
let stage = context
.session_config()
.get_extension::<StageExec>()
Expand All @@ -170,10 +166,10 @@ impl ExecutionPlan for ArrowFlightReadExec {
// reading from
let child_stage = stage
.child_stages_iter()
.find(|s| s.num == this.stage_num)
.find(|s| s.num == self_ready.stage_num)
.ok_or(internal_datafusion_err!(
"ArrowFlightReadExec: no child stage with num {}",
this.stage_num
self_ready.stage_num
))?;

let flight_metadata = context
Expand Down
18 changes: 4 additions & 14 deletions src/execution_plans/stage.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::channel_resolver_ext::get_distributed_channel_resolver;
use crate::{ArrowFlightReadExec, ChannelResolver, PartitionIsolatorExec};
use datafusion::common::{exec_err, internal_err};
use datafusion::common::internal_err;
use datafusion::error::{DataFusionError, Result};
use datafusion::execution::TaskContext;
use datafusion::physical_plan::{
Expand Down Expand Up @@ -260,20 +260,10 @@ impl ExecutionPlan for StageExec {
partition: usize,
context: Arc<TaskContext>,
) -> Result<datafusion::execution::SendableRecordBatchStream> {
let stage = self
.as_any()
.downcast_ref::<StageExec>()
.expect("Unwrapping myself should always work");

let Some(channel_resolver) = get_distributed_channel_resolver(context.session_config())
else {
return exec_err!("ChannelManager not found in session config");
};

let urls = channel_resolver.get_urls()?;
let channel_resolver = get_distributed_channel_resolver(context.session_config())?;

let assigned_stage = stage
.try_assign_urls(&urls)
let assigned_stage = self
.try_assign_urls(&channel_resolver.get_urls()?)
.map(Arc::new)
.map_err(|e| DataFusionError::Execution(e.to_string()))?;

Expand Down
151 changes: 78 additions & 73 deletions src/flight_service/do_get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@ use arrow_flight::encode::FlightDataEncoderBuilder;
use arrow_flight::error::FlightError;
use arrow_flight::flight_service_server::FlightService;
use arrow_flight::Ticket;
use datafusion::execution::SessionState;
use datafusion::execution::{SendableRecordBatchStream, SessionState};
use futures::TryStreamExt;
use http::HeaderMap;
use prost::Message;
use std::fmt::Display;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::OnceCell;
use tonic::metadata::MetadataMap;
use tonic::{Request, Response, Status};

#[derive(Clone, PartialEq, ::prost::Message)]
Expand All @@ -41,9 +42,9 @@ pub struct DoGet {
/// TaskData stores state for a single task being executed by this Endpoint. It may be shared
/// by concurrent requests for the same task which execute separate partitions.
pub struct TaskData {
pub(super) state: SessionState,
pub(super) session_state: SessionState,
pub(super) stage: Arc<StageExec>,
///num_partitions_remaining is initialized to the total number of partitions in the task (not
/// `num_partitions_remaining` is initialized to the total number of partitions in the task (not
/// only tasks in the partition group). This is decremented for each request to the endpoint
/// for this task. Once this count is zero, the task is likely complete. The task may not be
/// complete because it's possible that the same partition was retried and this count was
Expand All @@ -56,98 +57,78 @@ impl ArrowFlightEndpoint {
&self,
request: Request<Ticket>,
) -> Result<Response<<ArrowFlightEndpoint as FlightService>::DoGetStream>, Status> {
let (metadata, _ext, ticket) = request.into_parts();
let Ticket { ticket } = ticket;
let doget = DoGet::decode(ticket).map_err(|err| {
let (metadata, _ext, body) = request.into_parts();
let doget = DoGet::decode(body.ticket).map_err(|err| {
Status::invalid_argument(format!("Cannot decode DoGet message: {err}"))
})?;

// There's only 1 `StageExec` responsible for all requests that share the same `stage_key`,
// so here we either retrieve the existing one or create a new one if it does not exist.
let (mut session_state, stage) = self
.get_state_and_stage(
doget.stage_key.ok_or_else(missing("stage_key"))?,
doget.stage_proto.ok_or_else(missing("stage_proto"))?,
metadata.clone().into_headers(),
)
.await?;

// Find out which partition group we are executing
let partition = doget.partition as usize;
let task_number = doget.task_number as usize;
let task_data = self.get_state_and_stage(doget, metadata).await?;

let stage = task_data.stage;
let mut state = task_data.state;

// find out which partition group we are executing
let task = stage
.tasks
.get(task_number)
.ok_or(Status::invalid_argument(format!(
"Task number {} not found in stage {}",
task_number,
stage.name()
)))?;

let partition_group = PartitionGroup(task.partition_group.clone());
state.config_mut().set_extension(Arc::new(partition_group));

let inner_plan = stage.plan.clone();

let stream = inner_plan
.execute(partition, state.task_ctx())
let task = stage.tasks.get(task_number).ok_or_else(invalid(format!(
"Task number {task_number} not found in stage {}",
stage.num
)))?;

let cfg = session_state.config_mut();
cfg.set_extension(Arc::new(PartitionGroup(task.partition_group.clone())));
cfg.set_extension(Arc::clone(&stage));
cfg.set_extension(Arc::new(ContextGrpcMetadata(metadata.into_headers())));

// Rather than executing the `StageExec` itself, we want to execute the inner plan instead,
// as executing `StageExec` performs some worker assignation that should have already been
// done in the head stage.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice and useful comment

let stream = stage
.plan
.execute(partition, session_state.task_ctx())
.map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?;

let flight_data_stream = FlightDataEncoderBuilder::new()
.with_schema(inner_plan.schema().clone())
.build(stream.map_err(|err| {
FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err)))
}));

Ok(Response::new(Box::pin(flight_data_stream.map_err(
|err| match err {
FlightError::Tonic(status) => *status,
_ => Status::internal(format!("Error during flight stream: {err}")),
},
))))
Ok(record_batch_stream_to_response(stream))
}

async fn get_state_and_stage(
&self,
doget: DoGet,
metadata_map: MetadataMap,
) -> Result<TaskData, Status> {
let key = doget
.stage_key
.ok_or(Status::invalid_argument("DoGet is missing the stage key"))?;
let once_stage = self
.stages
key: StageKey,
stage_proto: StageExecProto,
headers: HeaderMap,
) -> Result<(SessionState, Arc<StageExec>), Status> {
let once = self
.task_data_entries
.get_or_init(key.clone(), || Arc::new(OnceCell::<TaskData>::new()));

let stage_data = once_stage
let stage_data = once
.get_or_try_init(|| async {
let stage_proto = doget
.stage_proto
.ok_or(Status::invalid_argument("DoGet is missing the stage proto"))?;

let headers = metadata_map.into_headers();
let mut state = self
let session_state = self
.session_builder
.build_session_state(DistributedSessionBuilderContext {
runtime_env: Arc::clone(&self.runtime),
headers: headers.clone(),
headers,
})
.await
.map_err(|err| datafusion_error_to_tonic_status(&err))?;

let codec = DistributedCodec::new_combined_with_user(state.config());
let codec = DistributedCodec::new_combined_with_user(session_state.config());

let stage = stage_from_proto(stage_proto, &state, self.runtime.as_ref(), &codec)
.map(Arc::new)
let stage = stage_from_proto(stage_proto, &session_state, &self.runtime, &codec)
.map_err(|err| {
Status::invalid_argument(format!("Cannot decode stage proto: {err}"))
})?;

// Add the extensions that might be required for ExecutionPlan nodes in the plan
let config = state.config_mut();
config.set_extension(stage.clone());
config.set_extension(Arc::new(ContextGrpcMetadata(headers)));

// Initialize partition count to the number of partitions in the stage
let total_partitions = stage.plan.properties().partitioning.partition_count();
Ok::<_, Status>(TaskData {
state,
stage,
session_state,
stage: Arc::new(stage),
num_partitions_remaining: Arc::new(AtomicUsize::new(total_partitions)),
})
})
Expand All @@ -158,13 +139,37 @@ impl ArrowFlightEndpoint {
.num_partitions_remaining
.fetch_sub(1, Ordering::SeqCst);
if remaining_partitions <= 1 {
self.stages.remove(key.clone());
self.task_data_entries.remove(key);
}

Ok(stage_data.clone())
Ok((stage_data.session_state.clone(), stage_data.stage.clone()))
}
}

fn missing(field: &'static str) -> impl FnOnce() -> Status {
move || Status::invalid_argument(format!("Missing field '{field}'"))
}

fn invalid(msg: impl Display) -> impl FnOnce() -> Status {
move || Status::invalid_argument(msg.to_string())
}

fn record_batch_stream_to_response(
stream: SendableRecordBatchStream,
) -> Response<<ArrowFlightEndpoint as FlightService>::DoGetStream> {
let flight_data_stream =
FlightDataEncoderBuilder::new()
.with_schema(stream.schema().clone())
.build(stream.map_err(|err| {
FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err)))
}));

Response::new(Box::pin(flight_data_stream.map_err(|err| match err {
FlightError::Tonic(status) => *status,
_ => Status::internal(format!("Error during flight stream: {err}")),
})))
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -262,28 +267,28 @@ mod tests {
}

// Check that the endpoint has not evicted any task states.
assert_eq!(endpoint.stages.len(), num_tasks);
assert_eq!(endpoint.task_data_entries.len(), num_tasks);

// Run the last partition of task 0. Any partition number works. Verify that the task state
// is evicted because all partitions have been processed.
let result = do_get(1, 0, task_keys[0].clone()).await;
assert!(result.is_ok());
let stored_stage_keys = endpoint.stages.keys().collect::<Vec<StageKey>>();
let stored_stage_keys = endpoint.task_data_entries.keys().collect::<Vec<StageKey>>();
assert_eq!(stored_stage_keys.len(), 2);
assert!(stored_stage_keys.contains(&task_keys[1]));
assert!(stored_stage_keys.contains(&task_keys[2]));

// Run the last partition of task 1.
let result = do_get(1, 1, task_keys[1].clone()).await;
assert!(result.is_ok());
let stored_stage_keys = endpoint.stages.keys().collect::<Vec<StageKey>>();
let stored_stage_keys = endpoint.task_data_entries.keys().collect::<Vec<StageKey>>();
assert_eq!(stored_stage_keys.len(), 1);
assert!(stored_stage_keys.contains(&task_keys[2]));

// Run the last partition of the last task.
let result = do_get(1, 2, task_keys[2].clone()).await;
assert!(result.is_ok());
let stored_stage_keys = endpoint.stages.keys().collect::<Vec<StageKey>>();
let stored_stage_keys = endpoint.task_data_entries.keys().collect::<Vec<StageKey>>();
assert_eq!(stored_stage_keys.len(), 0);
}

Expand Down
5 changes: 2 additions & 3 deletions src/flight_service/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ pub struct StageKey {

pub struct ArrowFlightEndpoint {
pub(super) runtime: Arc<RuntimeEnv>,
#[allow(clippy::type_complexity)]
pub(super) stages: TTLMap<StageKey, Arc<OnceCell<TaskData>>>,
pub(super) task_data_entries: TTLMap<StageKey, Arc<OnceCell<TaskData>>>,
pub(super) session_builder: Arc<dyn DistributedSessionBuilder + Send + Sync>,
}

Expand All @@ -42,7 +41,7 @@ impl ArrowFlightEndpoint {
let ttl_map = TTLMap::try_new(TTLMapConfig::default())?;
Ok(Self {
runtime: Arc::new(RuntimeEnv::default()),
stages: ttl_map,
task_data_entries: ttl_map,
session_builder: Arc::new(session_builder),
})
}
Expand Down