Skip to content

Commit 5000e0b

Browse files
committed
Evolve ChannelResolver trait for requiring a FlightClient instead of a channel
1 parent 3fe4b08 commit 5000e0b

File tree

9 files changed

+57
-40
lines changed

9 files changed

+57
-40
lines changed

examples/in_memory_cluster.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
use arrow::util::pretty::pretty_format_batches;
2+
use arrow_flight::flight_service_client::FlightServiceClient;
23
use arrow_flight::flight_service_server::FlightServiceServer;
34
use async_trait::async_trait;
45
use datafusion::common::DataFusionError;
56
use datafusion::execution::SessionStateBuilder;
67
use datafusion::physical_plan::displayable;
78
use datafusion::prelude::{ParquetReadOptions, SessionContext};
89
use datafusion_distributed::{
9-
ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt,
10-
DistributedPhysicalOptimizerRule, DistributedSessionBuilderContext,
10+
ArrowFlightEndpoint, ChannelResolver, DistributedExt, DistributedPhysicalOptimizerRule,
11+
DistributedSessionBuilderContext,
1112
};
1213
use futures::TryStreamExt;
1314
use hyper_util::rt::TokioIo;
@@ -75,7 +76,7 @@ const DUMMY_URL: &str = "http://localhost:50051";
7576
/// tokio duplex rather than a TCP connection.
7677
#[derive(Clone)]
7778
struct InMemoryChannelResolver {
78-
channel: BoxCloneSyncChannel,
79+
channel: FlightServiceClient<tonic::transport::Channel>,
7980
}
8081

8182
impl InMemoryChannelResolver {
@@ -93,7 +94,7 @@ impl InMemoryChannelResolver {
9394
}));
9495

9596
let this = Self {
96-
channel: BoxCloneSyncChannel::new(channel),
97+
channel: FlightServiceClient::new(channel),
9798
};
9899
let this_clone = this.clone();
99100

@@ -127,10 +128,10 @@ impl ChannelResolver for InMemoryChannelResolver {
127128
Ok(vec![url::Url::parse(DUMMY_URL).unwrap()])
128129
}
129130

130-
async fn get_channel_for_url(
131+
async fn get_flight_client_for_url(
131132
&self,
132133
_: &url::Url,
133-
) -> Result<BoxCloneSyncChannel, DataFusionError> {
134+
) -> Result<FlightServiceClient<tonic::transport::Channel>, DataFusionError> {
134135
Ok(self.channel.clone())
135136
}
136137
}

examples/localhost_run.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
use arrow::util::pretty::pretty_format_batches;
2+
use arrow_flight::flight_service_client::FlightServiceClient;
23
use async_trait::async_trait;
34
use dashmap::{DashMap, Entry};
45
use datafusion::common::DataFusionError;
56
use datafusion::execution::SessionStateBuilder;
67
use datafusion::physical_plan::displayable;
78
use datafusion::prelude::{ParquetReadOptions, SessionContext};
8-
use datafusion_distributed::{
9-
BoxCloneSyncChannel, ChannelResolver, DistributedExt, DistributedPhysicalOptimizerRule,
10-
};
9+
use datafusion_distributed::{ChannelResolver, DistributedExt, DistributedPhysicalOptimizerRule};
1110
use futures::TryStreamExt;
1211
use std::error::Error;
1312
use std::sync::Arc;
@@ -83,7 +82,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
8382
#[derive(Clone)]
8483
struct LocalhostChannelResolver {
8584
ports: Vec<u16>,
86-
cached: DashMap<Url, BoxCloneSyncChannel>,
85+
cached: DashMap<Url, FlightServiceClient<Channel>>,
8786
}
8887

8988
#[async_trait]
@@ -96,14 +95,17 @@ impl ChannelResolver for LocalhostChannelResolver {
9695
.collect())
9796
}
9897

