Skip to content

Commit a0dc5e5

Browse files
committed
Include ChannelManager in the DistributedExt trait
1 parent 248f7fb commit a0dc5e5

File tree

11 files changed

+242
-119
lines changed

11 files changed

+242
-119
lines changed

src/channel_manager.rs

Lines changed: 0 additions & 71 deletions
This file was deleted.

src/channel_manager_ext.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
use async_trait::async_trait;
2+
use datafusion::error::DataFusionError;
3+
use datafusion::prelude::SessionConfig;
4+
use std::sync::Arc;
5+
use tonic::body::BoxBody;
6+
use url::Url;
7+
8+
pub(crate) fn set_channel_resolver(
9+
cfg: &mut SessionConfig,
10+
channel_resolver: impl ChannelResolver + Send + Sync + 'static,
11+
) {
12+
cfg.set_extension(Arc::new(ChannelResolverExtension(Arc::new(
13+
channel_resolver,
14+
))));
15+
}
16+
17+
pub(crate) fn get_channel_resolver(
18+
cfg: &SessionConfig,
19+
) -> Option<Arc<dyn ChannelResolver + Send + Sync>> {
20+
cfg.get_extension::<ChannelResolverExtension>()
21+
.map(|cm| cm.0.clone())
22+
}
23+
24+
#[derive(Clone)]
25+
struct ChannelResolverExtension(Arc<dyn ChannelResolver + Send + Sync>);
26+
27+
pub type BoxCloneSyncChannel = tower::util::BoxCloneSyncService<
28+
http::Request<BoxBody>,
29+
http::Response<BoxBody>,
30+
tonic::transport::Error,
31+
>;
32+
33+
/// Abstracts networking details so that users can implement their own network resolution
34+
/// mechanism.
35+
#[async_trait]
36+
pub trait ChannelResolver {
37+
/// Gets all available worker URLs. Used during stage assignment.
38+
fn get_urls(&self) -> Result<Vec<Url>, DataFusionError>;
39+
/// For a given URL, get a channel for communicating to it.
40+
async fn get_channel_for_url(&self, url: &Url) -> Result<BoxCloneSyncChannel, DataFusionError>;
41+
}
42+
43+
#[async_trait]
44+
impl ChannelResolver for Arc<dyn ChannelResolver + Send + Sync> {
45+
fn get_urls(&self) -> Result<Vec<Url>, DataFusionError> {
46+
self.as_ref().get_urls()
47+
}
48+
49+
async fn get_channel_for_url(&self, url: &Url) -> Result<BoxCloneSyncChannel, DataFusionError> {
50+
self.as_ref().get_channel_for_url(url).await
51+
}
52+
}

src/distributed_ext.rs

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
use crate::channel_manager_ext::set_channel_resolver;
12
use crate::config_extension_ext::{
23
set_distributed_option_extension, set_distributed_option_extension_from_headers,
34
};
45
use crate::user_codec_ext::set_user_codec;
6+
use crate::ChannelResolver;
57
use datafusion::common::DataFusionError;
68
use datafusion::config::ConfigExtension;
79
use datafusion::execution::{SessionState, SessionStateBuilder};
@@ -122,17 +124,19 @@ pub trait DistributedExt: Sized {
122124
headers: &HeaderMap,
123125
) -> Result<(), DataFusionError>;
124126

