Skip to content

Commit 7925912

Browse files
committed
Address public API weakpoints
1 parent 64920af commit 7925912

26 files changed

+241
-475
lines changed

Cargo.lock

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ parquet = { version = "55.2.0", optional = true }
3939
arrow = { version = "55.2.0", optional = true }
4040
tokio-stream = { version = "0.1.17", optional = true }
4141
hyper-util = { version = "0.1.16", optional = true }
42-
pin-project = "1.1.10"
4342

4443
[features]
4544
integration = [
Lines changed: 83 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,96 @@
1-
use datafusion::common::not_impl_err;
1+
use datafusion::common::internal_datafusion_err;
22
use datafusion::error::DataFusionError;
3+
use datafusion::error::Result;
34
use datafusion::execution::FunctionRegistry;
45
use datafusion::logical_expr::{AggregateUDF, ScalarUDF};
56
use datafusion::physical_plan::ExecutionPlan;
67
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
8+
use prost::Message;
79
use std::fmt::Debug;
810
use std::sync::Arc;
11+
// Code taken from https://github.com/apache/datafusion/blob/10f41887fa40d7d425c19b07857f80115460a98e/datafusion/proto/src/physical_plan/mod.rs
12+
// TODO: It's not yet on DF 49, once upgrading to DF 50 we can remove this
913

10-
// Idea taken from
11-
// https://github.com/apache/datafusion/blob/0eebc0c7c0ffcd1514f5c6d0f8e2b6d0c69a07f5/datafusion-examples/examples/composed_extension_codec.rs#L236-L291
14+
/// DataEncoderTuple captures the position of the encoder
15+
/// in the codec list that was used to encode the data and actual encoded data
16+
#[derive(Clone, PartialEq, prost::Message)]
17+
struct DataEncoderTuple {
18+
/// The position of encoder used to encode data
19+
/// (to be used for decoding)
20+
#[prost(uint32, tag = 1)]
21+
pub encoder_position: u32,
1222

13-
/// A [PhysicalExtensionCodec] that holds multiple [PhysicalExtensionCodec] and tries them
14-
/// sequentially until one works.
15-
#[derive(Debug, Clone, Default)]
16-
pub(crate) struct ComposedPhysicalExtensionCodec {
23+
#[prost(bytes, tag = 2)]
24+
pub blob: Vec<u8>,
25+
}
26+
27+
/// A PhysicalExtensionCodec that tries one of multiple inner codecs
28+
/// until one works
29+
#[derive(Debug)]
30+
pub struct ComposedPhysicalExtensionCodec {
1731
codecs: Vec<Arc<dyn PhysicalExtensionCodec>>,
1832
}
1933

2034
impl ComposedPhysicalExtensionCodec {
21-
/// Adds a new [PhysicalExtensionCodec] to the list. These codecs will be tried
22-
/// sequentially until one works.
23-
pub(crate) fn push(&mut self, codec: impl PhysicalExtensionCodec + 'static) {
24-
self.codecs.push(Arc::new(codec));
35+
// Position in this codecs list is important as it will be used for decoding.
36+
// If new codec is added it should go to last position.
37+
pub fn new(codecs: Vec<Arc<dyn PhysicalExtensionCodec>>) -> Self {
38+
Self { codecs }
2539
}
2640

27-
/// Adds a new [PhysicalExtensionCodec] to the list. These codecs will be tried
28-
/// sequentially until one works.
29-
pub(crate) fn push_arc(&mut self, codec: Arc<dyn PhysicalExtensionCodec>) {
30-
self.codecs.push(codec);
41+
fn decode_protobuf<R>(
42+
&self,
43+
buf: &[u8],
44+
decode: impl FnOnce(&dyn PhysicalExtensionCodec, &[u8]) -> Result<R, DataFusionError>,
45+
) -> Result<R, DataFusionError> {
46+
let proto =
47+
DataEncoderTuple::decode(buf).map_err(|e| DataFusionError::Internal(e.to_string()))?;
48+
49+
let pos = proto.encoder_position as usize;
50+
let codec = self.codecs.get(pos).ok_or_else(|| {
51+
internal_datafusion_err!(
52+
"Can't find required codec in position {pos} in codec list with {} elements",
53+
self.codecs.len()
54+
)
55+
})?;
56+
57+
decode(codec.as_ref(), &proto.blob)
3158
}
3259

33-
fn try_any<T>(
60+
fn encode_protobuf(
3461
&self,
35-
mut f: impl FnMut(&dyn PhysicalExtensionCodec) -> Result<T, DataFusionError>,
36-
) -> Result<T, DataFusionError> {
37-
let mut errs = vec![];
38-
for codec in &self.codecs {
39-
match f(codec.as_ref()) {
40-
Ok(node) => return Ok(node),
41-
Err(err) => errs.push(err),
62+
buf: &mut Vec<u8>,
63+
mut encode: impl FnMut(&dyn PhysicalExtensionCodec, &mut Vec<u8>) -> Result<()>,
64+
) -> Result<(), DataFusionError> {
65+
let mut data = vec![];
66+
let mut last_err = None;
67+
let mut encoder_position = None;
68+
69+
// find the encoder
70+
for (position, codec) in self.codecs.iter().enumerate() {
71+
match encode(codec.as_ref(), &mut data) {
72+
Ok(_) => {
73+
encoder_position = Some(position as u32);
74+
break;
75+
}
76+
Err(err) => last_err = Some(err),
4277
}
4378
}
4479

45-
if errs.is_empty() {
46-
return not_impl_err!("Empty list of composed codecs");
47-
}
80+
let encoder_position = encoder_position.ok_or_else(|| {
81+
last_err.unwrap_or_else(|| {
82+
DataFusionError::NotImplemented("Empty list of composed codecs".to_owned())
83+
})
84+
})?;
4885

49-
let mut msg = "None of the provided PhysicalExtensionCodec worked:".to_string();
50-
for err in &errs {
51-
msg += &format!("\n {err}");
52-
}
53-
not_impl_err!("{msg}")
86+
// encode with encoder position
87+
let proto = DataEncoderTuple {
88+
encoder_position,
89+
blob: data,
90+
};
91+
proto
92+
.encode(buf)
93+
.map_err(|e| DataFusionError::Internal(e.to_string()))
5494
}
5595
}
5696

@@ -60,39 +100,27 @@ impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec {
60100
buf: &[u8],
61101
inputs: &[Arc<dyn ExecutionPlan>],
62102
registry: &dyn FunctionRegistry,
63-
) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
64-
self.try_any(|codec| codec.try_decode(buf, inputs, registry))
103+
) -> Result<Arc<dyn ExecutionPlan>> {
104+
self.decode_protobuf(buf, |codec, data| codec.try_decode(data, inputs, registry))
65105
}
66106

67-
fn try_encode(
68-
&self,
69-
node: Arc<dyn ExecutionPlan>,
70-
buf: &mut Vec<u8>,
71-
) -> Result<(), DataFusionError> {
72-
self.try_any(|codec| codec.try_encode(node.clone(), buf))
107+
fn try_encode(&self, node: Arc<dyn ExecutionPlan>, buf: &mut Vec<u8>) -> Result<()> {
108+
self.encode_protobuf(buf, |codec, data| codec.try_encode(Arc::clone(&node), data))
73109
}
74110

75-
fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result<Arc<ScalarUDF>, DataFusionError> {
76-
self.try_any(|codec| codec.try_decode_udf(name, buf))
111+
fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result<Arc<ScalarUDF>> {
112+
self.decode_protobuf(buf, |codec, data| codec.try_decode_udf(name, data))
77113
}
78114

79-
fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec<u8>) -> Result<(), DataFusionError> {
80-
self.try_any(|codec| codec.try_encode_udf(node, buf))
115+
fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec<u8>) -> Result<()> {
116+
self.encode_protobuf(buf, |codec, data| codec.try_encode_udf(node, data))
81117
}
82118

83-
fn try_decode_udaf(
84-
&self,
85-
name: &str,
86-
buf: &[u8],
87-
) -> Result<Arc<AggregateUDF>, DataFusionError> {
88-
self.try_any(|codec| codec.try_decode_udaf(name, buf))
119+
fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result<Arc<AggregateUDF>> {
120+
self.decode_protobuf(buf, |codec, data| codec.try_decode_udaf(name, data))
89121
}
90122

91-
fn try_encode_udaf(
92-
&self,
93-
node: &AggregateUDF,
94-
buf: &mut Vec<u8>,
95-
) -> Result<(), DataFusionError> {
96-
self.try_any(|codec| codec.try_encode_udaf(node, buf))
123+
fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec<u8>) -> Result<()> {
124+
self.encode_protobuf(buf, |codec, data| codec.try_encode_udaf(node, data))
97125
}
98126
}

src/distributed_ext.rs

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@ use crate::channel_resolver_ext::set_distributed_channel_resolver;
33
use crate::config_extension_ext::{
44
set_distributed_option_extension, set_distributed_option_extension_from_headers,
55
};
6-
use crate::protobuf::set_distributed_user_codec;
6+
use crate::protobuf::{set_distributed_user_codec, set_distributed_user_codec_arc};
77
use datafusion::common::DataFusionError;
88
use datafusion::config::ConfigExtension;
99
use datafusion::execution::{SessionState, SessionStateBuilder};
1010
use datafusion::prelude::{SessionConfig, SessionContext};
1111
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
1212
use delegate::delegate;
1313
use http::HeaderMap;
14+
use std::sync::Arc;
1415

1516
/// Extends DataFusion with distributed capabilities.
1617
pub trait DistributedExt: Sized {
@@ -125,7 +126,8 @@ pub trait DistributedExt: Sized {
125126
) -> Result<(), DataFusionError>;
126127

127128
/// Injects a user-defined [PhysicalExtensionCodec] that is capable of encoding/decoding
128-
/// custom execution nodes.
129+
/// custom execution nodes. Multiple user-defined [PhysicalExtensionCodec] can be added
130+
/// by calling this method several times.
129131
///
130132
/// Example:
131133
///
@@ -166,6 +168,12 @@ pub trait DistributedExt: Sized {
166168
/// Same as [DistributedExt::with_distributed_user_codec] but with an in-place mutation
167169
fn set_distributed_user_codec<T: PhysicalExtensionCodec + 'static>(&mut self, codec: T);
168170

171+
/// Same as [DistributedExt::with_distributed_user_codec] but with a dynamic argument.
172+
fn with_distributed_user_codec_arc(self, codec: Arc<dyn PhysicalExtensionCodec>) -> Self;
173+
174+
/// Same as [DistributedExt::set_distributed_user_codec] but with a dynamic argument.
175+
fn set_distributed_user_codec_arc(&mut self, codec: Arc<dyn PhysicalExtensionCodec>);
176+
169177
/// Injects a [ChannelResolver] implementation for Distributed DataFusion to resolve worker
170178
/// nodes. When running in distributed mode, setting a [ChannelResolver] is required.
171179
///
@@ -233,6 +241,10 @@ impl DistributedExt for SessionConfig {
233241
set_distributed_user_codec(self, codec)
234242
}
235243

244+
fn set_distributed_user_codec_arc(&mut self, codec: Arc<dyn PhysicalExtensionCodec>) {
245+
set_distributed_user_codec_arc(self, codec)
246+
}
247+
236248
fn set_distributed_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(
237249
&mut self,
238250
resolver: T,
@@ -254,6 +266,10 @@ impl DistributedExt for SessionConfig {
254266
#[expr($;self)]
255267
fn with_distributed_user_codec<T: PhysicalExtensionCodec + 'static>(mut self, codec: T) -> Self;
256268

269+
#[call(set_distributed_user_codec_arc)]
270+
#[expr($;self)]
271+
fn with_distributed_user_codec_arc(mut self, codec: Arc<dyn PhysicalExtensionCodec>) -> Self;
272+
257273
#[call(set_distributed_channel_resolver)]
258274
#[expr($;self)]
259275
fn with_distributed_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(mut self, resolver: T) -> Self;
@@ -279,6 +295,11 @@ impl DistributedExt for SessionStateBuilder {
279295
#[expr($;self)]
280296
fn with_distributed_user_codec<T: PhysicalExtensionCodec + 'static>(mut self, codec: T) -> Self;
281297

298+
fn set_distributed_user_codec_arc(&mut self, codec: Arc<dyn PhysicalExtensionCodec>);
299+
#[call(set_distributed_user_codec_arc)]
300+
#[expr($;self)]
301+
fn with_distributed_user_codec_arc(mut self, codec: Arc<dyn PhysicalExtensionCodec>) -> Self;
302+
282303
fn set_distributed_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(&mut self, resolver: T);
283304
#[call(set_distributed_channel_resolver)]
284305
#[expr($;self)]
@@ -305,6 +326,11 @@ impl DistributedExt for SessionState {
305326
#[expr($;self)]
306327
fn with_distributed_user_codec<T: PhysicalExtensionCodec + 'static>(mut self, codec: T) -> Self;
307328

329+
fn set_distributed_user_codec_arc(&mut self, codec: Arc<dyn PhysicalExtensionCodec>);
330+
#[call(set_distributed_user_codec_arc)]
331+
#[expr($;self)]
332+
fn with_distributed_user_codec_arc(mut self, codec: Arc<dyn PhysicalExtensionCodec>) -> Self;
333+
308334
fn set_distributed_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(&mut self, resolver: T);
309335
#[call(set_distributed_channel_resolver)]
310336
#[expr($;self)]
@@ -331,6 +357,11 @@ impl DistributedExt for SessionContext {
331357
#[expr($;self)]
332358
fn with_distributed_user_codec<T: PhysicalExtensionCodec + 'static>(self, codec: T) -> Self;
333359

360+
fn set_distributed_user_codec_arc(&mut self, codec: Arc<dyn PhysicalExtensionCodec>);
361+
#[call(set_distributed_user_codec_arc)]
362+
#[expr($;self)]
363+
fn with_distributed_user_codec_arc(self, codec: Arc<dyn PhysicalExtensionCodec>) -> Self;
364+
334365
fn set_distributed_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(&mut self, resolver: T);
335366
#[call(set_distributed_channel_resolver)]
336367
#[expr($;self)]

src/protobuf/errors/arrow_error.rs renamed to src/errors/arrow_error.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1+
use crate::errors::io_error::IoErrorProto;
12
use datafusion::arrow::error::ArrowError;
23

3-
use crate::protobuf::errors::io_error::IoErrorProto;
4-
54
#[derive(Clone, PartialEq, ::prost::Message)]
65
pub struct ArrowErrorProto {
76
#[prost(string, optional, tag = "1")]

src/protobuf/errors/datafusion_error.rs renamed to src/errors/datafusion_error.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
use crate::protobuf::errors::arrow_error::ArrowErrorProto;
2-
use crate::protobuf::errors::io_error::IoErrorProto;
3-
use crate::protobuf::errors::objectstore_error::ObjectStoreErrorProto;
4-
use crate::protobuf::errors::parquet_error::ParquetErrorProto;
5-
use crate::protobuf::errors::parser_error::ParserErrorProto;
6-
use crate::protobuf::errors::schema_error::SchemaErrorProto;
1+
use crate::errors::arrow_error::ArrowErrorProto;
2+
use crate::errors::io_error::IoErrorProto;
3+
use crate::errors::objectstore_error::ObjectStoreErrorProto;
4+
use crate::errors::parquet_error::ParquetErrorProto;
5+
use crate::errors::parser_error::ParserErrorProto;
6+
use crate::errors::schema_error::SchemaErrorProto;
77
use datafusion::common::{DataFusionError, Diagnostic};
88
use datafusion::logical_expr::sqlparser::parser::ParserError;
99
use std::error::Error;
File renamed without changes.

src/protobuf/errors/mod.rs renamed to src/errors/mod.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
#![allow(clippy::upper_case_acronyms, clippy::vec_box)]
22

3+
use crate::errors::datafusion_error::DataFusionErrorProto;
34
use arrow_flight::error::FlightError;
45
use datafusion::common::internal_datafusion_err;
56
use datafusion::error::DataFusionError;
67
use prost::Message;
78

8-
use crate::protobuf::errors::datafusion_error::DataFusionErrorProto;
9-
109
mod arrow_error;
1110
mod datafusion_error;
1211
mod io_error;
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)