Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 83 additions & 55 deletions src/common/composed_extension_codec.rs
Original file line number Diff line number Diff line change
@@ -1,56 +1,96 @@
use datafusion::common::not_impl_err;
use datafusion::common::internal_datafusion_err;
use datafusion::error::DataFusionError;
use datafusion::error::Result;
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{AggregateUDF, ScalarUDF};
use datafusion::physical_plan::ExecutionPlan;
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
use prost::Message;
use std::fmt::Debug;
use std::sync::Arc;
// Code taken from https://github.com/apache/datafusion/blob/10f41887fa40d7d425c19b07857f80115460a98e/datafusion/proto/src/physical_plan/mod.rs
// TODO: It's not yet on DF 49, once upgrading to DF 50 we can remove this

// Idea taken from
// https://github.com/apache/datafusion/blob/0eebc0c7c0ffcd1514f5c6d0f8e2b6d0c69a07f5/datafusion-examples/examples/composed_extension_codec.rs#L236-L291
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code was previously taken from an example upstream. However, in version 50, there's a better version of this same code living in the datafusion-proto package, so I took it.

/// DataEncoderTuple captures the position of the encoder
/// in the codec list that was used to encode the data and actual encoded data
#[derive(Clone, PartialEq, prost::Message)]
struct DataEncoderTuple {
/// The position of encoder used to encode data
/// (to be used for decoding)
#[prost(uint32, tag = 1)]
pub encoder_position: u32,

/// A [PhysicalExtensionCodec] that holds multiple [PhysicalExtensionCodec] and tries them
/// sequentially until one works.
#[derive(Debug, Clone, Default)]
pub(crate) struct ComposedPhysicalExtensionCodec {
#[prost(bytes, tag = 2)]
pub blob: Vec<u8>,
}

/// A PhysicalExtensionCodec that tries one of multiple inner codecs
/// until one works
#[derive(Debug)]
pub struct ComposedPhysicalExtensionCodec {
codecs: Vec<Arc<dyn PhysicalExtensionCodec>>,
}

impl ComposedPhysicalExtensionCodec {
/// Adds a new [PhysicalExtensionCodec] to the list. These codecs will be tried
/// sequentially until one works.
pub(crate) fn push(&mut self, codec: impl PhysicalExtensionCodec + 'static) {
self.codecs.push(Arc::new(codec));
// Position in this codecs list is important as it will be used for decoding.
// If new codec is added it should go to last position.
pub fn new(codecs: Vec<Arc<dyn PhysicalExtensionCodec>>) -> Self {
Self { codecs }
}

/// Adds a new [PhysicalExtensionCodec] to the list. These codecs will be tried
/// sequentially until one works.
pub(crate) fn push_arc(&mut self, codec: Arc<dyn PhysicalExtensionCodec>) {
self.codecs.push(codec);
fn decode_protobuf<R>(
&self,
buf: &[u8],
decode: impl FnOnce(&dyn PhysicalExtensionCodec, &[u8]) -> Result<R, DataFusionError>,
) -> Result<R, DataFusionError> {
let proto =
DataEncoderTuple::decode(buf).map_err(|e| DataFusionError::Internal(e.to_string()))?;

let pos = proto.encoder_position as usize;
let codec = self.codecs.get(pos).ok_or_else(|| {
internal_datafusion_err!(
"Can't find required codec in position {pos} in codec list with {} elements",
self.codecs.len()
)
})?;

decode(codec.as_ref(), &proto.blob)
}

fn try_any<T>(
fn encode_protobuf(
&self,
mut f: impl FnMut(&dyn PhysicalExtensionCodec) -> Result<T, DataFusionError>,
) -> Result<T, DataFusionError> {
let mut errs = vec![];
for codec in &self.codecs {
match f(codec.as_ref()) {
Ok(node) => return Ok(node),
Err(err) => errs.push(err),
buf: &mut Vec<u8>,
mut encode: impl FnMut(&dyn PhysicalExtensionCodec, &mut Vec<u8>) -> Result<()>,
) -> Result<(), DataFusionError> {
let mut data = vec![];
let mut last_err = None;
let mut encoder_position = None;

// find the encoder
for (position, codec) in self.codecs.iter().enumerate() {
match encode(codec.as_ref(), &mut data) {
Ok(_) => {
encoder_position = Some(position as u32);
break;
}
Err(err) => last_err = Some(err),
}
}

if errs.is_empty() {
return not_impl_err!("Empty list of composed codecs");
}
let encoder_position = encoder_position.ok_or_else(|| {
last_err.unwrap_or_else(|| {
DataFusionError::NotImplemented("Empty list of composed codecs".to_owned())
})
})?;

let mut msg = "None of the provided PhysicalExtensionCodec worked:".to_string();
for err in &errs {
msg += &format!("\n {err}");
}
not_impl_err!("{msg}")
// encode with encoder position
let proto = DataEncoderTuple {
encoder_position,
blob: data,
};
proto
.encode(buf)
.map_err(|e| DataFusionError::Internal(e.to_string()))
}
}

Expand All @@ -60,39 +100,27 @@ impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec {
buf: &[u8],
inputs: &[Arc<dyn ExecutionPlan>],
registry: &dyn FunctionRegistry,
) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
self.try_any(|codec| codec.try_decode(buf, inputs, registry))
) -> Result<Arc<dyn ExecutionPlan>> {
self.decode_protobuf(buf, |codec, data| codec.try_decode(data, inputs, registry))
}

fn try_encode(
&self,
node: Arc<dyn ExecutionPlan>,
buf: &mut Vec<u8>,
) -> Result<(), DataFusionError> {
self.try_any(|codec| codec.try_encode(node.clone(), buf))
fn try_encode(&self, node: Arc<dyn ExecutionPlan>, buf: &mut Vec<u8>) -> Result<()> {
self.encode_protobuf(buf, |codec, data| codec.try_encode(Arc::clone(&node), data))
}

fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result<Arc<ScalarUDF>, DataFusionError> {
self.try_any(|codec| codec.try_decode_udf(name, buf))
fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result<Arc<ScalarUDF>> {
self.decode_protobuf(buf, |codec, data| codec.try_decode_udf(name, data))
}

fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec<u8>) -> Result<(), DataFusionError> {
self.try_any(|codec| codec.try_encode_udf(node, buf))
fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec<u8>) -> Result<()> {
self.encode_protobuf(buf, |codec, data| codec.try_encode_udf(node, data))
}

fn try_decode_udaf(
&self,
name: &str,
buf: &[u8],
) -> Result<Arc<AggregateUDF>, DataFusionError> {
self.try_any(|codec| codec.try_decode_udaf(name, buf))
fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result<Arc<AggregateUDF>> {
self.decode_protobuf(buf, |codec, data| codec.try_decode_udaf(name, data))
}

fn try_encode_udaf(
&self,
node: &AggregateUDF,
buf: &mut Vec<u8>,
) -> Result<(), DataFusionError> {
self.try_any(|codec| codec.try_encode_udaf(node, buf))
fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec<u8>) -> Result<()> {
self.encode_protobuf(buf, |codec, data| codec.try_encode_udaf(node, data))
}
}
35 changes: 33 additions & 2 deletions src/distributed_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ use crate::channel_resolver_ext::set_distributed_channel_resolver;
use crate::config_extension_ext::{
set_distributed_option_extension, set_distributed_option_extension_from_headers,
};
use crate::protobuf::set_distributed_user_codec;
use crate::protobuf::{set_distributed_user_codec, set_distributed_user_codec_arc};
use datafusion::common::DataFusionError;
use datafusion::config::ConfigExtension;
use datafusion::execution::{SessionState, SessionStateBuilder};
use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
use delegate::delegate;
use http::HeaderMap;
use std::sync::Arc;

/// Extends DataFusion with distributed capabilities.
pub trait DistributedExt: Sized {
Expand Down Expand Up @@ -125,7 +126,8 @@ pub trait DistributedExt: Sized {
) -> Result<(), DataFusionError>;

/// Injects a user-defined [PhysicalExtensionCodec] that is capable of encoding/decoding
/// custom execution nodes.
/// custom execution nodes. Multiple user-defined [PhysicalExtensionCodec] can be added
/// by calling this method several times.
///
/// Example:
///
Expand Down Expand Up @@ -166,6 +168,12 @@ pub trait DistributedExt: Sized {
/// Same as [DistributedExt::with_distributed_user_codec] but with an in-place mutation
fn set_distributed_user_codec<T: PhysicalExtensionCodec + 'static>(&mut self, codec: T);

/// Same as [DistributedExt::with_distributed_user_codec] but with a dynamic argument.
fn with_distributed_user_codec_arc(self, codec: Arc<dyn PhysicalExtensionCodec>) -> Self;

/// Same as [DistributedExt::set_distributed_user_codec] but with a dynamic argument.
fn set_distributed_user_codec_arc(&mut self, codec: Arc<dyn PhysicalExtensionCodec>);

/// Injects a [ChannelResolver] implementation for Distributed DataFusion to resolve worker
/// nodes. When running in distributed mode, setting a [ChannelResolver] is required.
///
Expand Down Expand Up @@ -233,6 +241,10 @@ impl DistributedExt for SessionConfig {
set_distributed_user_codec(self, codec)
}

fn set_distributed_user_codec_arc(&mut self, codec: Arc<dyn PhysicalExtensionCodec>) {
set_distributed_user_codec_arc(self, codec)
}

fn set_distributed_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(
&mut self,
resolver: T,
Expand All @@ -254,6 +266,10 @@ impl DistributedExt for SessionConfig {
#[expr($;self)]
fn with_distributed_user_codec<T: PhysicalExtensionCodec + 'static>(mut self, codec: T) -> Self;

#[call(set_distributed_user_codec_arc)]
#[expr($;self)]
fn with_distributed_user_codec_arc(mut self, codec: Arc<dyn PhysicalExtensionCodec>) -> Self;

#[call(set_distributed_channel_resolver)]
#[expr($;self)]
fn with_distributed_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(mut self, resolver: T) -> Self;
Expand All @@ -279,6 +295,11 @@ impl DistributedExt for SessionStateBuilder {
#[expr($;self)]
fn with_distributed_user_codec<T: PhysicalExtensionCodec + 'static>(mut self, codec: T) -> Self;

fn set_distributed_user_codec_arc(&mut self, codec: Arc<dyn PhysicalExtensionCodec>);
#[call(set_distributed_user_codec_arc)]
#[expr($;self)]
fn with_distributed_user_codec_arc(mut self, codec: Arc<dyn PhysicalExtensionCodec>) -> Self;

fn set_distributed_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(&mut self, resolver: T);
#[call(set_distributed_channel_resolver)]
#[expr($;self)]
Expand All @@ -305,6 +326,11 @@ impl DistributedExt for SessionState {
#[expr($;self)]
fn with_distributed_user_codec<T: PhysicalExtensionCodec + 'static>(mut self, codec: T) -> Self;

fn set_distributed_user_codec_arc(&mut self, codec: Arc<dyn PhysicalExtensionCodec>);
#[call(set_distributed_user_codec_arc)]
#[expr($;self)]
fn with_distributed_user_codec_arc(mut self, codec: Arc<dyn PhysicalExtensionCodec>) -> Self;

fn set_distributed_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(&mut self, resolver: T);
#[call(set_distributed_channel_resolver)]
#[expr($;self)]
Expand All @@ -331,6 +357,11 @@ impl DistributedExt for SessionContext {
#[expr($;self)]
fn with_distributed_user_codec<T: PhysicalExtensionCodec + 'static>(self, codec: T) -> Self;

fn set_distributed_user_codec_arc(&mut self, codec: Arc<dyn PhysicalExtensionCodec>);
#[call(set_distributed_user_codec_arc)]
#[expr($;self)]
fn with_distributed_user_codec_arc(self, codec: Arc<dyn PhysicalExtensionCodec>) -> Self;

fn set_distributed_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(&mut self, resolver: T);
#[call(set_distributed_channel_resolver)]
#[expr($;self)]
Expand Down
5 changes: 4 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ pub use channel_resolver_ext::{BoxCloneSyncChannel, ChannelResolver};
pub use distributed_ext::DistributedExt;
pub use distributed_physical_optimizer_rule::DistributedPhysicalOptimizerRule;
pub use execution_plans::display_plan_graphviz;
pub use execution_plans::{ExecutionTask, NetworkShuffleExec, PartitionIsolatorExec, StageExec};
pub use execution_plans::{
DistributedTaskContext, ExecutionTask, NetworkCoalesceExec, NetworkShuffleExec,
PartitionIsolatorExec, StageExec,
};
pub use flight_service::{
ArrowFlightEndpoint, DefaultSessionBuilder, DistributedSessionBuilder,
DistributedSessionBuilderContext, MappedDistributedSessionBuilder,
Expand Down
11 changes: 4 additions & 7 deletions src/protobuf/distributed_codec.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::get_distributed_user_codec;
use super::get_distributed_user_codecs;
use crate::common::ComposedPhysicalExtensionCodec;
use crate::execution_plans::{NetworkCoalesceExec, NetworkCoalesceReady, NetworkShuffleReadyExec};
use crate::{NetworkShuffleExec, PartitionIsolatorExec};
Expand All @@ -24,12 +24,9 @@ pub struct DistributedCodec;

impl DistributedCodec {
pub fn new_combined_with_user(cfg: &SessionConfig) -> impl PhysicalExtensionCodec + use<> {
let mut combined_codec = ComposedPhysicalExtensionCodec::default();
combined_codec.push(DistributedCodec {});
if let Some(ref user_codec) = get_distributed_user_codec(cfg) {
combined_codec.push_arc(Arc::clone(user_codec));
}
combined_codec
let mut codecs: Vec<Arc<dyn PhysicalExtensionCodec>> = vec![Arc::new(DistributedCodec {})];
codecs.extend(get_distributed_user_codecs(cfg));
ComposedPhysicalExtensionCodec::new(codecs)
Comment on lines +27 to +29
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now users can provide more than 1 codec

}
}

Expand Down
4 changes: 3 additions & 1 deletion src/protobuf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@ pub(crate) use errors::{
map_status_to_datafusion_error,
};
pub(crate) use stage_proto::{StageExecProto, StageKey, proto_from_stage, stage_from_proto};
pub(crate) use user_codec::{get_distributed_user_codec, set_distributed_user_codec};
pub(crate) use user_codec::{
get_distributed_user_codecs, set_distributed_user_codec, set_distributed_user_codec_arc,
};
25 changes: 20 additions & 5 deletions src/protobuf/user_codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,32 @@ use datafusion::prelude::SessionConfig;
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
use std::sync::Arc;

pub struct UserProvidedCodec(Arc<dyn PhysicalExtensionCodec>);
pub struct UserProvidedCodecs(Vec<Arc<dyn PhysicalExtensionCodec>>);

pub(crate) fn set_distributed_user_codec_arc(
cfg: &mut SessionConfig,
codec: Arc<dyn PhysicalExtensionCodec>,
) {
let mut codecs = match cfg.get_extension::<UserProvidedCodecs>() {
None => vec![],
Some(prev) => prev.0.clone(),
};
codecs.push(codec);
cfg.set_extension(Arc::new(UserProvidedCodecs(codecs)))
}

pub(crate) fn set_distributed_user_codec<T: PhysicalExtensionCodec + 'static>(
cfg: &mut SessionConfig,
codec: T,
) {
cfg.set_extension(Arc::new(UserProvidedCodec(Arc::new(codec))))
set_distributed_user_codec_arc(cfg, Arc::new(codec))
}

pub(crate) fn get_distributed_user_codec(
pub(crate) fn get_distributed_user_codecs(
cfg: &SessionConfig,
) -> Option<Arc<dyn PhysicalExtensionCodec>> {
Some(Arc::clone(&cfg.get_extension::<UserProvidedCodec>()?.0))
) -> Vec<Arc<dyn PhysicalExtensionCodec>> {
match cfg.get_extension::<UserProvidedCodecs>() {
None => vec![],
Some(v) => v.0.clone(),
}
}