125-
/// Injects a user-defined codec that is capable of encoding/decoding custom execution nodes.
127+
/// Injects a user-defined [PhysicalExtensionCodec] that is capable of encoding/decoding
128+
/// custom execution nodes.
126129
///
127130
/// Example:
128131
///
129132
/// ```
130133
/// # use std::sync::Arc;
134+
/// # use datafusion::common::DataFusionError;
131135
/// # use datafusion::execution::{SessionState, FunctionRegistry, SessionStateBuilder};
132136
/// # use datafusion::physical_plan::ExecutionPlan;
133137
/// # use datafusion::prelude::SessionConfig;
134138
/// # use datafusion_proto::physical_plan::PhysicalExtensionCodec;
135-
/// # use datafusion_distributed::DistributedExt;
139+
/// # use datafusion_distributed::{DistributedExt, DistributedSessionBuilderContext};
136140
///
137141
/// #[derive(Debug)]
138142
/// struct CustomExecCodec;
@@ -148,11 +152,61 @@ pub trait DistributedExt: Sized {
148152
/// }
149153
///
150154
/// let config = SessionConfig::new().with_user_codec(CustomExecCodec);
155+
///
156+
/// async fn build_state(ctx: DistributedSessionBuilderContext) -> Result<SessionState, DataFusionError> {
157+
/// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will
158+
/// // know how to deserialize the CustomExtension from the gRPC metadata.
159+
/// Ok(SessionStateBuilder::new()
160+
/// .with_user_codec(CustomExecCodec)
161+
/// .build())
162+
/// }
151163
/// ```
152164
fn with_user_codec<T: PhysicalExtensionCodec + 'static>(self, codec: T) -> Self;
153165

154166
/// Same as [DistributedExt::with_user_codec] but with an in-place mutation
155167
fn set_user_codec<T: PhysicalExtensionCodec + 'static>(&mut self, codec: T);
168+
169+
/// Injects a [ChannelResolver] implementation for Distributed DataFusion to resolve worker
170+
/// nodes. When running in distributed mode, setting a [ChannelResolver] is required.
171+
///
172+
/// Example:
173+
///
174+
/// ```
175+
/// # use async_trait::async_trait;
176+
/// # use datafusion::common::DataFusionError;
177+
/// # use datafusion::execution::{SessionState, SessionStateBuilder};
178+
/// # use datafusion::prelude::SessionConfig;
179+
/// # use url::Url;
180+
/// # use datafusion_distributed::{BoxCloneSyncChannel, ChannelResolver, DistributedExt, DistributedSessionBuilderContext};
181+
///
182+
/// struct CustomChannelResolver;
183+
///
184+
/// #[async_trait]
185+
/// impl ChannelResolver for CustomChannelResolver {
186+
/// fn get_urls(&self) -> Result<Vec<Url>, DataFusionError> {
187+
/// todo!()
188+
/// }
189+
///
190+
/// async fn get_channel_for_url(&self, url: &Url) -> Result<BoxCloneSyncChannel, DataFusionError> {
191+
/// todo!()
192+
/// }
193+
/// }
194+
///
195+
/// let config = SessionConfig::new().with_channel_resolver(CustomChannelResolver);
196+
///
197+
/// async fn build_state(ctx: DistributedSessionBuilderContext) -> Result<SessionState, DataFusionError> {
198+
/// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will
199+
/// // know how to deserialize the CustomExtension from the gRPC metadata.
200+
/// Ok(SessionStateBuilder::new()
201+
/// .with_channel_resolver(CustomChannelResolver)
202+
/// .build())
203+
/// }
204+
/// ```
205+
fn with_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(self, resolver: T)
206+
-> Self;
207+
208+
/// Same as [DistributedExt::with_channel_resolver] but with an in-place mutation.
209+
fn set_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(&mut self, resolver: T);
156210
}
157211

158212
impl DistributedExt for SessionConfig {
@@ -174,17 +228,27 @@ impl DistributedExt for SessionConfig {
174228
set_user_codec(self, codec)
175229
}
176230

231+
fn set_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(&mut self, resolver: T) {
232+
set_channel_resolver(self, resolver)
233+
}
234+
177235
delegate! {
178236
to self {
179237
#[call(set_distributed_option_extension)]
180238
#[expr($?;Ok(self))]
181239
fn with_distributed_option_extension<T: ConfigExtension + Default>(mut self, t: T) -> Result<Self, DataFusionError>;
240+
182241
#[call(set_distributed_option_extension_from_headers)]
183242
#[expr($?;Ok(self))]
184243
fn with_distributed_option_extension_from_headers<T: ConfigExtension + Default>(mut self, headers: &HeaderMap) -> Result<Self, DataFusionError>;
244+
185245
#[call(set_user_codec)]
186246
#[expr($;self)]
187247
fn with_user_codec<T: PhysicalExtensionCodec + 'static>(mut self, codec: T) -> Self;
248+
249+
#[call(set_channel_resolver)]
250+
#[expr($;self)]
251+
fn with_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(mut self, resolver: T) -> Self;
188252
}
189253
}
190254
}
@@ -196,14 +260,21 @@ impl DistributedExt for SessionStateBuilder {
196260
#[call(set_distributed_option_extension)]
197261
#[expr($?;Ok(self))]
198262
fn with_distributed_option_extension<T: ConfigExtension + Default>(mut self, t: T) -> Result<Self, DataFusionError>;
263+
199264
fn set_distributed_option_extension_from_headers<T: ConfigExtension + Default>(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>;
200265
#[call(set_distributed_option_extension_from_headers)]
201266
#[expr($?;Ok(self))]
202267
fn with_distributed_option_extension_from_headers<T: ConfigExtension + Default>(mut self, headers: &HeaderMap) -> Result<Self, DataFusionError>;
268+
203269
fn set_user_codec<T: PhysicalExtensionCodec + 'static>(&mut self, codec: T);
204270
#[call(set_user_codec)]
205271
#[expr($;self)]
206272
fn with_user_codec<T: PhysicalExtensionCodec + 'static>(mut self, codec: T) -> Self;
273+
274+
fn set_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(&mut self, resolver: T);
275+
#[call(set_channel_resolver)]
276+
#[expr($;self)]
277+
fn with_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(mut self, resolver: T) -> Self;
207278
}
208279
}
209280
}
@@ -215,14 +286,21 @@ impl DistributedExt for SessionState {
215286
#[call(set_distributed_option_extension)]
216287
#[expr($?;Ok(self))]
217288
fn with_distributed_option_extension<T: ConfigExtension + Default>(mut self, t: T) -> Result<Self, DataFusionError>;
289+
218290
fn set_distributed_option_extension_from_headers<T: ConfigExtension + Default>(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>;
219291
#[call(set_distributed_option_extension_from_headers)]
220292
#[expr($?;Ok(self))]
221293
fn with_distributed_option_extension_from_headers<T: ConfigExtension + Default>(mut self, headers: &HeaderMap) -> Result<Self, DataFusionError>;
294+
222295
fn set_user_codec<T: PhysicalExtensionCodec + 'static>(&mut self, codec: T);
223296
#[call(set_user_codec)]
224297
#[expr($;self)]
225298
fn with_user_codec<T: PhysicalExtensionCodec + 'static>(mut self, codec: T) -> Self;
299+
300+
fn set_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(&mut self, resolver: T);
301+
#[call(set_channel_resolver)]
302+
#[expr($;self)]
303+
fn with_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(mut self, resolver: T) -> Self;
226304
}
227305
}
228306
}
@@ -234,14 +312,21 @@ impl DistributedExt for SessionContext {
234312
#[call(set_distributed_option_extension)]
235313
#[expr($?;Ok(self))]
236314
fn with_distributed_option_extension<T: ConfigExtension + Default>(self, t: T) -> Result<Self, DataFusionError>;
315+
237316
fn set_distributed_option_extension_from_headers<T: ConfigExtension + Default>(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>;
238317
#[call(set_distributed_option_extension_from_headers)]
239318
#[expr($?;Ok(self))]
240319
fn with_distributed_option_extension_from_headers<T: ConfigExtension + Default>(self, headers: &HeaderMap) -> Result<Self, DataFusionError>;
320+
241321
fn set_user_codec<T: PhysicalExtensionCodec + 'static>(&mut self, codec: T);
242322
#[call(set_user_codec)]
243323
#[expr($;self)]
244324
fn with_user_codec<T: PhysicalExtensionCodec + 'static>(self, codec: T) -> Self;
325+
326+
fn set_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(&mut self, resolver: T);
327+
#[call(set_channel_resolver)]
328+
#[expr($;self)]
329+
fn with_channel_resolver<T: ChannelResolver + Send + Sync + 'static>(self, resolver: T) -> Self;
245330
}
246331
}
247332
}

