Skip to content

Commit b2a5d0e

Browse files
committed
Address public API weakpoints
1 parent 64920af commit b2a5d0e

File tree

6 files changed

+147
-71
lines changed

6 files changed

+147
-71
lines changed
Lines changed: 83 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,96 @@
1-
use datafusion::common::not_impl_err;
1+
use datafusion::common::internal_datafusion_err;
22
use datafusion::error::DataFusionError;
3+
use datafusion::error::Result;
34
use datafusion::execution::FunctionRegistry;
45
use datafusion::logical_expr::{AggregateUDF, ScalarUDF};
56
use datafusion::physical_plan::ExecutionPlan;
67
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
8+
use prost::Message;
79
use std::fmt::Debug;
810
use std::sync::Arc;
11+
// Code taken from https://github.com/apache/datafusion/blob/10f41887fa40d7d425c19b07857f80115460a98e/datafusion/proto/src/physical_plan/mod.rs
12+
// TODO: It's not yet on DF 49, once upgrading to DF 50 we can remove this
913

10-
// Idea taken from
11-
// https://github.com/apache/datafusion/blob/0eebc0c7c0ffcd1514f5c6d0f8e2b6d0c69a07f5/datafusion-examples/examples/composed_extension_codec.rs#L236-L291
14+
/// DataEncoderTuple captures the position of the encoder
15+
/// in the codec list that was used to encode the data and actual encoded data
16+
#[derive(Clone, PartialEq, prost::Message)]
17+
struct DataEncoderTuple {
18+
/// The position of encoder used to encode data
19+
/// (to be used for decoding)
20+
#[prost(uint32, tag = 1)]
21+
pub encoder_position: u32,
1222

13-
/// A [PhysicalExtensionCodec] that holds multiple [PhysicalExtensionCodec] and tries them
14-
/// sequentially until one works.
15-
#[derive(Debug, Clone, Default)]
16-
pub(crate) struct ComposedPhysicalExtensionCodec {
23+
#[prost(bytes, tag = 2)]
24+
pub blob: Vec<u8>,
25+
}
26+
27+
/// A PhysicalExtensionCodec that tries one of multiple inner codecs
28+
/// until one works
29+
#[derive(Debug)]
30+
pub struct ComposedPhysicalExtensionCodec {
1731
codecs: Vec<Arc<dyn PhysicalExtensionCodec>>,
1832
}
1933

2034
impl ComposedPhysicalExtensionCodec {
21-
/// Adds a new [PhysicalExtensionCodec] to the list. These codecs will be tried
22-
/// sequentially until one works.
23-
pub(crate) fn push(&mut self, codec: impl PhysicalExtensionCodec + 'static) {
24-
self.codecs.push(Arc::new(codec));
35+
// Position in this codecs list is important as it will be used for decoding.
36+
// If new codec is added it should go to last position.
37+
pub fn new(codecs: Vec<Arc<dyn PhysicalExtensionCodec>>) -> Self {
38+
Self { codecs }
2539
}
2640

27-
/// Adds a new [PhysicalExtensionCodec] to the list. These codecs will be tried
28-
/// sequentially until one works.
29-
pub(crate) fn push_arc(&mut self, codec: Arc<dyn PhysicalExtensionCodec>) {
30-
self.codecs.push(codec);
41+
fn decode_protobuf<R>(
42+
&self,
43+
buf: &[u8],
44+
decode: impl FnOnce(&dyn PhysicalExtensionCodec, &[u8]) -> Result<R, DataFusionError>,
45+
) -> Result<R, DataFusionError> {
46+
let proto =
47+
DataEncoderTuple::decode(buf).map_err(|e| DataFusionError::Internal(e.to_string()))?;
48+
49+
let pos = proto.encoder_position as usize;
50+
let codec = self.codecs.get(pos).ok_or_else(|| {
51+
internal_datafusion_err!(
52+
"Can't find required codec in position {pos} in codec list with {} elements",
53+
self.codecs.len()
54+
)
55+
})?;
56+
57+
decode(codec.as_ref(), &proto.blob)
3158
}
3259

33-
fn try_any<T>(
60+
fn encode_protobuf(
3461
&self,
35-
mut f: impl FnMut(&dyn PhysicalExtensionCodec) -> Result<T, DataFusionError>,
36-
) -> Result<T, DataFusionError> {
37-
let mut errs = vec![];
38-
for codec in &self.codecs {
39-
match f(codec.as_ref()) {
40-
Ok(node) => return Ok(node),
41-
Err(err) => errs.push(err),
62+
buf: &mut Vec<u8>,
63+
mut encode: impl FnMut(&dyn PhysicalExtensionCodec, &mut Vec<u8>) -> Result<()>,
64+
) -> Result<(), DataFusionError> {
65+
let mut data = vec![];
66+
let mut last_err = None;
67+
let mut encoder_position = None;
68+
69+
// find the encoder
70+
for (position, codec) in self.codecs.iter().enumerate() {
71+
match encode(codec.as_ref(), &mut data) {
72+
Ok(_) => {
73+
encoder_position = Some(position as u32);
74+
break;
75+
}
76+
Err(err) => last_err = Some(err),
4277
}
4378
}
4479

45-
if errs.is_empty() {
46-
return not_impl_err!("Empty list of composed codecs");
47-
}
80+
let encoder_position = encoder_position.ok_or_else(|| {
81+
last_err.unwrap_or_else(|| {
82+
DataFusionError::NotImplemented("Empty list of composed codecs".to_owned())
83+
})
84+
})?;
4885

49-
let mut msg = "None of the provided PhysicalExtensionCodec worked:".to_string();
50-
for err in &errs {
51-
msg += &format!("\n {err}");
52-
}
53-
not_impl_err!("{msg}")
86+
// encode with encoder position
87+
let proto = DataEncoderTuple {
88+
encoder_position,
89+
blob: data,
90+
};
91+
proto
92+
.encode(buf)
93+
.map_err(|e| DataFusionError::Internal(e.to_string()))
5494
}
5595
}
5696

@@ -60,39 +100,27 @@ impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec {
60100
buf: &[u8],
61101
inputs: &[Arc<dyn ExecutionPlan>],
62102
registry: &dyn FunctionRegistry,
63-
) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
64-
self.try_any(|codec| codec.try_decode(buf, inputs, registry))
103+
) -> Result<Arc<dyn ExecutionPlan>> {
104+
self.decode_protobuf(buf, |codec, data| codec.try_decode(data, inputs, registry))
65105
}
66106

67-
fn try_encode(
68-
&self,
69-
node: Arc<dyn ExecutionPlan>,
70-
buf: &mut Vec<u8>,
71-
) -> Result<(), DataFusionError> {
72-
self.try_any(|codec| codec.try_encode(node.clone(), buf))
107+
fn try_encode(&self, node: Arc<dyn ExecutionPlan>, buf: &mut Vec<u8>) -> Result<()> {
108+
self.encode_protobuf(buf, |codec, data| codec.try_encode(Arc::clone(&node), data))
73109
}
74110

75-
fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result<Arc<ScalarUDF>, DataFusionError> {
76-
self.try_any(|codec| codec.try_decode_udf(name, buf))
111+
fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result<Arc<ScalarUDF>> {
112+
self.decode_protobuf(buf, |codec, data| codec.try_decode_udf(name, data))
77113
}
78114

79-
fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec<u8>) -> Result<(), DataFusionError> {
80-
self.try_any(|codec| codec.try_encode_udf(node, buf))
115+
fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec<u8>) -> Result<()> {
116+
self.encode_protobuf(buf, |codec, data| codec.try_encode_udf(node, data))
81117
}
82118

83-
fn try_decode_udaf(
84-
&self,
85-
name: &str,
86-
buf: &[u8],
87-
) -> Result<Arc<AggregateUDF>, DataFusionError> {
88-
self.try_any(|codec| codec.try_decode_udaf(name, buf))
119+
fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result<Arc<AggregateUDF>> {
120+
self.decode_protobuf(buf, |codec, data| codec.try_decode_udaf(name, data))
89121
}
90122

91-
fn try_encode_udaf(
92-
&self,
93-
node: &AggregateUDF,
94-
buf: &mut Vec<u8>,
95-
) -> Result<(), DataFusionError> {
96-
self.try_any(|codec| codec.try_encode_udaf(node, buf))
123+
fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec<u8>) -> Result<()> {
124+
self.encode_protobuf(buf, |codec, data| codec.try_encode_udaf(node, data))
97125
}
98126
}

src/distributed_ext.rs

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@ use crate::channel_resolver_ext::set_distributed_channel_resolver;
33
use crate::config_extension_ext::{
44
set_distributed_option_extension, set_distributed_option_extension_from_headers,
55
};
6-
use crate::protobuf::set_distributed_user_codec;
6+
use crate::protobuf::{set_distributed_user_codec, set_distributed_user_codec_arc};
77
use datafusion::common::DataFusionError;
88
use datafusion::config::ConfigExtension;
99
use datafusion::execution::{SessionState, SessionStateBuilder};
1010
use datafusion::prelude::{SessionConfig, SessionContext};
1111
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
1212
use delegate::delegate;
1313
use http::HeaderMap;
14+
use std::sync::Arc;
1415

1516
/// Extends DataFusion with distributed capabilities.
1617
pub trait DistributedExt: Sized {
@@ -125,7 +126,8 @@ pub trait DistributedExt: Sized {
125126
) -> Result<(), DataFusionError>;
126127

127128
/// Injects a user-defined [PhysicalExtensionCodec] that is capable of encoding/decoding
128-
/// custom execution nodes.
129+
/// custom execution nodes. Multiple user-defined [PhysicalExtensionCodec] can be added
130+
/// by calling this method several times.
129131
///
130132
/// Example:
131133
///
@@ -166,6 +168,12 @@ pub trait DistributedExt: Sized {
166168
/// Same as [DistributedExt::with_distributed_user_codec] but with an in-place mutation
167169
fn set_distributed_user_codec<T: PhysicalExtensionCodec + 'static>(&mut self, codec: T);
168170

171+
/// Same as [DistributedExt::with_distributed_user_codec] but with a dynamic argument.
172+
fn with_distributed_user_codec_arc(self, codec: Arc<dyn PhysicalExtensionCodec>) -> Self;
173+
174+
/// Same as [DistributedExt::set_distributed_user_codec] but with a dynamic argument.
175+
fn set_distributed_user_codec_arc(&mut self, codec: Arc<dyn PhysicalExtensionCodec>);
176+
169177
/// Injects a [ChannelResolver] implementation for Distributed DataFusion to resolve worker
170178
/// nodes. When running in distributed mode, setting a [ChannelResolver] is required.
171179
///
@@ -233,6 +241,10 @@ impl DistributedExt for SessionConfig {
233241
set_distributed_user_codec(self, codec)
234242
}
235243

244+
fn set_distributed_user_codec_arc(&mut self, codec: Arc<dyn PhysicalExtensionCodec>) {
245+
set_distributed_user_codec_arc(self, codec)
246+
}
247+
236248
fn set_distributed_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(
237249
&mut self,
238250
resolver: T,
@@ -254,6 +266,10 @@ impl DistributedExt for SessionConfig {
254266
#[expr($;self)]
255267
fn with_distributed_user_codec<T: PhysicalExtensionCodec + 'static>(mut self, codec: T) -> Self;
256268

269+
#[call(set_distributed_user_codec_arc)]
270+
#[expr($;self)]
271+
fn with_distributed_user_codec_arc(mut self, codec: Arc<dyn PhysicalExtensionCodec>) -> Self;
272+
257273
#[call(set_distributed_channel_resolver)]
258274
#[expr($;self)]
259275
fn with_distributed_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(mut self, resolver: T) -> Self;
@@ -279,6 +295,11 @@ impl DistributedExt for SessionStateBuilder {
279295
#[expr($;self)]
280296
fn with_distributed_user_codec<T: PhysicalExtensionCodec + 'static>(mut self, codec: T) -> Self;
281297

298+
fn set_distributed_user_codec_arc(&mut self, codec: Arc<dyn PhysicalExtensionCodec>);
299+
#[call(set_distributed_user_codec_arc)]
300+
#[expr($;self)]
301+
fn with_distributed_user_codec_arc(mut self, codec: Arc<dyn PhysicalExtensionCodec>) -> Self;
302+
282303
fn set_distributed_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(&mut self, resolver: T);
283304
#[call(set_distributed_channel_resolver)]
284305
#[expr($;self)]
@@ -305,6 +326,11 @@ impl DistributedExt for SessionState {
305326
#[expr($;self)]
306327
fn with_distributed_user_codec<T: PhysicalExtensionCodec + 'static>(mut self, codec: T) -> Self;
307328

329+
fn set_distributed_user_codec_arc(&mut self, codec: Arc<dyn PhysicalExtensionCodec>);
330+
#[call(set_distributed_user_codec_arc)]
331+
#[expr($;self)]
332+
fn with_distributed_user_codec_arc(mut self, codec: Arc<dyn PhysicalExtensionCodec>) -> Self;
333+
308334
fn set_distributed_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(&mut self, resolver: T);
309335
#[call(set_distributed_channel_resolver)]
310336
#[expr($;self)]
@@ -331,6 +357,11 @@ impl DistributedExt for SessionContext {
331357
#[expr($;self)]
332358
fn with_distributed_user_codec<T: PhysicalExtensionCodec + 'static>(self, codec: T) -> Self;
333359

360+
fn set_distributed_user_codec_arc(&mut self, codec: Arc<dyn PhysicalExtensionCodec>);
361+
#[call(set_distributed_user_codec_arc)]
362+
#[expr($;self)]
363+
fn with_distributed_user_codec_arc(self, codec: Arc<dyn PhysicalExtensionCodec>) -> Self;
364+
334365
fn set_distributed_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(&mut self, resolver: T);
335366
#[call(set_distributed_channel_resolver)]
336367
#[expr($;self)]

src/lib.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@ pub use channel_resolver_ext::{BoxCloneSyncChannel, ChannelResolver};
1717
pub use distributed_ext::DistributedExt;
1818
pub use distributed_physical_optimizer_rule::DistributedPhysicalOptimizerRule;
1919
pub use execution_plans::display_plan_graphviz;
20-
pub use execution_plans::{ExecutionTask, NetworkShuffleExec, PartitionIsolatorExec, StageExec};
20+
pub use execution_plans::{
21+
DistributedTaskContext, ExecutionTask, NetworkCoalesceExec, NetworkShuffleExec,
22+
PartitionIsolatorExec, StageExec,
23+
};
2124
pub use flight_service::{
2225
ArrowFlightEndpoint, DefaultSessionBuilder, DistributedSessionBuilder,
2326
DistributedSessionBuilderContext, MappedDistributedSessionBuilder,

src/protobuf/distributed_codec.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use super::get_distributed_user_codec;
1+
use super::get_distributed_user_codecs;
22
use crate::common::ComposedPhysicalExtensionCodec;
33
use crate::execution_plans::{NetworkCoalesceExec, NetworkCoalesceReady, NetworkShuffleReadyExec};
44
use crate::{NetworkShuffleExec, PartitionIsolatorExec};
@@ -24,12 +24,9 @@ pub struct DistributedCodec;
2424

2525
impl DistributedCodec {
2626
pub fn new_combined_with_user(cfg: &SessionConfig) -> impl PhysicalExtensionCodec + use<> {
27-
let mut combined_codec = ComposedPhysicalExtensionCodec::default();
28-
combined_codec.push(DistributedCodec {});
29-
if let Some(ref user_codec) = get_distributed_user_codec(cfg) {
30-
combined_codec.push_arc(Arc::clone(user_codec));
31-
}
32-
combined_codec
27+
let mut codecs: Vec<Arc<dyn PhysicalExtensionCodec>> = vec![Arc::new(DistributedCodec {})];
28+
codecs.extend(get_distributed_user_codecs(cfg));
29+
ComposedPhysicalExtensionCodec::new(codecs)
3330
}
3431
}
3532

src/protobuf/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,6 @@ pub(crate) use errors::{
1212
map_status_to_datafusion_error,
1313
};
1414
pub(crate) use stage_proto::{StageExecProto, StageKey, proto_from_stage, stage_from_proto};
15-
pub(crate) use user_codec::{get_distributed_user_codec, set_distributed_user_codec};
15+
pub(crate) use user_codec::{
16+
get_distributed_user_codecs, set_distributed_user_codec, set_distributed_user_codec_arc,
17+
};

src/protobuf/user_codec.rs

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,32 @@ use datafusion::prelude::SessionConfig;
22
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
33
use std::sync::Arc;
44

5-
pub struct UserProvidedCodec(Arc<dyn PhysicalExtensionCodec>);
5+
pub struct UserProvidedCodecs(Vec<Arc<dyn PhysicalExtensionCodec>>);
6+
7+
pub(crate) fn set_distributed_user_codec_arc(
8+
cfg: &mut SessionConfig,
9+
codec: Arc<dyn PhysicalExtensionCodec>,
10+
) {
11+
let mut codecs = match cfg.get_extension::<UserProvidedCodecs>() {
12+
None => vec![],
13+
Some(prev) => prev.0.clone(),
14+
};
15+
codecs.push(codec);
16+
cfg.set_extension(Arc::new(UserProvidedCodecs(codecs)))
17+
}
618

719
pub(crate) fn set_distributed_user_codec<T: PhysicalExtensionCodec + 'static>(
820
cfg: &mut SessionConfig,
921
codec: T,
1022
) {
11-
cfg.set_extension(Arc::new(UserProvidedCodec(Arc::new(codec))))
23+
set_distributed_user_codec_arc(cfg, Arc::new(codec))
1224
}
1325

14-
pub(crate) fn get_distributed_user_codec(
26+
pub(crate) fn get_distributed_user_codecs(
1527
cfg: &SessionConfig,
16-
) -> Option<Arc<dyn PhysicalExtensionCodec>> {
17-
Some(Arc::clone(&cfg.get_extension::<UserProvidedCodec>()?.0))
28+
) -> Vec<Arc<dyn PhysicalExtensionCodec>> {
29+
match cfg.get_extension::<UserProvidedCodecs>() {
30+
None => vec![],
31+
Some(v) => v.0.clone(),
32+
}
1833
}

0 commit comments

Comments
 (0)