diff --git a/src/composed_extension_codec.rs b/src/composed_extension_codec.rs index 33e196c..50ccf1d 100644 --- a/src/composed_extension_codec.rs +++ b/src/composed_extension_codec.rs @@ -3,7 +3,6 @@ use datafusion::error::DataFusionError; use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{AggregateUDF, ScalarUDF}; use datafusion::physical_plan::ExecutionPlan; -use datafusion::prelude::SessionConfig; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use std::fmt::Debug; use std::sync::Arc; @@ -25,42 +24,10 @@ impl ComposedPhysicalExtensionCodec { self.codecs.push(Arc::new(codec)); } - /// Adds a new [PhysicalExtensionCodec] from DataFusion's [SessionConfig] extensions. - /// - /// If users have a custom [PhysicalExtensionCodec] for their own nodes, they should - /// populate the config extensions with a [PhysicalExtensionCodec] so that we can use - /// it while encoding/decoding plans to/from protobuf. - /// - /// Example: - /// ```rust - /// # use std::sync::Arc; - /// # use datafusion::execution::FunctionRegistry; - /// # use datafusion::physical_plan::ExecutionPlan; - /// # use datafusion::prelude::SessionConfig; - /// # use datafusion_proto::physical_plan::PhysicalExtensionCodec; - /// - /// #[derive(Debug)] - /// struct CustomUserCodec {} - /// - /// impl PhysicalExtensionCodec for CustomUserCodec { - /// fn try_decode(&self, buf: &[u8], inputs: &[Arc], registry: &dyn FunctionRegistry) -> datafusion::common::Result> { - /// todo!() - /// } - /// - /// fn try_encode(&self, node: Arc, buf: &mut Vec) -> datafusion::common::Result<()> { - /// todo!() - /// } - /// } - /// - /// let mut config = SessionConfig::new(); - /// - /// let codec: Arc = Arc::new(CustomUserCodec {}); - /// config.set_extension(Arc::new(codec)); - /// ``` - pub(crate) fn push_from_config(&mut self, config: &SessionConfig) { - if let Some(user_codec) = config.get_extension::>() { - self.codecs.push(user_codec.as_ref().clone()); - } + /// Adds a new [PhysicalExtensionCodec] to the list. These codecs will be tried + /// sequentially until one works. + pub(crate) fn push_arc(&mut self, codec: Arc) { + self.codecs.push(codec); } fn try_any( diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index 1631f2a..210934b 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -1,14 +1,15 @@ +use crate::composed_extension_codec::ComposedPhysicalExtensionCodec; use crate::errors::datafusion_error_to_tonic_status; use crate::flight_service::service::ArrowFlightEndpoint; use crate::plan::DistributedCodec; use crate::stage::{stage_from_proto, ExecutionStageProto}; +use crate::user_provided_codec::get_user_codec; use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::error::FlightError; use arrow_flight::flight_service_server::FlightService; use arrow_flight::Ticket; use datafusion::execution::SessionStateBuilder; use datafusion::optimizer::OptimizerConfig; -use datafusion_proto::physical_plan::PhysicalExtensionCodec; use futures::TryStreamExt; use prost::Message; use std::sync::Arc; @@ -48,25 +49,32 @@ impl ArrowFlightEndpoint { "FunctionRegistry not present in newly built SessionState", ))?; - let codec = DistributedCodec {}; - let codec = Arc::new(codec) as Arc; + let mut combined_codec = ComposedPhysicalExtensionCodec::default(); + combined_codec.push(DistributedCodec); + if let Some(ref user_codec) = get_user_codec(state.config()) { + combined_codec.push_arc(Arc::clone(&user_codec)); + } - let stage = stage_from_proto(stage_msg, function_registry, &self.runtime.as_ref(), codec) - .map(Arc::new) - .map_err(|err| Status::invalid_argument(format!("Cannot decode stage proto: {err}")))?; + let mut stage = stage_from_proto( + stage_msg, + function_registry, + &self.runtime.as_ref(), + &combined_codec, + ) + .map_err(|err| Status::invalid_argument(format!("Cannot decode stage proto: {err}")))?; + let inner_plan = Arc::clone(&stage.plan); // Add the extensions that might be required for ExecutionPlan nodes in the plan let config = state.config_mut(); config.set_extension(Arc::clone(&self.channel_manager)); - config.set_extension(stage.clone()); + config.set_extension(Arc::new(stage)); - let stream = stage - .plan + let stream = inner_plan .execute(doget.partition as usize, state.task_ctx()) .map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?; let flight_data_stream = FlightDataEncoderBuilder::new() - .with_schema(stage.plan.schema().clone()) + .with_schema(inner_plan.schema().clone()) .build(stream.map_err(|err| { FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err))) })); diff --git a/src/flight_service/session_builder.rs b/src/flight_service/session_builder.rs index f9950a0..eb6d06f 100644 --- a/src/flight_service/session_builder.rs +++ b/src/flight_service/session_builder.rs @@ -16,7 +16,7 @@ pub trait SessionBuilder { /// # use datafusion::execution::{FunctionRegistry, SessionStateBuilder}; /// # use datafusion::physical_plan::ExecutionPlan; /// # use datafusion_proto::physical_plan::PhysicalExtensionCodec; - /// # use datafusion_distributed::{SessionBuilder}; + /// # use datafusion_distributed::{with_user_codec, SessionBuilder}; /// /// #[derive(Debug)] /// struct CustomExecCodec; @@ -35,10 +35,8 @@ pub trait SessionBuilder { /// struct CustomSessionBuilder; /// impl SessionBuilder for CustomSessionBuilder { /// fn on_new_session(&self, mut builder: SessionStateBuilder) -> SessionStateBuilder { - /// let config = builder.config().get_or_insert_default(); - /// let codec: Arc = Arc::new(CustomExecCodec); - /// config.set_extension(Arc::new(codec)); - /// builder + /// // Add your UDFs, optimization rules, etc... + /// with_user_codec(builder, CustomExecCodec) /// } /// } /// ``` diff --git a/src/lib.rs b/src/lib.rs index 65faa4f..8256e86 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,9 @@ mod test_utils; pub mod physical_optimizer; pub mod stage; pub mod task; +mod user_provided_codec; + pub use channel_manager::{BoxCloneSyncChannel, ChannelManager, ChannelResolver}; pub use flight_service::{ArrowFlightEndpoint, SessionBuilder}; pub use plan::ArrowFlightReadExec; +pub use user_provided_codec::{add_user_codec, with_user_codec}; diff --git a/src/physical_optimizer.rs b/src/physical_optimizer.rs index 54c460e..932a8ce 100644 --- a/src/physical_optimizer.rs +++ b/src/physical_optimizer.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use super::stage::ExecutionStage; use crate::{plan::PartitionIsolatorExec, ArrowFlightReadExec}; use datafusion::common::tree_node::TreeNodeRecursion; use datafusion::error::DataFusionError; @@ -15,15 +16,9 @@ use datafusion::{ displayable, repartition::RepartitionExec, ExecutionPlan, ExecutionPlanProperties, }, }; -use datafusion_proto::physical_plan::PhysicalExtensionCodec; - -use super::stage::ExecutionStage; #[derive(Debug, Default)] pub struct DistributedPhysicalOptimizerRule { - /// Optional codec to assist in serializing and deserializing any custom - /// ExecutionPlan nodes - codec: Option>, /// maximum number of partitions per task. This is used to determine how many /// tasks to create for each stage partitions_per_task: Option, @@ -32,18 +27,10 @@ pub struct DistributedPhysicalOptimizerRule { impl DistributedPhysicalOptimizerRule { pub fn new() -> Self { DistributedPhysicalOptimizerRule { - codec: None, partitions_per_task: None, } } - /// Set a codec to use to assist in serializing and deserializing - /// custom ExecutionPlan nodes. - pub fn with_codec(mut self, codec: Arc) -> Self { - self.codec = Some(codec); - self - } - /// Set the maximum number of partitions per task. This is used to determine how many /// tasks to create for each stage. /// @@ -156,9 +143,6 @@ impl DistributedPhysicalOptimizerRule { if let Some(partitions_per_task) = self.partitions_per_task { stage = stage.with_maximum_partitions_per_task(partitions_per_task); } - if let Some(codec) = self.codec.as_ref() { - stage = stage.with_codec(codec.clone()); - } stage.depth = depth; Ok(stage) diff --git a/src/plan/arrow_flight_read.rs b/src/plan/arrow_flight_read.rs index 73cb070..dd970f3 100644 --- a/src/plan/arrow_flight_read.rs +++ b/src/plan/arrow_flight_read.rs @@ -1,8 +1,11 @@ use super::combined::CombinedRecordBatchStream; use crate::channel_manager::ChannelManager; +use crate::composed_extension_codec::ComposedPhysicalExtensionCodec; use crate::errors::tonic_status_to_datafusion_error; use crate::flight_service::DoGet; -use crate::stage::{ExecutionStage, ExecutionStageProto}; +use crate::plan::DistributedCodec; +use crate::stage::{proto_from_stage, ExecutionStage}; +use crate::user_provided_codec::get_user_codec; use arrow_flight::decode::FlightRecordBatchStream; use arrow_flight::error::FlightError; use arrow_flight::flight_service_client::FlightServiceClient; @@ -170,12 +173,14 @@ impl ExecutionPlan for ArrowFlightReadExec { this.stage_num ))?; - let child_stage_tasks = child_stage.tasks.clone(); - let child_stage_proto = ExecutionStageProto::try_from(child_stage).map_err(|e| { - internal_datafusion_err!( - "ArrowFlightReadExec: failed to convert stage to proto: {}", - e - ) + let mut combined_codec = ComposedPhysicalExtensionCodec::default(); + combined_codec.push(DistributedCodec {}); + if let Some(ref user_codec) = get_user_codec(context.session_config()) { + combined_codec.push_arc(Arc::clone(user_codec)); + } + + let child_stage_proto = proto_from_stage(child_stage, &combined_codec).map_err(|e| { + internal_datafusion_err!("ArrowFlightReadExec: failed to convert stage to proto: {e}") })?; let ticket_bytes = DoGet { @@ -191,6 +196,7 @@ impl ExecutionPlan for ArrowFlightReadExec { let schema = child_stage.plan.schema(); + let child_stage_tasks = child_stage.tasks.clone(); let stream = async move { let futs = child_stage_tasks.iter().enumerate().map(|(i, task)| async { let url = task.url()?.ok_or(internal_datafusion_err!( diff --git a/src/stage/mod.rs b/src/stage/mod.rs index 7bb6037..034f1a3 100644 --- a/src/stage/mod.rs +++ b/src/stage/mod.rs @@ -3,5 +3,5 @@ mod proto; mod stage; pub use display::display_stage_graphviz; -pub use proto::{stage_from_proto, ExecutionStageProto}; +pub use proto::{proto_from_stage, stage_from_proto, ExecutionStageProto}; pub use stage::ExecutionStage; diff --git a/src/stage/proto.rs b/src/stage/proto.rs index f0b127f..d9a4d32 100644 --- a/src/stage/proto.rs +++ b/src/stage/proto.rs @@ -11,7 +11,7 @@ use datafusion_proto::{ protobuf::PhysicalPlanNode, }; -use crate::{plan::DistributedCodec, task::ExecutionTask}; +use crate::task::ExecutionTask; use super::ExecutionStage; @@ -35,54 +35,42 @@ pub struct ExecutionStageProto { pub tasks: Vec, } -impl TryFrom<&ExecutionStage> for ExecutionStageProto { - type Error = DataFusionError; - - fn try_from(stage: &ExecutionStage) -> Result { - let codec = stage.codec.clone().unwrap_or(Arc::new(DistributedCodec {})); - - let proto_plan = - PhysicalPlanNode::try_from_physical_plan(stage.plan.clone(), codec.as_ref())?; - let inputs = stage - .child_stages_iter() - .map(|s| Box::new(ExecutionStageProto::try_from(s).unwrap())) - .collect(); - - Ok(ExecutionStageProto { - num: stage.num as u64, - name: stage.name(), - plan: Some(Box::new(proto_plan)), - inputs, - tasks: stage.tasks.clone(), - }) - } -} - -impl TryFrom for ExecutionStageProto { - type Error = DataFusionError; +pub fn proto_from_stage( + stage: &ExecutionStage, + codec: &dyn PhysicalExtensionCodec, +) -> Result { + let proto_plan = PhysicalPlanNode::try_from_physical_plan(stage.plan.clone(), codec)?; + let inputs = stage + .child_stages_iter() + .map(|s| Ok(Box::new(proto_from_stage(s, codec)?))) + .collect::>>()?; - fn try_from(stage: ExecutionStage) -> Result { - ExecutionStageProto::try_from(&stage) - } + Ok(ExecutionStageProto { + num: stage.num as u64, + name: stage.name(), + plan: Some(Box::new(proto_plan)), + inputs, + tasks: stage.tasks.clone(), + }) } pub fn stage_from_proto( msg: ExecutionStageProto, registry: &dyn FunctionRegistry, runtime: &RuntimeEnv, - codec: Arc, + codec: &dyn PhysicalExtensionCodec, ) -> Result { let plan_node = msg.plan.ok_or(internal_datafusion_err!( "ExecutionStageMsg is missing the plan" ))?; - let plan = plan_node.try_into_physical_plan(registry, runtime, codec.as_ref())?; + let plan = plan_node.try_into_physical_plan(registry, runtime, codec)?; let inputs = msg .inputs .into_iter() .map(|s| { - stage_from_proto(*s, registry, runtime, codec.clone()) + stage_from_proto(*s, registry, runtime, codec) .map(|s| Arc::new(s) as Arc) }) .collect::>>()?; @@ -93,7 +81,6 @@ pub fn stage_from_proto( plan, inputs, tasks: msg.tasks, - codec: Some(codec), depth: 0, }) } @@ -116,6 +103,7 @@ mod tests { use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; use prost::Message; + use crate::stage::proto::proto_from_stage; use crate::stage::{proto::stage_from_proto, ExecutionStage, ExecutionStageProto}; // create a simple mem table @@ -158,12 +146,11 @@ mod tests { plan: physical_plan, inputs: vec![], tasks: vec![], - codec: Some(Arc::new(DefaultPhysicalExtensionCodec {})), depth: 0, }; // Convert to proto message - let stage_msg = ExecutionStageProto::try_from(&stage)?; + let stage_msg = proto_from_stage(&stage, &DefaultPhysicalExtensionCodec {})?; // Serialize to bytes let mut buf = Vec::new(); @@ -180,7 +167,7 @@ mod tests { decoded_msg, &ctx, ctx.runtime_env().as_ref(), - Arc::new(DefaultPhysicalExtensionCodec {}), + &DefaultPhysicalExtensionCodec {}, )?; // Compare original and round-tripped stages diff --git a/src/stage/stage.rs b/src/stage/stage.rs index b161f03..e17863a 100644 --- a/src/stage/stage.rs +++ b/src/stage/stage.rs @@ -5,7 +5,6 @@ use datafusion::error::{DataFusionError, Result}; use datafusion::execution::TaskContext; use datafusion::physical_plan::{displayable, ExecutionPlan}; use datafusion::prelude::SessionContext; -use datafusion_proto::physical_plan::PhysicalExtensionCodec; use itertools::Itertools; use rand::Rng; @@ -32,8 +31,6 @@ pub struct ExecutionStage { /// Our tasks which tell us how finely grained to execute the partitions in /// the plan pub tasks: Vec, - /// An optional codec to assist in serializing and deserializing this stage - pub codec: Option>, /// tree depth of our location in the stage tree, used for display only pub depth: usize, } @@ -65,7 +62,6 @@ impl ExecutionStage { .map(|s| s as Arc) .collect(), tasks: vec![ExecutionTask::new(partition_group)], - codec: None, depth: 0, } } @@ -93,13 +89,6 @@ impl ExecutionStage { self } - /// Sets the codec for this stage, which is used to serialize and deserialize the plan - /// and its inputs. - pub fn with_codec(mut self, codec: Arc) -> Self { - self.codec = Some(codec); - self - } - /// Returns the name of this stage pub fn name(&self) -> String { format!("Stage {:<3}", self.num) @@ -174,7 +163,6 @@ impl ExecutionStage { plan: self.plan.clone(), inputs: assigned_children, tasks: assigned_tasks, - codec: self.codec.clone(), depth: self.depth, }; @@ -205,7 +193,6 @@ impl ExecutionPlan for ExecutionStage { plan: self.plan.clone(), inputs: children, tasks: self.tasks.clone(), - codec: self.codec.clone(), depth: self.depth, })) } diff --git a/src/user_provided_codec.rs b/src/user_provided_codec.rs new file mode 100644 index 0000000..4881165 --- /dev/null +++ b/src/user_provided_codec.rs @@ -0,0 +1,126 @@ +use datafusion::execution::SessionStateBuilder; +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_proto::physical_plan::PhysicalExtensionCodec; +use std::sync::Arc; + +pub struct UserProvidedCodec(Arc); + +/// Injects a user-defined codec that is capable of encoding/decoding custom execution nodes. +/// It will inject the codec as a config extension in the provided [SessionConfig], [SessionContext] +/// or [SessionStateBuilder]. +/// +/// Example: +/// +/// ``` +/// # use std::sync::Arc; +/// # use datafusion::execution::{SessionState, FunctionRegistry, SessionStateBuilder}; +/// # use datafusion::physical_plan::ExecutionPlan; +/// # use datafusion_proto::physical_plan::PhysicalExtensionCodec; +/// # use datafusion_distributed::{add_user_codec}; +/// +/// #[derive(Debug)] +/// struct CustomExecCodec; +/// +/// impl PhysicalExtensionCodec for CustomExecCodec { +/// fn try_decode(&self, buf: &[u8], inputs: &[Arc], registry: &dyn FunctionRegistry) -> datafusion::common::Result> { +/// todo!() +/// } +/// +/// fn try_encode(&self, node: Arc, buf: &mut Vec) -> datafusion::common::Result<()> { +/// todo!() +/// } +/// } +/// +/// let builder = SessionStateBuilder::new(); +/// let mut state = builder.build(); +/// add_user_codec(state.config_mut(), CustomExecCodec); +/// ``` +#[allow(private_bounds)] +pub fn add_user_codec( + transport: &mut impl UserCodecTransport, + codec: impl PhysicalExtensionCodec + 'static, +) { + transport.set(codec); +} + +/// Adds a user-defined codec that is capable of encoding/decoding custom execution nodes. +/// It returns the [SessionContext], [SessionConfig] or [SessionStateBuilder] passed on the first +/// argument with the user-defined codec already placed into the config extensions. +/// +/// Example: +/// +/// ``` +/// # use std::sync::Arc; +/// # use datafusion::execution::{SessionState, FunctionRegistry, SessionStateBuilder}; +/// # use datafusion::physical_plan::ExecutionPlan; +/// # use datafusion_proto::physical_plan::PhysicalExtensionCodec; +/// # use datafusion_distributed::with_user_codec; +/// +/// #[derive(Debug)] +/// struct CustomExecCodec; +/// +/// impl PhysicalExtensionCodec for CustomExecCodec { +/// fn try_decode(&self, buf: &[u8], inputs: &[Arc], registry: &dyn FunctionRegistry) -> datafusion::common::Result> { +/// todo!() +/// } +/// +/// fn try_encode(&self, node: Arc, buf: &mut Vec) -> datafusion::common::Result<()> { +/// todo!() +/// } +/// } +/// +/// let builder = SessionStateBuilder::new(); +/// let builder = with_user_codec(builder, CustomExecCodec); +/// let state = builder.build(); +/// ``` +#[allow(private_bounds)] +pub fn with_user_codec( + mut transport: T, + codec: impl PhysicalExtensionCodec + 'static, +) -> T { + transport.set(codec); + transport +} + +#[allow(private_bounds)] +pub(crate) fn get_user_codec( + transport: &impl UserCodecTransport, +) -> Option> { + transport.get() +} + +trait UserCodecTransport { + fn set(&mut self, codec: impl PhysicalExtensionCodec + 'static); + fn get(&self) -> Option>; +} + +impl UserCodecTransport for SessionConfig { + fn set(&mut self, codec: impl PhysicalExtensionCodec + 'static) { + self.set_extension(Arc::new(UserProvidedCodec(Arc::new(codec)))); + } + + fn get(&self) -> Option> { + Some(Arc::clone(&self.get_extension::()?.0)) + } +} + +impl UserCodecTransport for SessionContext { + fn set(&mut self, codec: impl PhysicalExtensionCodec + 'static) { + self.state_ref().write().config_mut().set(codec) + } + + fn get(&self) -> Option> { + self.state_ref().read().config().get() + } +} + +impl UserCodecTransport for SessionStateBuilder { + fn set(&mut self, codec: impl PhysicalExtensionCodec + 'static) { + self.config().get_or_insert_default().set(codec); + } + + fn get(&self) -> Option> { + // Nobody will never want to retriever a user codec from a SessionStateBuilder + None + } +} diff --git a/tests/custom_extension_codec.rs b/tests/custom_extension_codec.rs index 012f72c..a69b0ce 100644 --- a/tests/custom_extension_codec.rs +++ b/tests/custom_extension_codec.rs @@ -28,7 +28,9 @@ mod tests { displayable, execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, }; use datafusion_distributed::physical_optimizer::DistributedPhysicalOptimizerRule; - use datafusion_distributed::{ArrowFlightReadExec, SessionBuilder}; + use datafusion_distributed::{ + add_user_codec, with_user_codec, ArrowFlightReadExec, SessionBuilder, + }; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_proto::protobuf::proto_error; use futures::{stream, TryStreamExt}; @@ -43,22 +45,14 @@ mod tests { #[derive(Clone)] struct CustomSessionBuilder; impl SessionBuilder for CustomSessionBuilder { - fn on_new_session(&self, mut builder: SessionStateBuilder) -> SessionStateBuilder { - let codec: Arc = Arc::new(Int64ListExecCodec); - let config = builder.config().get_or_insert_default(); - config.set_extension(Arc::new(codec)); - builder + fn on_new_session(&self, builder: SessionStateBuilder) -> SessionStateBuilder { + with_user_codec(builder, Int64ListExecCodec) } } - let (ctx, _guard) = + let (mut ctx, _guard) = start_localhost_context([50050, 50051, 50052], CustomSessionBuilder).await; - - let codec: Arc = Arc::new(Int64ListExecCodec); - ctx.state_ref() - .write() - .config_mut() - .set_extension(Arc::new(codec)); + add_user_codec(&mut ctx, Int64ListExecCodec); let single_node_plan = build_plan(false)?; assert_snapshot!(displayable(single_node_plan.as_ref()).indent(true).to_string(), @r" @@ -72,13 +66,22 @@ mod tests { DistributedPhysicalOptimizerRule::default().distribute_plan(distributed_plan)?; assert_snapshot!(displayable(&distributed_plan).indent(true).to_string(), @r" - SortExec: expr=[numbers@0 DESC NULLS LAST], preserve_partitioning=[false] - RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=10 - ArrowFlightReadExec: input_tasks=10 hash_expr=[] stage_id=UUID input_stage_id=UUID input_hosts=[http://localhost:50050/, http://localhost:50051/, http://localhost:50052/, http://localhost:50050/, http://localhost:50051/, http://localhost:50052/, http://localhost:50050/, http://localhost:50051/, http://localhost:50052/, http://localhost:50050/] - SortExec: expr=[numbers@0 DESC NULLS LAST], preserve_partitioning=[false] - ArrowFlightReadExec: input_tasks=1 hash_expr=[numbers@0] stage_id=UUID input_stage_id=UUID input_hosts=[http://localhost:50051/] - FilterExec: numbers@0 > 1 - Int64ListExec: length=6 + ┌───── Stage 3 Task: partitions: 0,unassigned] + │partitions [out:1 <-- in:1 ] SortExec: expr=[numbers@0 DESC NULLS LAST], preserve_partitioning=[false] + │partitions [out:1 <-- in:10 ] RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=10 + │partitions [out:10 ] ArrowFlightReadExec: Stage 2 + │ + └────────────────────────────────────────────────── + ┌───── Stage 2 Task: partitions: 0,unassigned] + │partitions [out:1 <-- in:1 ] SortExec: expr=[numbers@0 DESC NULLS LAST], preserve_partitioning=[false] + │partitions [out:1 ] ArrowFlightReadExec: Stage 1 + │ + └────────────────────────────────────────────────── + ┌───── Stage 1 Task: partitions: 0,unassigned] + │partitions [out:1 <-- in:1 ] FilterExec: numbers@0 > 1 + │partitions [out:1 ] Int64ListExec: length=6 + │ + └────────────────────────────────────────────────── "); let stream = execute_stream(single_node_plan, ctx.task_ctx())?; diff --git a/tests/error_propagation.rs b/tests/error_propagation.rs index b559f62..16ff1cb 100644 --- a/tests/error_propagation.rs +++ b/tests/error_propagation.rs @@ -16,7 +16,9 @@ mod tests { execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, }; use datafusion_distributed::physical_optimizer::DistributedPhysicalOptimizerRule; - use datafusion_distributed::{ArrowFlightReadExec, SessionBuilder}; + use datafusion_distributed::{ + add_user_codec, with_user_codec, ArrowFlightReadExec, SessionBuilder, + }; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_proto::protobuf::proto_error; use futures::{stream, TryStreamExt}; @@ -27,26 +29,17 @@ mod tests { use std::sync::Arc; #[tokio::test] - #[ignore] async fn test_error_propagation() -> Result<(), Box> { #[derive(Clone)] struct CustomSessionBuilder; impl SessionBuilder for CustomSessionBuilder { - fn on_new_session(&self, mut builder: SessionStateBuilder) -> SessionStateBuilder { - let codec: Arc = Arc::new(ErrorExecCodec); - let config = builder.config().get_or_insert_default(); - config.set_extension(Arc::new(codec)); - builder + fn on_new_session(&self, builder: SessionStateBuilder) -> SessionStateBuilder { + with_user_codec(builder, ErrorExecCodec) } } - let (ctx, _guard) = + let (mut ctx, _guard) = start_localhost_context([50050, 50051, 50053], CustomSessionBuilder).await; - - let codec: Arc = Arc::new(ErrorExecCodec); - ctx.state_ref() - .write() - .config_mut() - .set_extension(Arc::new(codec)); + add_user_codec(&mut ctx, ErrorExecCodec); let mut plan: Arc = Arc::new(ErrorExec::new("something failed"));