diff --git a/Cargo.lock b/Cargo.lock index 00881dd..39d46a5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1111,6 +1111,7 @@ dependencies = [ "arrow", "arrow-flight", "async-trait", + "bytes", "chrono", "dashmap", "datafusion", diff --git a/Cargo.toml b/Cargo.toml index 0381b34..9282e54 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,7 @@ dashmap = "6.1.0" prost = "0.13.5" rand = "0.8.5" object_store = "0.12.3" +bytes = "1.10.1" # integration_tests deps insta = { version = "1.43.1", features = ["filters"], optional = true } diff --git a/src/distributed_physical_optimizer_rule.rs b/src/distributed_physical_optimizer_rule.rs index bff7a8c..1f1829f 100644 --- a/src/distributed_physical_optimizer_rule.rs +++ b/src/distributed_physical_optimizer_rule.rs @@ -282,7 +282,6 @@ impl DistributedPhysicalOptimizerRule { distributed.data = Arc::new(CoalescePartitionsExec::new(distributed.data)); } - let inputs = inputs.into_iter().map(Arc::new).collect(); let mut stage = StageExec::new(query_id, *num, distributed.data, inputs, n_tasks); *num += 1; diff --git a/src/execution_plans/mod.rs b/src/execution_plans/mod.rs index 37e13e3..551b36e 100644 --- a/src/execution_plans/mod.rs +++ b/src/execution_plans/mod.rs @@ -8,5 +8,6 @@ mod stage; pub use network_coalesce::{NetworkCoalesceExec, NetworkCoalesceReady}; pub use network_shuffle::{NetworkShuffleExec, NetworkShuffleReadyExec}; pub use partition_isolator::PartitionIsolatorExec; +pub(crate) use stage::InputStage; pub use stage::display_plan_graphviz; pub use stage::{DistributedTaskContext, ExecutionTask, StageExec}; diff --git a/src/execution_plans/network_coalesce.rs b/src/execution_plans/network_coalesce.rs index 5b2fe62..5b5ca6d 100644 --- a/src/execution_plans/network_coalesce.rs +++ b/src/execution_plans/network_coalesce.rs @@ -6,7 +6,7 @@ use crate::distributed_physical_optimizer_rule::{NetworkBoundary, limit_tasks_er use crate::execution_plans::{DistributedTaskContext, StageExec}; use crate::flight_service::DoGet; use crate::metrics::proto::MetricsSetProto; -use crate::protobuf::{DistributedCodec, StageKey, proto_from_stage}; +use crate::protobuf::{DistributedCodec, StageKey, proto_from_input_stage}; use crate::protobuf::{map_flight_to_datafusion_error, map_status_to_datafusion_error}; use arrow_flight::Ticket; use arrow_flight::decode::FlightRecordBatchStream; @@ -235,12 +235,11 @@ impl ExecutionPlan for NetworkCoalesceExec { // the `NetworkCoalesceExec` node can only be executed in the context of a `StageExec` let stage = StageExec::from_ctx(&context)?; - // of our child stages find the one that matches the one we are supposed to be - // reading from - let child_stage = stage.child_stage(self_ready.stage_num)?; + // of our input stages find the one that we are supposed to be reading from + let input_stage = stage.input_stage(self_ready.stage_num)?; let codec = DistributedCodec::new_combined_with_user(context.session_config()); - let child_stage_proto = proto_from_stage(child_stage, &codec).map_err(|e| { + let input_stage_proto = proto_from_input_stage(input_stage, &codec).map_err(|e| { internal_datafusion_err!("NetworkCoalesceExec: failed to convert stage to proto: {e}") })?; @@ -251,7 +250,7 @@ impl ExecutionPlan for NetworkCoalesceExec { } let partitions_per_task = - self.properties().partitioning.partition_count() / child_stage.tasks.len(); + self.properties().partitioning.partition_count() / input_stage.tasks().len(); let target_task = partition / partitions_per_task; let target_partition = partition % partitions_per_task; @@ -261,11 +260,11 @@ impl ExecutionPlan for NetworkCoalesceExec { Extensions::default(), Ticket { ticket: DoGet { - stage_proto: Some(child_stage_proto.clone()), + stage_proto: input_stage_proto, target_partition: target_partition as u64, stage_key: Some(StageKey { query_id: stage.query_id.to_string(), - stage_id: child_stage.num as u64, + stage_id: input_stage.num() as u64, task_number: target_task as u64, }), target_task_index: target_task as u64, @@ -275,7 +274,7 @@ impl ExecutionPlan for NetworkCoalesceExec { }, ); - let Some(task) = child_stage.tasks.get(target_task) else { + let Some(task) = input_stage.tasks().get(target_task) else { return internal_err!("ProgrammingError: Task {target_task} not found"); }; diff --git a/src/execution_plans/network_shuffle.rs b/src/execution_plans/network_shuffle.rs index eed04b7..c8adf68 100644 --- a/src/execution_plans/network_shuffle.rs +++ b/src/execution_plans/network_shuffle.rs @@ -6,7 +6,7 @@ use crate::distributed_physical_optimizer_rule::NetworkBoundary; use crate::execution_plans::{DistributedTaskContext, StageExec}; use crate::flight_service::DoGet; use crate::metrics::proto::MetricsSetProto; -use crate::protobuf::{DistributedCodec, StageKey, proto_from_stage}; +use crate::protobuf::{DistributedCodec, StageKey, proto_from_input_stage}; use crate::protobuf::{map_flight_to_datafusion_error, map_status_to_datafusion_error}; use arrow_flight::Ticket; use arrow_flight::decode::FlightRecordBatchStream; @@ -293,22 +293,22 @@ impl ExecutionPlan for NetworkShuffleExec { // of our child stages find the one that matches the one we are supposed to be // reading from - let child_stage = stage.child_stage(self_ready.stage_num)?; + let input_stage = stage.input_stage(self_ready.stage_num)?; let codec = DistributedCodec::new_combined_with_user(context.session_config()); - let child_stage_proto = proto_from_stage(child_stage, &codec).map_err(|e| { + let input_stage_proto = proto_from_input_stage(input_stage, &codec).map_err(|e| { internal_datafusion_err!("NetworkShuffleExec: failed to convert stage to proto: {e}") })?; - let child_stage_tasks = child_stage.tasks.clone(); - let child_stage_num = child_stage.num as u64; + let input_stage_tasks = input_stage.tasks().to_vec(); + let input_stage_num = input_stage.num() as u64; let query_id = stage.query_id.to_string(); let context_headers = ContextGrpcMetadata::headers_from_ctx(&context); let task_context = DistributedTaskContext::from_ctx(&context); let off = self_ready.properties.partitioning.partition_count() * task_context.task_index; - let stream = child_stage_tasks.into_iter().enumerate().map(|(i, task)| { + let stream = input_stage_tasks.into_iter().enumerate().map(|(i, task)| { let channel_resolver = Arc::clone(&channel_resolver); let ticket = Request::from_parts( @@ -316,11 +316,11 @@ impl ExecutionPlan for NetworkShuffleExec { Extensions::default(), Ticket { ticket: DoGet { - stage_proto: Some(child_stage_proto.clone()), + stage_proto: input_stage_proto.clone(), target_partition: (off + partition) as u64, stage_key: Some(StageKey { query_id: query_id.clone(), - stage_id: child_stage_num, + stage_id: input_stage_num, task_number: i as u64, }), target_task_index: i as u64, diff --git a/src/execution_plans/stage.rs b/src/execution_plans/stage.rs index d1b7782..effd075 100644 --- a/src/execution_plans/stage.rs +++ b/src/execution_plans/stage.rs @@ -1,7 +1,7 @@ use crate::channel_resolver_ext::get_distributed_channel_resolver; use crate::execution_plans::NetworkCoalesceExec; use crate::{ChannelResolver, NetworkShuffleExec, PartitionIsolatorExec}; -use datafusion::common::{exec_err, internal_datafusion_err}; +use datafusion::common::{exec_err, internal_datafusion_err, internal_err, plan_err}; use datafusion::error::{DataFusionError, Result}; use datafusion::execution::TaskContext; use datafusion::physical_plan::{ @@ -85,7 +85,7 @@ pub struct StageExec { /// The physical execution plan that this stage will execute. pub plan: Arc, /// The input stages to this stage - pub inputs: Vec>, + pub inputs: Vec, /// Our tasks which tell us how finely grained to execute the partitions in /// the plan pub tasks: Vec, @@ -93,6 +93,52 @@ pub struct StageExec { pub depth: usize, } +/// A [StageExec] that is the input of another [StageExec]. +/// +/// It can be either: +/// - Decoded: the inner [StageExec] is stored as-is. +/// - Encoded: the inner [StageExec] is stored as protobuf [Bytes]. Storing it this way allow us +/// to thread it through the project and eventually send it through gRPC in a zero copy manner. +#[derive(Debug, Clone)] +pub enum InputStage { + /// The decoded [StageExec]. Unfortunately, this cannot be an `Arc`, because at + /// some point we need to upcast `&Arc` to `&Arc`, and Rust + /// compiler does not allow it. + /// + /// This is very annoying because it forces us to store it like an `Arc` + /// here even though we know this can only be `Arc`. For this reason + /// [StageExec::from_dyn] was introduced for casting it back to [StageExec]. + Decoded(Arc), + /// A protobuf encoded version of the [InputStage]. The inner [Bytes] represent the full + /// input [StageExec] encoded in protobuf format. + /// + /// By keeping it encoded, we avoid encoding/decoding it unnecessarily in parts of the project + /// that do not need it. Only the Stage num and the [ExecutionTask]s are left decoded, + /// as typically those are the only things needed by the network boundaries. The [Bytes] can be + /// just passed through in a zero copy manner. + Encoded { + num: usize, + tasks: Vec, + proto: Bytes, + }, +} + +impl InputStage { + pub fn num(&self) -> usize { + match self { + InputStage::Decoded(v) => StageExec::from_dyn(v).num, + InputStage::Encoded { num, .. } => *num, + } + } + + pub fn tasks(&self) -> &[ExecutionTask] { + match self { + InputStage::Decoded(v) => &StageExec::from_dyn(v).tasks, + InputStage::Encoded { tasks, .. } => tasks, + } + } +} + #[derive(Debug, Clone)] pub struct ExecutionTask { /// The url of the worker that will execute this task. A None value is interpreted as @@ -118,13 +164,21 @@ impl DistributedTaskContext { } impl StageExec { + /// Dangerous way of accessing a [StageExec] out of an `&Arc`. + /// See [InputStage::Decoded] docs for more details about why panicking here is preferred. + pub(crate) fn from_dyn(plan: &Arc) -> &Self { + plan.as_any() + .downcast_ref() + .expect("Programming Error: expected Arc to be of type StageExec") + } + /// Creates a new `ExecutionStage` with the given plan and inputs. One task will be created /// responsible for partitions in the plan. - pub fn new( + pub(crate) fn new( query_id: Uuid, num: usize, plan: Arc, - inputs: Vec>, + inputs: Vec, n_tasks: usize, ) -> Self { StageExec { @@ -134,7 +188,7 @@ impl StageExec { plan, inputs: inputs .into_iter() - .map(|s| s as Arc) + .map(|s| InputStage::Decoded(Arc::new(s))) .collect(), tasks: vec![ExecutionTask { url: None }; n_tasks], depth: 0, @@ -146,39 +200,22 @@ impl StageExec { format!("Stage {:<3}", self.num) } - /// Returns an iterator over the child stages of this stage cast as &ExecutionStage + /// Returns an iterator over the input stages of this stage cast as &ExecutionStage /// which can be useful - pub fn child_stages_iter(&self) -> impl Iterator { - self.inputs - .iter() - .filter_map(|s| s.as_any().downcast_ref::()) - } - - /// Returns the name of this stage including child stage numbers if any. - pub fn name_with_children(&self) -> String { - let child_str = if self.inputs.is_empty() { - "".to_string() - } else { - format!( - " Child Stages:[{}] ", - self.child_stages_iter() - .map(|s| format!("{}", s.num)) - .collect::>() - .join(", ") - ) - }; - format!("Stage {:<3}{}", self.num, child_str) + pub fn input_stages_iter(&self) -> impl Iterator { + self.inputs.iter() } fn try_assign_urls(&self, urls: &[Url]) -> Result { - let assigned_children = self - .child_stages_iter() - .map(|child| { - child - .clone() // TODO: avoid cloning if possible - .try_assign_urls(urls) - .map(|c| Arc::new(c) as Arc) + let assigned_input_stages = self + .input_stages_iter() + .map(|input_stage| { + let InputStage::Decoded(input_stage) = input_stage else { + return exec_err!("Cannot assign URLs to the tasks in an encoded stage"); + }; + StageExec::from_dyn(input_stage).try_assign_urls(urls) }) + .map_ok(|v| InputStage::Decoded(Arc::new(v))) .collect::>>()?; // pick a random starting position @@ -199,7 +236,7 @@ impl StageExec { num: self.num, name: self.name.clone(), plan: self.plan.clone(), - inputs: assigned_children, + inputs: assigned_input_stages, tasks: assigned_tasks, depth: self.depth, }; @@ -215,10 +252,22 @@ impl StageExec { )) } - pub fn child_stage(&self, i: usize) -> Result<&StageExec, DataFusionError> { - self.child_stages_iter() - .find(|s| s.num == i) - .ok_or(internal_datafusion_err!("no child stage with num {i}")) + pub fn input_stage(&self, stage_num: usize) -> Result<&InputStage, DataFusionError> { + for input_stage in self.input_stages_iter() { + match input_stage { + InputStage::Decoded(v) => { + if StageExec::from_dyn(v).num == stage_num { + return Ok(input_stage); + }; + } + InputStage::Encoded { num, .. } => { + if *num == stage_num { + return Ok(input_stage); + } + } + } + } + internal_err!("no child stage with num {stage_num}") } } @@ -232,22 +281,20 @@ impl ExecutionPlan for StageExec { } fn children(&self) -> Vec<&Arc> { - self.inputs.iter().collect() + self.inputs + .iter() + .filter_map(|v| match v { + InputStage::Decoded(v) => Some(v), + InputStage::Encoded { .. } => None, + }) + .collect() } fn with_new_children( self: Arc, - children: Vec>, + _children: Vec>, ) -> Result> { - Ok(Arc::new(StageExec { - query_id: self.query_id, - num: self.num, - name: self.name.clone(), - plan: self.plan.clone(), - inputs: children, - tasks: self.tasks.clone(), - depth: self.depth, - })) + plan_err!("with_new_children() not supported for StageExec") } fn properties(&self) -> &datafusion::physical_plan::PlanProperties { @@ -297,6 +344,7 @@ impl ExecutionPlan for StageExec { } } +use bytes::Bytes; use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion::physical_expr::Partitioning; /// Be able to display a nice tree for stages. @@ -328,13 +376,12 @@ impl StageExec { if let Some(NetworkShuffleExec::Ready(ready)) = plan.as_any().downcast_ref::() { - let Some(input_stage) = &self.child_stages_iter().find(|v| v.num == ready.stage_num) - else { + let Ok(input_stage) = &self.input_stage(ready.stage_num) else { writeln!(f, "Wrong partition number {}", ready.stage_num)?; return Ok(()); }; let n_tasks = self.tasks.len(); - let input_tasks = input_stage.tasks.len(); + let input_tasks = input_stage.tasks().len(); let partitions = plan.output_partitioning().partition_count(); let stage = ready.stage_num; write!( @@ -346,12 +393,11 @@ impl StageExec { if let Some(NetworkCoalesceExec::Ready(ready)) = plan.as_any().downcast_ref::() { - let Some(input_stage) = &self.child_stages_iter().find(|v| v.num == ready.stage_num) - else { + let Ok(input_stage) = &self.input_stage(ready.stage_num) else { writeln!(f, "Wrong partition number {}", ready.stage_num)?; return Ok(()); }; - let tasks = input_stage.tasks.len(); + let tasks = input_stage.tasks().len(); let partitions = plan.output_partitioning().partition_count(); let stage = ready.stage_num; write!( @@ -507,15 +553,19 @@ pub fn display_plan_graphviz(plan: Arc) -> Result { .downcast_ref::() .expect("Expected StageExec"); - for child_stage in stage.child_stages_iter() { + for input_stage in stage.input_stages_iter() { + let InputStage::Decoded(input_stage) = input_stage else { + continue; + }; + let input_stage = StageExec::from_dyn(input_stage); for task_i in 0..stage.tasks.len() { - for child_task_i in 0..child_stage.tasks.len() { + for input_task_i in 0..input_stage.tasks.len() { let edges = - display_inter_task_edges(stage, task_i, child_stage, child_task_i)?; + display_inter_task_edges(stage, task_i, input_stage, input_task_i)?; writeln!( f, "// edges from child stage {} task {} to stage {} task {}\n {}", - child_stage.num, child_task_i, stage.num, task_i, edges + input_stage.num, input_task_i, stage.num, task_i, edges )?; } } @@ -779,8 +829,8 @@ pub fn display_single_plan( fn display_inter_task_edges( stage: &StageExec, task_i: usize, - child_stage: &StageExec, - child_task_i: usize, + input_stage: &StageExec, + input_task_i: usize, ) -> Result { let mut f = String::new(); let partition_group = @@ -797,7 +847,7 @@ fn display_inter_task_edges( let NetworkShuffleExec::Ready(node) = node else { continue; }; - if node.stage_num != child_stage.num { + if node.stage_num != input_stage.num { continue; } // draw the edges to this node pulling data up from its child @@ -811,9 +861,9 @@ fn display_inter_task_edges( writeln!( f, " {}_{}_{}_{}:t{}:n -> {}_{}_{}_{}:b{}:s {} [color={}]", - child_stage.plan.name(), - child_stage.num, - child_task_i, + input_stage.plan.name(), + input_stage.num, + input_task_i, 1, // the repartition exec is always the first node in the plan p + (task_i * output_partitions), plan.name(), @@ -829,12 +879,12 @@ fn display_inter_task_edges( let NetworkCoalesceExec::Ready(node) = node else { continue; }; - if node.stage_num != child_stage.num { + if node.stage_num != input_stage.num { continue; } // draw the edges to this node pulling data up from its child let output_partitions = plan.output_partitioning().partition_count(); - let input_partitions_per_task = output_partitions / child_stage.tasks.len(); + let input_partitions_per_task = output_partitions / input_stage.tasks.len(); for p in 0..input_partitions_per_task { let mut style = ""; if found_isolator && !partition_group.contains(&p) { @@ -844,16 +894,16 @@ fn display_inter_task_edges( writeln!( f, " {}_{}_{}_{}:t{}:n -> {}_{}_{}_{}:b{}:s {} [color={}]", - child_stage.plan.name(), - child_stage.num, - child_task_i, + input_stage.plan.name(), + input_stage.num, + input_task_i, 1, // the repartition exec is always the first node in the plan p, plan.name(), stage.num, task_i, index, - p + (child_task_i * input_partitions_per_task), + p + (input_task_i * input_partitions_per_task), style, p % NUM_COLORS + 1 )?; diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index b9701ff..91d493d 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -3,12 +3,13 @@ use crate::execution_plans::{DistributedTaskContext, StageExec}; use crate::flight_service::service::ArrowFlightEndpoint; use crate::flight_service::session_builder::DistributedSessionBuilderContext; use crate::protobuf::{ - DistributedCodec, StageExecProto, StageKey, datafusion_error_to_tonic_status, stage_from_proto, + DistributedCodec, StageKey, datafusion_error_to_tonic_status, stage_from_proto, }; use arrow_flight::Ticket; use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::error::FlightError; use arrow_flight::flight_service_server::FlightService; +use bytes::Bytes; use datafusion::common::exec_datafusion_err; use datafusion::execution::SendableRecordBatchStream; use futures::TryStreamExt; @@ -19,9 +20,9 @@ use tonic::{Request, Response, Status}; #[derive(Clone, PartialEq, ::prost::Message)] pub struct DoGet { - /// The ExecutionStage that we are going to execute - #[prost(message, optional, tag = "1")] - pub stage_proto: Option, + /// The [StageExec] we are going to execute encoded as protobuf bytes. + #[prost(bytes, tag = "1")] + pub stage_proto: Bytes, /// The index to the task within the stage that we want to execute #[prost(uint64, tag = "2")] pub target_task_index: u64, @@ -79,7 +80,7 @@ impl ArrowFlightEndpoint { let stage_data = once .get_or_try_init(|| async { - let stage_proto = doget.stage_proto.ok_or_else(missing("stage_proto"))?; + let stage_proto = doget.stage_proto; let stage = stage_from_proto(stage_proto, &session_state, &self.runtime, &codec) .map_err(|err| { Status::invalid_argument(format!("Cannot decode stage proto: {err}")) @@ -219,7 +220,7 @@ mod tests { let stage_proto = stage_proto_for_closure.clone(); // Create DoGet message let doget = DoGet { - stage_proto: Some(stage_proto), + stage_proto: stage_proto.encode_to_vec().into(), target_task_index: task_number, target_partition: partition, stage_key: Some(stage_key), diff --git a/src/protobuf/mod.rs b/src/protobuf/mod.rs index ada13de..15a8c63 100644 --- a/src/protobuf/mod.rs +++ b/src/protobuf/mod.rs @@ -11,7 +11,9 @@ pub(crate) use errors::{ datafusion_error_to_tonic_status, map_flight_to_datafusion_error, map_status_to_datafusion_error, }; -pub(crate) use stage_proto::{StageExecProto, StageKey, proto_from_stage, stage_from_proto}; +#[cfg(test)] +pub(crate) use stage_proto::proto_from_stage; +pub(crate) use stage_proto::{StageKey, proto_from_input_stage, stage_from_proto}; pub(crate) use user_codec::{ get_distributed_user_codecs, set_distributed_user_codec, set_distributed_user_codec_arc, }; diff --git a/src/protobuf/stage_proto.rs b/src/protobuf/stage_proto.rs index 6b03d45..06911f3 100644 --- a/src/protobuf/stage_proto.rs +++ b/src/protobuf/stage_proto.rs @@ -1,16 +1,17 @@ -use crate::execution_plans::{ExecutionTask, StageExec}; +use crate::execution_plans::{ExecutionTask, InputStage, StageExec}; +use bytes::Bytes; +use datafusion::common::exec_err; use datafusion::{ common::internal_datafusion_err, error::{DataFusionError, Result}, execution::{FunctionRegistry, runtime_env::RuntimeEnv}, - physical_plan::ExecutionPlan, }; use datafusion_proto::{ physical_plan::{AsExecutionPlan, PhysicalExtensionCodec}, protobuf::PhysicalPlanNode, }; +use prost::Message; use std::fmt::Display; -use std::sync::Arc; use url::Url; /// A key that uniquely identifies a stage in a query @@ -53,13 +54,23 @@ pub struct StageExecProto { plan: Option>, /// The input stages to this stage #[prost(repeated, message, tag = "5")] - inputs: Vec, + inputs: Vec, /// Our tasks which tell us how finely grained to execute the partitions in /// the plan #[prost(message, repeated, tag = "6")] tasks: Vec, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct StageInputProto { + #[prost(uint64, tag = "1")] + num: u64, + #[prost(message, repeated, tag = "2")] + tasks: Vec, + #[prost(bytes, tag = "3")] + stage: Bytes, +} + #[derive(Clone, PartialEq, ::prost::Message)] pub struct ExecutionTaskProto { /// The url of the worker that will execute this task. A None value is interpreted as @@ -68,14 +79,64 @@ pub struct ExecutionTaskProto { url_str: Option, } -pub fn proto_from_stage( +fn encode_tasks(tasks: &[ExecutionTask]) -> Vec { + tasks + .iter() + .map(|task| ExecutionTaskProto { + url_str: task.url.as_ref().map(|v| v.to_string()), + }) + .collect() +} + +/// Encodes an [InputStage] as protobuf [Bytes]: +/// - If the input is [InputStage::Decoded], it will serialize the inner plan as protobuf bytes. +/// - If the input is [InputStage::Encoded], it will pass through the [Bytes] in a zero-copy manner. +pub(crate) fn proto_from_input_stage( + input_stage: &InputStage, + codec: &dyn PhysicalExtensionCodec, +) -> Result { + match input_stage { + InputStage::Decoded(v) => { + let stage = StageExec::from_dyn(v); + Ok(proto_from_stage(stage, codec)?.encode_to_vec().into()) + } + InputStage::Encoded { proto, .. } => Ok(proto.clone()), + } +} + +/// Converts a [StageExec] into a [StageExecProto], which makes it suitable to be serialized and +/// sent over the wire. +/// +/// If the input [InputStage]s of the provided [StageExec] are already encoded as protobuf [Bytes], +/// they will not be decoded and re-encoded, the [Bytes] are just passthrough as-is in a zero copy +/// manner. +pub(crate) fn proto_from_stage( stage: &StageExec, codec: &dyn PhysicalExtensionCodec, ) -> Result { let proto_plan = PhysicalPlanNode::try_from_physical_plan(stage.plan.clone(), codec)?; let inputs = stage - .child_stages_iter() - .map(|s| proto_from_stage(s, codec)) + .input_stages_iter() + .map(|s| match s { + InputStage::Decoded(s) => { + let Some(s) = s.as_any().downcast_ref::() else { + return exec_err!( + "Programming error: StageExec input must always be other StageExec" + ); + }; + + Ok(StageInputProto { + num: s.num as u64, + tasks: encode_tasks(&s.tasks), + stage: proto_from_stage(s, codec)?.encode_to_vec().into(), + }) + } + InputStage::Encoded { num, tasks, proto } => Ok(StageInputProto { + num: *num as u64, + tasks: encode_tasks(tasks), + stage: proto.clone(), + }), + }) .collect::>>()?; Ok(StageExecProto { @@ -84,22 +145,39 @@ pub fn proto_from_stage( name: stage.name(), plan: Some(Box::new(proto_plan)), inputs, - tasks: stage - .tasks - .iter() - .map(|task| ExecutionTaskProto { - url_str: task.url.as_ref().map(|v| v.to_string()), - }) - .collect(), + tasks: encode_tasks(&stage.tasks), }) } -pub fn stage_from_proto( - msg: StageExecProto, +/// Decodes the provided protobuf [Bytes] as a [StageExec]. Rather than recursively decoding all the +/// input [InputStage]s, it performs a shallow decoding of just the first [StageExec] level, leaving +/// all the inputs in [InputStage::Encoded] state. +/// +/// This prevents decoding and then re-encoding the whole plan recursively, and only decodes the +/// things that are strictly needed. +pub(crate) fn stage_from_proto( + msg: Bytes, registry: &dyn FunctionRegistry, runtime: &RuntimeEnv, codec: &dyn PhysicalExtensionCodec, ) -> Result { + fn decode_tasks(tasks: Vec) -> Result> { + tasks + .into_iter() + .map(|task| { + Ok(ExecutionTask { + url: task + .url_str + .map(|u| { + Url::parse(&u).map_err(|_| internal_datafusion_err!("Invalid URL: {u}")) + }) + .transpose()?, + }) + }) + .collect() + } + let msg = StageExecProto::decode(msg) + .map_err(|e| internal_datafusion_err!("Cannot decode StageExecProto: {e}"))?; let plan_node = msg.plan.ok_or(internal_datafusion_err!( "ExecutionStageMsg is missing the plan" ))?; @@ -110,8 +188,11 @@ pub fn stage_from_proto( .inputs .into_iter() .map(|s| { - stage_from_proto(s, registry, runtime, codec) - .map(|s| Arc::new(s) as Arc) + Ok(InputStage::Encoded { + num: s.num as usize, + tasks: decode_tasks(s.tasks)?, + proto: s.stage, + }) }) .collect::>>()?; @@ -124,20 +205,7 @@ pub fn stage_from_proto( name: msg.name, plan, inputs, - tasks: msg - .tasks - .into_iter() - .map(|task| { - Ok(ExecutionTask { - url: task - .url_str - .map(|u| { - Url::parse(&u).map_err(|_| internal_datafusion_err!("Invalid URL: {u}")) - }) - .transpose()?, - }) - }) - .collect::>>()?, + tasks: decode_tasks(msg.tasks)?, depth: 0, }) } @@ -148,7 +216,6 @@ mod tests { use std::sync::Arc; use crate::StageExec; - use crate::protobuf::stage_proto::StageExecProto; use crate::protobuf::{proto_from_stage, stage_from_proto}; use datafusion::{ arrow::{ @@ -217,13 +284,9 @@ mod tests { .encode(&mut buf) .map_err(|e| internal_datafusion_err!("couldn't encode {e:#?}"))?; - // Deserialize from bytes - let decoded_msg = StageExecProto::decode(&buf[..]) - .map_err(|e| internal_datafusion_err!("couldn't decode {e:#?}"))?; - // Convert back to ExecutionStage let round_trip_stage = stage_from_proto( - decoded_msg, + buf.into(), &ctx, ctx.runtime_env().as_ref(), &DefaultPhysicalExtensionCodec {},