99-
async fn get_channel_for_url(&self, url: &Url) -> Result<BoxCloneSyncChannel, DataFusionError> {
98+
async fn get_flight_client_for_url(
99+
&self,
100+
url: &Url,
101+
) -> Result<FlightServiceClient<Channel>, DataFusionError> {
100102
match self.cached.entry(url.clone()) {
101103
Entry::Occupied(v) => Ok(v.get().clone()),
102104
Entry::Vacant(v) => {
103105
let channel = Channel::from_shared(url.to_string())
104106
.unwrap()
105107
.connect_lazy();
106-
let channel = BoxCloneSyncChannel::new(channel);
108+
let channel = FlightServiceClient::new(channel);
107109
v.insert(channel.clone());
108110
Ok(channel)
109111
}

examples/localhost_worker.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
use arrow_flight::flight_service_client::FlightServiceClient;
12
use arrow_flight::flight_service_server::FlightServiceServer;
23
use async_trait::async_trait;
34
use dashmap::{DashMap, Entry};
45
use datafusion::common::DataFusionError;
56
use datafusion::execution::SessionStateBuilder;
67
use datafusion_distributed::{
7-
ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt,
8-
DistributedSessionBuilderContext,
8+
ArrowFlightEndpoint, ChannelResolver, DistributedExt, DistributedSessionBuilderContext,
99
};
1010
use std::error::Error;
1111
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
@@ -55,7 +55,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
5555
#[derive(Clone)]
5656
struct LocalhostChannelResolver {
5757
ports: Vec<u16>,
58-
cached: DashMap<Url, BoxCloneSyncChannel>,
58+
cached: DashMap<Url, FlightServiceClient<Channel>>,
5959
}
6060

6161
#[async_trait]
@@ -68,14 +68,17 @@ impl ChannelResolver for LocalhostChannelResolver {
6868
.collect())
6969
}
7070

71-
async fn get_channel_for_url(&self, url: &Url) -> Result<BoxCloneSyncChannel, DataFusionError> {
71+
async fn get_flight_client_for_url(
72+
&self,
73+
url: &Url,
74+
) -> Result<FlightServiceClient<Channel>, DataFusionError> {
7275
match self.cached.entry(url.clone()) {
7376
Entry::Occupied(v) => Ok(v.get().clone()),
7477
Entry::Vacant(v) => {
7578
let channel = Channel::from_shared(url.to_string())
7679
.unwrap()
7780
.connect_lazy();
78-
let channel = BoxCloneSyncChannel::new(channel);
81+
let channel = FlightServiceClient::new(channel);
7982
v.insert(channel.clone());
8083
Ok(channel)
8184
}

src/channel_resolver_ext.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
use arrow_flight::flight_service_client::FlightServiceClient;
12
use async_trait::async_trait;
23
use datafusion::common::exec_datafusion_err;
34
use datafusion::error::DataFusionError;
45
use datafusion::prelude::SessionConfig;
56
use std::sync::Arc;
67
use tonic::body::BoxBody;
8+
use tonic::transport::Channel;
79
use url::Url;
810

911
pub(crate) fn set_distributed_channel_resolver(
@@ -38,8 +40,11 @@ pub type BoxCloneSyncChannel = tower::util::BoxCloneSyncService<
3840
pub trait ChannelResolver {
3941
/// Gets all available worker URLs. Used during stage assignment.
4042
fn get_urls(&self) -> Result<Vec<Url>, DataFusionError>;
41-
/// For a given URL, get a channel for communicating to it.
42-
async fn get_channel_for_url(&self, url: &Url) -> Result<BoxCloneSyncChannel, DataFusionError>;
43+
/// For a given URL, get an Arrow Flight client for communicating to it.
44+
async fn get_flight_client_for_url(
45+
&self,
46+
url: &Url,
47+
) -> Result<FlightServiceClient<Channel>, DataFusionError>;
4348
}
4449

4550
#[async_trait]
@@ -48,7 +53,10 @@ impl ChannelResolver for Arc<dyn ChannelResolver + Send + Sync> {
4853
self.as_ref().get_urls()
4954
}
5055

51-
async fn get_channel_for_url(&self, url: &Url) -> Result<BoxCloneSyncChannel, DataFusionError> {
52-
self.as_ref().get_channel_for_url(url).await
56+
async fn get_flight_client_for_url(
57+
&self,
58+
url: &Url,
59+
) -> Result<FlightServiceClient<Channel>, DataFusionError> {
60+
self.as_ref().get_flight_client_for_url(url).await
5361
}
5462
}

src/distributed_ext.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,12 @@ pub trait DistributedExt: Sized {
180180
/// Example:
181181
///
182182
/// ```
183+
/// # use arrow_flight::flight_service_client::FlightServiceClient;
183184
/// # use async_trait::async_trait;
184185
/// # use datafusion::common::DataFusionError;
185186
/// # use datafusion::execution::{SessionState, SessionStateBuilder};
186187
/// # use datafusion::prelude::SessionConfig;
188+
/// # use tonic::transport::Channel;
187189
/// # use url::Url;
188190
/// # use datafusion_distributed::{BoxCloneSyncChannel, ChannelResolver, DistributedExt, DistributedSessionBuilderContext};
189191
///
@@ -195,7 +197,7 @@ pub trait DistributedExt: Sized {
195197
/// todo!()
196198
/// }
197199
///
198-
/// async fn get_channel_for_url(&self, url: &Url) -> Result<BoxCloneSyncChannel, DataFusionError> {
200+
/// async fn get_flight_client_for_url(&self, url: &Url) -> Result<FlightServiceClient<Channel>, DataFusionError> {
199201
/// todo!()
200202
/// }
201203
/// }

src/execution_plans/network_coalesce.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ use crate::protobuf::{map_flight_to_datafusion_error, map_status_to_datafusion_e
1212
use arrow_flight::Ticket;
1313
use arrow_flight::decode::FlightRecordBatchStream;
1414
use arrow_flight::error::FlightError;
15-
use arrow_flight::flight_service_client::FlightServiceClient;
1615
use dashmap::DashMap;
1716
use datafusion::common::{exec_err, internal_datafusion_err, internal_err, plan_err};
1817
use datafusion::error::DataFusionError;
@@ -285,8 +284,8 @@ impl ExecutionPlan for NetworkCoalesceExec {
285284

286285
let metrics_collection_capture = self_ready.metrics_collection.clone();
287286
let stream = async move {
288-
let channel = channel_resolver.get_channel_for_url(&url).await?;
289-
let stream = FlightServiceClient::new(channel)
287+
let mut client = channel_resolver.get_flight_client_for_url(&url).await?;
288+
let stream = client
290289
.do_get(ticket)
291290
.await
292291
.map_err(map_status_to_datafusion_error)?

src/execution_plans/network_shuffle.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ use crate::protobuf::{map_flight_to_datafusion_error, map_status_to_datafusion_e
1212
use arrow_flight::Ticket;
1313
use arrow_flight::decode::FlightRecordBatchStream;
1414
use arrow_flight::error::FlightError;
15-
use arrow_flight::flight_service_client::FlightServiceClient;
1615
use dashmap::DashMap;
1716
use datafusion::common::{exec_err, internal_datafusion_err, internal_err, plan_err};
1817
use datafusion::error::DataFusionError;
@@ -337,8 +336,8 @@ impl ExecutionPlan for NetworkShuffleExec {
337336
"NetworkShuffleExec: task is unassigned, cannot proceed"
338337
))?;
339338

340-
let channel = channel_resolver.get_channel_for_url(&url).await?;
341-
let stream = FlightServiceClient::new(channel)
339+
let mut client = channel_resolver.get_flight_client_for_url(&url).await?;
340+
let stream = client
342341
.do_get(ticket)
343342
.await
344343
.map_err(map_status_to_datafusion_error)?

src/test_utils/in_memory_channel_resolver.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
use crate::{
2-
ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt,
3-
DistributedSessionBuilderContext,
2+
ArrowFlightEndpoint, ChannelResolver, DistributedExt, DistributedSessionBuilderContext,
43
};
4+
use arrow_flight::flight_service_client::FlightServiceClient;
55
use arrow_flight::flight_service_server::FlightServiceServer;
66
use async_trait::async_trait;
77
use datafusion::common::DataFusionError;
88
use datafusion::execution::SessionStateBuilder;
99
use hyper_util::rt::TokioIo;
10-
use tonic::transport::{Endpoint, Server};
10+
use tonic::transport::{Channel, Endpoint, Server};
1111

1212
const DUMMY_URL: &str = "http://localhost:50051";
1313

1414
/// [ChannelResolver] implementation that returns gRPC clients backed by an in-memory
1515
/// tokio duplex rather than a TCP connection.
1616
#[derive(Clone)]
1717
pub struct InMemoryChannelResolver {
18-
channel: BoxCloneSyncChannel,
18+
channel: FlightServiceClient<Channel>,
1919
}
2020

2121
impl Default for InMemoryChannelResolver {
@@ -39,7 +39,7 @@ impl InMemoryChannelResolver {
3939
}));
4040

4141
let this = Self {
42-
channel: BoxCloneSyncChannel::new(channel),
42+
channel: FlightServiceClient::new(channel),
4343
};
4444
let this_clone = this.clone();
4545

@@ -73,10 +73,10 @@ impl ChannelResolver for InMemoryChannelResolver {
7373
Ok(vec![url::Url::parse(DUMMY_URL).unwrap()])
7474
}
7575

76-
async fn get_channel_for_url(
76+
async fn get_flight_client_for_url(
7777
&self,
7878
_: &url::Url,
79-
) -> Result<BoxCloneSyncChannel, DataFusionError> {
79+
) -> Result<FlightServiceClient<tonic::transport::Channel>, DataFusionError> {
8080
Ok(self.channel.clone())
8181
}
8282
}

src/test_utils/localhost.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use crate::{
2-
ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt,
3-
DistributedSessionBuilder, DistributedSessionBuilderContext,
4-
MappedDistributedSessionBuilderExt,
2+
ArrowFlightEndpoint, ChannelResolver, DistributedExt, DistributedSessionBuilder,
3+
DistributedSessionBuilderContext, MappedDistributedSessionBuilderExt,
54
};
5+
use arrow_flight::flight_service_client::FlightServiceClient;
66
use arrow_flight::flight_service_server::FlightServiceServer;
77
use async_trait::async_trait;
88
use datafusion::common::DataFusionError;
@@ -98,10 +98,13 @@ impl ChannelResolver for LocalHostChannelResolver {
9898
.map(|url| Url::parse(&url).map_err(external_err))
9999
.collect::<Result<Vec<Url>, _>>()
100100
}
101-
async fn get_channel_for_url(&self, url: &Url) -> Result<BoxCloneSyncChannel, DataFusionError> {
101+
async fn get_flight_client_for_url(
102+
&self,
103+
url: &Url,
104+
) -> Result<FlightServiceClient<Channel>, DataFusionError> {
102105
let endpoint = Channel::from_shared(url.to_string()).map_err(external_err)?;
103106
let channel = endpoint.connect().await.map_err(external_err)?;
104-
Ok(BoxCloneSyncChannel::new(channel))
107+
Ok(FlightServiceClient::new(channel))
105108
}
106109
}
107110

0 commit comments

Comments
 (0)