Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 4 additions & 37 deletions src/composed_extension_codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use datafusion::error::DataFusionError;
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{AggregateUDF, ScalarUDF};
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionConfig;
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
use std::fmt::Debug;
use std::sync::Arc;
Expand All @@ -25,42 +24,10 @@ impl ComposedPhysicalExtensionCodec {
self.codecs.push(Arc::new(codec));
}

/// Adds a new [PhysicalExtensionCodec] from DataFusion's [SessionConfig] extensions.
///
/// If users have a custom [PhysicalExtensionCodec] for their own nodes, they should
/// populate the config extensions with a [PhysicalExtensionCodec] so that we can use
/// it while encoding/decoding plans to/from protobuf.
///
/// Example:
/// ```rust
/// # use std::sync::Arc;
/// # use datafusion::execution::FunctionRegistry;
/// # use datafusion::physical_plan::ExecutionPlan;
/// # use datafusion::prelude::SessionConfig;
/// # use datafusion_proto::physical_plan::PhysicalExtensionCodec;
///
/// #[derive(Debug)]
/// struct CustomUserCodec {}
///
/// impl PhysicalExtensionCodec for CustomUserCodec {
/// fn try_decode(&self, buf: &[u8], inputs: &[Arc<dyn ExecutionPlan>], registry: &dyn FunctionRegistry) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
/// todo!()
/// }
///
/// fn try_encode(&self, node: Arc<dyn ExecutionPlan>, buf: &mut Vec<u8>) -> datafusion::common::Result<()> {
/// todo!()
/// }
/// }
///
/// let mut config = SessionConfig::new();
///
/// let codec: Arc<dyn PhysicalExtensionCodec> = Arc::new(CustomUserCodec {});
/// config.set_extension(Arc::new(codec));
/// ```
pub(crate) fn push_from_config(&mut self, config: &SessionConfig) {
if let Some(user_codec) = config.get_extension::<Arc<dyn PhysicalExtensionCodec>>() {
self.codecs.push(user_codec.as_ref().clone());
}
/// Adds a new [PhysicalExtensionCodec] to the list. These codecs will be tried
/// sequentially until one works.
pub(crate) fn push_arc(&mut self, codec: Arc<dyn PhysicalExtensionCodec>) {
self.codecs.push(codec);
}

fn try_any<T>(
Expand Down
28 changes: 18 additions & 10 deletions src/flight_service/do_get.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
use crate::composed_extension_codec::ComposedPhysicalExtensionCodec;
use crate::errors::datafusion_error_to_tonic_status;
use crate::flight_service::service::ArrowFlightEndpoint;
use crate::plan::DistributedCodec;
use crate::stage::{stage_from_proto, ExecutionStageProto};
use crate::user_provided_codec::get_user_codec;
use arrow_flight::encode::FlightDataEncoderBuilder;
use arrow_flight::error::FlightError;
use arrow_flight::flight_service_server::FlightService;
use arrow_flight::Ticket;
use datafusion::execution::SessionStateBuilder;
use datafusion::optimizer::OptimizerConfig;
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
use futures::TryStreamExt;
use prost::Message;
use std::sync::Arc;
Expand Down Expand Up @@ -48,25 +49,32 @@ impl ArrowFlightEndpoint {
"FunctionRegistry not present in newly built SessionState",
))?;

let codec = DistributedCodec {};
let codec = Arc::new(codec) as Arc<dyn PhysicalExtensionCodec>;
let mut combined_codec = ComposedPhysicalExtensionCodec::default();
combined_codec.push(DistributedCodec);
if let Some(ref user_codec) = get_user_codec(state.config()) {
combined_codec.push_arc(Arc::clone(&user_codec));
}

let stage = stage_from_proto(stage_msg, function_registry, &self.runtime.as_ref(), codec)
.map(Arc::new)
.map_err(|err| Status::invalid_argument(format!("Cannot decode stage proto: {err}")))?;
let mut stage = stage_from_proto(
stage_msg,
function_registry,
&self.runtime.as_ref(),
&combined_codec,
)
.map_err(|err| Status::invalid_argument(format!("Cannot decode stage proto: {err}")))?;
let inner_plan = Arc::clone(&stage.plan);

// Add the extensions that might be required for ExecutionPlan nodes in the plan
let config = state.config_mut();
config.set_extension(Arc::clone(&self.channel_manager));
config.set_extension(stage.clone());
config.set_extension(Arc::new(stage));

let stream = stage
.plan
let stream = inner_plan
.execute(doget.partition as usize, state.task_ctx())
.map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?;

let flight_data_stream = FlightDataEncoderBuilder::new()
.with_schema(stage.plan.schema().clone())
.with_schema(inner_plan.schema().clone())
.build(stream.map_err(|err| {
FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err)))
}));
Expand Down
8 changes: 3 additions & 5 deletions src/flight_service/session_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub trait SessionBuilder {
/// # use datafusion::execution::{FunctionRegistry, SessionStateBuilder};
/// # use datafusion::physical_plan::ExecutionPlan;
/// # use datafusion_proto::physical_plan::PhysicalExtensionCodec;
/// # use datafusion_distributed::{SessionBuilder};
/// # use datafusion_distributed::{with_user_codec, SessionBuilder};
///
/// #[derive(Debug)]
/// struct CustomExecCodec;
Expand All @@ -35,10 +35,8 @@ pub trait SessionBuilder {
/// struct CustomSessionBuilder;
/// impl SessionBuilder for CustomSessionBuilder {
/// fn on_new_session(&self, mut builder: SessionStateBuilder) -> SessionStateBuilder {
/// let config = builder.config().get_or_insert_default();
/// let codec: Arc<dyn PhysicalExtensionCodec> = Arc::new(CustomExecCodec);
/// config.set_extension(Arc::new(codec));
/// builder
/// // Add your UDFs, optimization rules, etc...
/// with_user_codec(builder, CustomExecCodec)
/// }
/// }
/// ```
Expand Down
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ mod test_utils;
pub mod physical_optimizer;
pub mod stage;
pub mod task;
mod user_provided_codec;

pub use channel_manager::{BoxCloneSyncChannel, ChannelManager, ChannelResolver};
pub use flight_service::{ArrowFlightEndpoint, SessionBuilder};
pub use plan::ArrowFlightReadExec;
pub use user_provided_codec::{add_user_codec, with_user_codec};
18 changes: 1 addition & 17 deletions src/physical_optimizer.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::sync::Arc;

use super::stage::ExecutionStage;
use crate::{plan::PartitionIsolatorExec, ArrowFlightReadExec};
use datafusion::common::tree_node::TreeNodeRecursion;
use datafusion::error::DataFusionError;
Expand All @@ -15,15 +16,9 @@ use datafusion::{
displayable, repartition::RepartitionExec, ExecutionPlan, ExecutionPlanProperties,
},
};
use datafusion_proto::physical_plan::PhysicalExtensionCodec;

use super::stage::ExecutionStage;

#[derive(Debug, Default)]
pub struct DistributedPhysicalOptimizerRule {
/// Optional codec to assist in serializing and deserializing any custom
/// ExecutionPlan nodes
codec: Option<Arc<dyn PhysicalExtensionCodec>>,
/// maximum number of partitions per task. This is used to determine how many
/// tasks to create for each stage
partitions_per_task: Option<usize>,
Expand All @@ -32,18 +27,10 @@ pub struct DistributedPhysicalOptimizerRule {
impl DistributedPhysicalOptimizerRule {
pub fn new() -> Self {
DistributedPhysicalOptimizerRule {
codec: None,
partitions_per_task: None,
}
}

/// Set a codec to use to assist in serializing and deserializing
/// custom ExecutionPlan nodes.
pub fn with_codec(mut self, codec: Arc<dyn PhysicalExtensionCodec>) -> Self {
self.codec = Some(codec);
self
}

/// Set the maximum number of partitions per task. This is used to determine how many
/// tasks to create for each stage.
///
Expand Down Expand Up @@ -156,9 +143,6 @@ impl DistributedPhysicalOptimizerRule {
if let Some(partitions_per_task) = self.partitions_per_task {
stage = stage.with_maximum_partitions_per_task(partitions_per_task);
}
if let Some(codec) = self.codec.as_ref() {
stage = stage.with_codec(codec.clone());
}
stage.depth = depth;

Ok(stage)
Expand Down
20 changes: 13 additions & 7 deletions src/plan/arrow_flight_read.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use super::combined::CombinedRecordBatchStream;
use crate::channel_manager::ChannelManager;
use crate::composed_extension_codec::ComposedPhysicalExtensionCodec;
use crate::errors::tonic_status_to_datafusion_error;
use crate::flight_service::DoGet;
use crate::stage::{ExecutionStage, ExecutionStageProto};
use crate::plan::DistributedCodec;
use crate::stage::{proto_from_stage, ExecutionStage};
use crate::user_provided_codec::get_user_codec;
use arrow_flight::decode::FlightRecordBatchStream;
use arrow_flight::error::FlightError;
use arrow_flight::flight_service_client::FlightServiceClient;
Expand Down Expand Up @@ -170,12 +173,14 @@ impl ExecutionPlan for ArrowFlightReadExec {
this.stage_num
))?;

let child_stage_tasks = child_stage.tasks.clone();
let child_stage_proto = ExecutionStageProto::try_from(child_stage).map_err(|e| {
internal_datafusion_err!(
"ArrowFlightReadExec: failed to convert stage to proto: {}",
e
)
let mut combined_codec = ComposedPhysicalExtensionCodec::default();
combined_codec.push(DistributedCodec {});
if let Some(ref user_codec) = get_user_codec(context.session_config()) {
combined_codec.push_arc(Arc::clone(user_codec));
}

let child_stage_proto = proto_from_stage(child_stage, &combined_codec).map_err(|e| {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intuively this makes sense but I think I can only fully get the role of combined_codec when I start using it

internal_datafusion_err!("ArrowFlightReadExec: failed to convert stage to proto: {e}")
})?;

let ticket_bytes = DoGet {
Expand All @@ -191,6 +196,7 @@ impl ExecutionPlan for ArrowFlightReadExec {

let schema = child_stage.plan.schema();

let child_stage_tasks = child_stage.tasks.clone();
let stream = async move {
let futs = child_stage_tasks.iter().enumerate().map(|(i, task)| async {
let url = task.url()?.ok_or(internal_datafusion_err!(
Expand Down
2 changes: 1 addition & 1 deletion src/stage/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ mod proto;
mod stage;

pub use display::display_stage_graphviz;
pub use proto::{stage_from_proto, ExecutionStageProto};
pub use proto::{proto_from_stage, stage_from_proto, ExecutionStageProto};
pub use stage::ExecutionStage;
59 changes: 23 additions & 36 deletions src/stage/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use datafusion_proto::{
protobuf::PhysicalPlanNode,
};

use crate::{plan::DistributedCodec, task::ExecutionTask};
use crate::task::ExecutionTask;

use super::ExecutionStage;

Expand All @@ -35,54 +35,42 @@ pub struct ExecutionStageProto {
pub tasks: Vec<ExecutionTask>,
}

impl TryFrom<&ExecutionStage> for ExecutionStageProto {
type Error = DataFusionError;

fn try_from(stage: &ExecutionStage) -> Result<Self, Self::Error> {
let codec = stage.codec.clone().unwrap_or(Arc::new(DistributedCodec {}));

let proto_plan =
PhysicalPlanNode::try_from_physical_plan(stage.plan.clone(), codec.as_ref())?;
let inputs = stage
.child_stages_iter()
.map(|s| Box::new(ExecutionStageProto::try_from(s).unwrap()))
.collect();

Ok(ExecutionStageProto {
num: stage.num as u64,
name: stage.name(),
plan: Some(Box::new(proto_plan)),
inputs,
tasks: stage.tasks.clone(),
})
}
}

impl TryFrom<ExecutionStage> for ExecutionStageProto {
type Error = DataFusionError;
pub fn proto_from_stage(
stage: &ExecutionStage,
codec: &dyn PhysicalExtensionCodec,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As now a codec is needed here, we cannot just implement this in terms of a simple TryFrom implementation

) -> Result<ExecutionStageProto, DataFusionError> {
let proto_plan = PhysicalPlanNode::try_from_physical_plan(stage.plan.clone(), codec)?;
let inputs = stage
.child_stages_iter()
.map(|s| Ok(Box::new(proto_from_stage(s, codec)?)))
.collect::<Result<Vec<_>>>()?;

fn try_from(stage: ExecutionStage) -> Result<Self, Self::Error> {
ExecutionStageProto::try_from(&stage)
}
Ok(ExecutionStageProto {
num: stage.num as u64,
name: stage.name(),
plan: Some(Box::new(proto_plan)),
inputs,
tasks: stage.tasks.clone(),
})
}

pub fn stage_from_proto(
msg: ExecutionStageProto,
registry: &dyn FunctionRegistry,
runtime: &RuntimeEnv,
codec: Arc<dyn PhysicalExtensionCodec>,
codec: &dyn PhysicalExtensionCodec,
) -> Result<ExecutionStage> {
let plan_node = msg.plan.ok_or(internal_datafusion_err!(
"ExecutionStageMsg is missing the plan"
))?;

let plan = plan_node.try_into_physical_plan(registry, runtime, codec.as_ref())?;
let plan = plan_node.try_into_physical_plan(registry, runtime, codec)?;

let inputs = msg
.inputs
.into_iter()
.map(|s| {
stage_from_proto(*s, registry, runtime, codec.clone())
stage_from_proto(*s, registry, runtime, codec)
.map(|s| Arc::new(s) as Arc<dyn ExecutionPlan>)
})
.collect::<Result<Vec<_>>>()?;
Expand All @@ -93,7 +81,6 @@ pub fn stage_from_proto(
plan,
inputs,
tasks: msg.tasks,
codec: Some(codec),
depth: 0,
})
}
Expand All @@ -116,6 +103,7 @@ mod tests {
use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec;
use prost::Message;

use crate::stage::proto::proto_from_stage;
use crate::stage::{proto::stage_from_proto, ExecutionStage, ExecutionStageProto};

// create a simple mem table
Expand Down Expand Up @@ -158,12 +146,11 @@ mod tests {
plan: physical_plan,
inputs: vec![],
tasks: vec![],
codec: Some(Arc::new(DefaultPhysicalExtensionCodec {})),
depth: 0,
};

// Convert to proto message
let stage_msg = ExecutionStageProto::try_from(&stage)?;
let stage_msg = proto_from_stage(&stage, &DefaultPhysicalExtensionCodec {})?;

// Serialize to bytes
let mut buf = Vec::new();
Expand All @@ -180,7 +167,7 @@ mod tests {
decoded_msg,
&ctx,
ctx.runtime_env().as_ref(),
Arc::new(DefaultPhysicalExtensionCodec {}),
&DefaultPhysicalExtensionCodec {},
)?;

// Compare original and round-tripped stages
Expand Down
Loading