diff --git a/src/channel_resolver_ext.rs b/src/channel_resolver_ext.rs index 1a6df43..33cf739 100644 --- a/src/channel_resolver_ext.rs +++ b/src/channel_resolver_ext.rs @@ -1,4 +1,5 @@ use async_trait::async_trait; +use datafusion::common::exec_datafusion_err; use datafusion::error::DataFusionError; use datafusion::prelude::SessionConfig; use std::sync::Arc; @@ -16,9 +17,10 @@ pub(crate) fn set_distributed_channel_resolver( pub(crate) fn get_distributed_channel_resolver( cfg: &SessionConfig, -) -> Option> { +) -> Result, DataFusionError> { cfg.get_extension::() .map(|cm| cm.0.clone()) + .ok_or_else(|| exec_datafusion_err!("ChannelResolver not present in the session config")) } #[derive(Clone)] diff --git a/src/execution_plans/arrow_flight_read.rs b/src/execution_plans/arrow_flight_read.rs index 234cbc8..bee5819 100644 --- a/src/execution_plans/arrow_flight_read.rs +++ b/src/execution_plans/arrow_flight_read.rs @@ -147,18 +147,14 @@ impl ExecutionPlan for ArrowFlightReadExec { partition: usize, context: Arc, ) -> Result { - let ArrowFlightReadExec::Ready(this) = self else { + let ArrowFlightReadExec::Ready(self_ready) = self else { return exec_err!("ArrowFlightReadExec is not ready, was the distributed optimization step performed?"); }; // get the channel manager and current stage from our context - let Some(channel_resolver) = get_distributed_channel_resolver(context.session_config()) - else { - return exec_err!( - "ArrowFlightReadExec requires a ChannelResolver in the session config" - ); - }; + let channel_resolver = get_distributed_channel_resolver(context.session_config())?; + // the `ArrowFlightReadExec` node can only be executed in the context of a `StageExec` let stage = context .session_config() .get_extension::() @@ -170,10 +166,10 @@ impl ExecutionPlan for ArrowFlightReadExec { // reading from let child_stage = stage .child_stages_iter() - .find(|s| s.num == this.stage_num) + .find(|s| s.num == self_ready.stage_num) .ok_or(internal_datafusion_err!( "ArrowFlightReadExec: no child stage with num {}", - this.stage_num + self_ready.stage_num ))?; let flight_metadata = context diff --git a/src/execution_plans/stage.rs b/src/execution_plans/stage.rs index ff42349..fc1e45a 100644 --- a/src/execution_plans/stage.rs +++ b/src/execution_plans/stage.rs @@ -1,6 +1,6 @@ use crate::channel_resolver_ext::get_distributed_channel_resolver; use crate::{ArrowFlightReadExec, ChannelResolver, PartitionIsolatorExec}; -use datafusion::common::{exec_err, internal_err}; +use datafusion::common::internal_err; use datafusion::error::{DataFusionError, Result}; use datafusion::execution::TaskContext; use datafusion::physical_plan::{ @@ -260,20 +260,10 @@ impl ExecutionPlan for StageExec { partition: usize, context: Arc, ) -> Result { - let stage = self - .as_any() - .downcast_ref::() - .expect("Unwrapping myself should always work"); - - let Some(channel_resolver) = get_distributed_channel_resolver(context.session_config()) - else { - return exec_err!("ChannelManager not found in session config"); - }; - - let urls = channel_resolver.get_urls()?; + let channel_resolver = get_distributed_channel_resolver(context.session_config())?; - let assigned_stage = stage - .try_assign_urls(&urls) + let assigned_stage = self + .try_assign_urls(&channel_resolver.get_urls()?) .map(Arc::new) .map_err(|e| DataFusionError::Execution(e.to_string()))?; diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index c316525..e979667 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -9,13 +9,14 @@ use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::error::FlightError; use arrow_flight::flight_service_server::FlightService; use arrow_flight::Ticket; -use datafusion::execution::SessionState; +use datafusion::execution::{SendableRecordBatchStream, SessionState}; use futures::TryStreamExt; +use http::HeaderMap; use prost::Message; +use std::fmt::Display; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use tokio::sync::OnceCell; -use tonic::metadata::MetadataMap; use tonic::{Request, Response, Status}; #[derive(Clone, PartialEq, ::prost::Message)] @@ -41,9 +42,9 @@ pub struct DoGet { /// TaskData stores state for a single task being executed by this Endpoint. It may be shared /// by concurrent requests for the same task which execute separate partitions. pub struct TaskData { - pub(super) state: SessionState, + pub(super) session_state: SessionState, pub(super) stage: Arc, - ///num_partitions_remaining is initialized to the total number of partitions in the task (not + /// `num_partitions_remaining` is initialized to the total number of partitions in the task (not /// only tasks in the partition group). This is decremented for each request to the endpoint /// for this task. Once this count is zero, the task is likely complete. The task may not be /// complete because it's possible that the same partition was retried and this count was @@ -56,98 +57,78 @@ impl ArrowFlightEndpoint { &self, request: Request, ) -> Result::DoGetStream>, Status> { - let (metadata, _ext, ticket) = request.into_parts(); - let Ticket { ticket } = ticket; - let doget = DoGet::decode(ticket).map_err(|err| { + let (metadata, _ext, body) = request.into_parts(); + let doget = DoGet::decode(body.ticket).map_err(|err| { Status::invalid_argument(format!("Cannot decode DoGet message: {err}")) })?; + // There's only 1 `StageExec` responsible for all requests that share the same `stage_key`, + // so here we either retrieve the existing one or create a new one if it does not exist. + let (mut session_state, stage) = self + .get_state_and_stage( + doget.stage_key.ok_or_else(missing("stage_key"))?, + doget.stage_proto.ok_or_else(missing("stage_proto"))?, + metadata.clone().into_headers(), + ) + .await?; + + // Find out which partition group we are executing let partition = doget.partition as usize; let task_number = doget.task_number as usize; - let task_data = self.get_state_and_stage(doget, metadata).await?; - - let stage = task_data.stage; - let mut state = task_data.state; - - // find out which partition group we are executing - let task = stage - .tasks - .get(task_number) - .ok_or(Status::invalid_argument(format!( - "Task number {} not found in stage {}", - task_number, - stage.name() - )))?; - - let partition_group = PartitionGroup(task.partition_group.clone()); - state.config_mut().set_extension(Arc::new(partition_group)); - - let inner_plan = stage.plan.clone(); - - let stream = inner_plan - .execute(partition, state.task_ctx()) + let task = stage.tasks.get(task_number).ok_or_else(invalid(format!( + "Task number {task_number} not found in stage {}", + stage.num + )))?; + + let cfg = session_state.config_mut(); + cfg.set_extension(Arc::new(PartitionGroup(task.partition_group.clone()))); + cfg.set_extension(Arc::clone(&stage)); + cfg.set_extension(Arc::new(ContextGrpcMetadata(metadata.into_headers()))); + + // Rather than executing the `StageExec` itself, we want to execute the inner plan instead, + // as executing `StageExec` performs some worker assignation that should have already been + // done in the head stage. + let stream = stage + .plan + .execute(partition, session_state.task_ctx()) .map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?; - let flight_data_stream = FlightDataEncoderBuilder::new() - .with_schema(inner_plan.schema().clone()) - .build(stream.map_err(|err| { - FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err))) - })); - - Ok(Response::new(Box::pin(flight_data_stream.map_err( - |err| match err { - FlightError::Tonic(status) => *status, - _ => Status::internal(format!("Error during flight stream: {err}")), - }, - )))) + Ok(record_batch_stream_to_response(stream)) } async fn get_state_and_stage( &self, - doget: DoGet, - metadata_map: MetadataMap, - ) -> Result { - let key = doget - .stage_key - .ok_or(Status::invalid_argument("DoGet is missing the stage key"))?; - let once_stage = self - .stages + key: StageKey, + stage_proto: StageExecProto, + headers: HeaderMap, + ) -> Result<(SessionState, Arc), Status> { + let once = self + .task_data_entries .get_or_init(key.clone(), || Arc::new(OnceCell::::new())); - let stage_data = once_stage + let stage_data = once .get_or_try_init(|| async { - let stage_proto = doget - .stage_proto - .ok_or(Status::invalid_argument("DoGet is missing the stage proto"))?; - - let headers = metadata_map.into_headers(); - let mut state = self + let session_state = self .session_builder .build_session_state(DistributedSessionBuilderContext { runtime_env: Arc::clone(&self.runtime), - headers: headers.clone(), + headers, }) .await .map_err(|err| datafusion_error_to_tonic_status(&err))?; - let codec = DistributedCodec::new_combined_with_user(state.config()); + let codec = DistributedCodec::new_combined_with_user(session_state.config()); - let stage = stage_from_proto(stage_proto, &state, self.runtime.as_ref(), &codec) - .map(Arc::new) + let stage = stage_from_proto(stage_proto, &session_state, &self.runtime, &codec) .map_err(|err| { Status::invalid_argument(format!("Cannot decode stage proto: {err}")) })?; - // Add the extensions that might be required for ExecutionPlan nodes in the plan - let config = state.config_mut(); - config.set_extension(stage.clone()); - config.set_extension(Arc::new(ContextGrpcMetadata(headers))); - // Initialize partition count to the number of partitions in the stage let total_partitions = stage.plan.properties().partitioning.partition_count(); Ok::<_, Status>(TaskData { - state, - stage, + session_state, + stage: Arc::new(stage), num_partitions_remaining: Arc::new(AtomicUsize::new(total_partitions)), }) }) @@ -158,13 +139,37 @@ impl ArrowFlightEndpoint { .num_partitions_remaining .fetch_sub(1, Ordering::SeqCst); if remaining_partitions <= 1 { - self.stages.remove(key.clone()); + self.task_data_entries.remove(key); } - Ok(stage_data.clone()) + Ok((stage_data.session_state.clone(), stage_data.stage.clone())) } } +fn missing(field: &'static str) -> impl FnOnce() -> Status { + move || Status::invalid_argument(format!("Missing field '{field}'")) +} + +fn invalid(msg: impl Display) -> impl FnOnce() -> Status { + move || Status::invalid_argument(msg.to_string()) +} + +fn record_batch_stream_to_response( + stream: SendableRecordBatchStream, +) -> Response<::DoGetStream> { + let flight_data_stream = + FlightDataEncoderBuilder::new() + .with_schema(stream.schema().clone()) + .build(stream.map_err(|err| { + FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err))) + })); + + Response::new(Box::pin(flight_data_stream.map_err(|err| match err { + FlightError::Tonic(status) => *status, + _ => Status::internal(format!("Error during flight stream: {err}")), + }))) +} + #[cfg(test)] mod tests { use super::*; @@ -262,13 +267,13 @@ mod tests { } // Check that the endpoint has not evicted any task states. - assert_eq!(endpoint.stages.len(), num_tasks); + assert_eq!(endpoint.task_data_entries.len(), num_tasks); // Run the last partition of task 0. Any partition number works. Verify that the task state // is evicted because all partitions have been processed. let result = do_get(1, 0, task_keys[0].clone()).await; assert!(result.is_ok()); - let stored_stage_keys = endpoint.stages.keys().collect::>(); + let stored_stage_keys = endpoint.task_data_entries.keys().collect::>(); assert_eq!(stored_stage_keys.len(), 2); assert!(stored_stage_keys.contains(&task_keys[1])); assert!(stored_stage_keys.contains(&task_keys[2])); @@ -276,14 +281,14 @@ mod tests { // Run the last partition of task 1. let result = do_get(1, 1, task_keys[1].clone()).await; assert!(result.is_ok()); - let stored_stage_keys = endpoint.stages.keys().collect::>(); + let stored_stage_keys = endpoint.task_data_entries.keys().collect::>(); assert_eq!(stored_stage_keys.len(), 1); assert!(stored_stage_keys.contains(&task_keys[2])); // Run the last partition of the last task. let result = do_get(1, 2, task_keys[2].clone()).await; assert!(result.is_ok()); - let stored_stage_keys = endpoint.stages.keys().collect::>(); + let stored_stage_keys = endpoint.task_data_entries.keys().collect::>(); assert_eq!(stored_stage_keys.len(), 0); } diff --git a/src/flight_service/service.rs b/src/flight_service/service.rs index 3b02cb3..c50ff1d 100644 --- a/src/flight_service/service.rs +++ b/src/flight_service/service.rs @@ -30,8 +30,7 @@ pub struct StageKey { pub struct ArrowFlightEndpoint { pub(super) runtime: Arc, - #[allow(clippy::type_complexity)] - pub(super) stages: TTLMap>>, + pub(super) task_data_entries: TTLMap>>, pub(super) session_builder: Arc, } @@ -42,7 +41,7 @@ impl ArrowFlightEndpoint { let ttl_map = TTLMap::try_new(TTLMapConfig::default())?; Ok(Self { runtime: Arc::new(RuntimeEnv::default()), - stages: ttl_map, + task_data_entries: ttl_map, session_builder: Arc::new(session_builder), }) }