src/flight_service/do_get.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ impl ArrowFlightEndpoint {
130130

131131
// Add the extensions that might be required for ExecutionPlan nodes in the plan
132132
let config = state.config_mut();
133-
config.set_extension(Arc::clone(&self.channel_manager));
134133
config.set_extension(stage.clone());
135134
config.set_extension(Arc::new(ContextGrpcMetadata(headers)));
136135

src/flight_service/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ pub(crate) use do_get::DoGet;
77
pub use service::{ArrowFlightEndpoint, StageKey};
88
pub use session_builder::{
99
DefaultSessionBuilder, DistributedSessionBuilder, DistributedSessionBuilderContext,
10+
MappedDistributedSessionBuilder, MappedDistributedSessionBuilderExt,
1011
};

src/flight_service/service.rs

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
use crate::channel_manager::ChannelManager;
2-
use crate::flight_service::session_builder::DefaultSessionBuilder;
31
use crate::flight_service::DistributedSessionBuilder;
42
use crate::stage::ExecutionStage;
5-
use crate::ChannelResolver;
63
use arrow_flight::flight_service_server::FlightService;
74
use arrow_flight::{
85
Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
@@ -32,29 +29,20 @@ pub struct StageKey {
3229
}
3330

3431
pub struct ArrowFlightEndpoint {
35-
pub(super) channel_manager: Arc<ChannelManager>,
3632
pub(super) runtime: Arc<RuntimeEnv>,
3733
#[allow(clippy::type_complexity)]
3834
pub(super) stages: DashMap<StageKey, Arc<OnceCell<(SessionState, Arc<ExecutionStage>)>>>,
3935
pub(super) session_builder: Arc<dyn DistributedSessionBuilder + Send + Sync>,
4036
}
4137

4238
impl ArrowFlightEndpoint {
43-
pub fn new(channel_resolver: impl ChannelResolver + Send + Sync + 'static) -> Self {
39+
pub fn new(session_builder: impl DistributedSessionBuilder + Send + Sync + 'static) -> Self {
4440
Self {
45-
channel_manager: Arc::new(ChannelManager::new(channel_resolver)),
4641
runtime: Arc::new(RuntimeEnv::default()),
4742
stages: DashMap::new(),
48-
session_builder: Arc::new(DefaultSessionBuilder),
43+
session_builder: Arc::new(session_builder),
4944
}
5045
}
51-
52-
pub fn with_session_builder(
53-
&mut self,
54-
session_builder: impl DistributedSessionBuilder + Send + Sync + 'static,
55-
) {
56-
self.session_builder = Arc::new(session_builder);
57-
}
5846
}
5947

6048
#[async_trait]

0 commit comments

Comments
 (0)