Skip to content

Commit ce4e907

Browse files
authored
Evolve ChannelResolver trait for requiring a FlightClient instead of a channel (#172)
1 parent 3fe4b08 commit ce4e907

File tree

10 files changed

+50
-30
lines changed

10 files changed

+50
-30
lines changed

examples/in_memory_cluster.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use arrow::util::pretty::pretty_format_batches;
2+
use arrow_flight::flight_service_client::FlightServiceClient;
23
use arrow_flight::flight_service_server::FlightServiceServer;
34
use async_trait::async_trait;
45
use datafusion::common::DataFusionError;
@@ -75,7 +76,7 @@ const DUMMY_URL: &str = "http://localhost:50051";
7576
/// tokio duplex rather than a TCP connection.
7677
#[derive(Clone)]
7778
struct InMemoryChannelResolver {
78-
channel: BoxCloneSyncChannel,
79+
channel: FlightServiceClient<BoxCloneSyncChannel>,
7980
}
8081

8182
impl InMemoryChannelResolver {
@@ -93,7 +94,7 @@ impl InMemoryChannelResolver {
9394
}));
9495

9596
let this = Self {
96-
channel: BoxCloneSyncChannel::new(channel),
97+
channel: FlightServiceClient::new(BoxCloneSyncChannel::new(channel)),
9798
};
9899
let this_clone = this.clone();
99100

@@ -127,10 +128,10 @@ impl ChannelResolver for InMemoryChannelResolver {
127128
Ok(vec![url::Url::parse(DUMMY_URL).unwrap()])
128129
}
129130

130-
async fn get_channel_for_url(
131+
async fn get_flight_client_for_url(
131132
&self,
132133
_: &url::Url,
133-
) -> Result<BoxCloneSyncChannel, DataFusionError> {
134+
) -> Result<FlightServiceClient<BoxCloneSyncChannel>, DataFusionError> {
134135
Ok(self.channel.clone())
135136
}
136137
}

examples/localhost_run.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use arrow::util::pretty::pretty_format_batches;
2+
use arrow_flight::flight_service_client::FlightServiceClient;
23
use async_trait::async_trait;
34
use dashmap::{DashMap, Entry};
45
use datafusion::common::DataFusionError;
@@ -83,7 +84,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
8384
#[derive(Clone)]
8485
struct LocalhostChannelResolver {
8586
ports: Vec<u16>,
86-
cached: DashMap<Url, BoxCloneSyncChannel>,
87+
cached: DashMap<Url, FlightServiceClient<BoxCloneSyncChannel>>,
8788
}
8889

8990
#[async_trait]
@@ -96,14 +97,17 @@ impl ChannelResolver for LocalhostChannelResolver {
9697
.collect())
9798
}
9899

99-
async fn get_channel_for_url(&self, url: &Url) -> Result<BoxCloneSyncChannel, DataFusionError> {
100+
async fn get_flight_client_for_url(
101+
&self,
102+
url: &Url,
103+
) -> Result<FlightServiceClient<BoxCloneSyncChannel>, DataFusionError> {
100104
match self.cached.entry(url.clone()) {
101105
Entry::Occupied(v) => Ok(v.get().clone()),
102106
Entry::Vacant(v) => {
103107
let channel = Channel::from_shared(url.to_string())
104108
.unwrap()
105109
.connect_lazy();
106-
let channel = BoxCloneSyncChannel::new(channel);
110+
let channel = FlightServiceClient::new(BoxCloneSyncChannel::new(channel));
107111
v.insert(channel.clone());
108112
Ok(channel)
109113
}

examples/localhost_worker.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use arrow_flight::flight_service_client::FlightServiceClient;
12
use arrow_flight::flight_service_server::FlightServiceServer;
23
use async_trait::async_trait;
34
use dashmap::{DashMap, Entry};
@@ -55,7 +56,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
5556
#[derive(Clone)]
5657
struct LocalhostChannelResolver {
5758
ports: Vec<u16>,
58-
cached: DashMap<Url, BoxCloneSyncChannel>,
59+
cached: DashMap<Url, FlightServiceClient<BoxCloneSyncChannel>>,
5960
}
6061

6162
#[async_trait]
@@ -68,14 +69,17 @@ impl ChannelResolver for LocalhostChannelResolver {
6869
.collect())
6970
}
7071

71-
async fn get_channel_for_url(&self, url: &Url) -> Result<BoxCloneSyncChannel, DataFusionError> {
72+
async fn get_flight_client_for_url(
73+
&self,
74+
url: &Url,
75+
) -> Result<FlightServiceClient<BoxCloneSyncChannel>, DataFusionError> {
7276
match self.cached.entry(url.clone()) {
7377
Entry::Occupied(v) => Ok(v.get().clone()),
7478
Entry::Vacant(v) => {
7579
let channel = Channel::from_shared(url.to_string())
7680
.unwrap()
7781
.connect_lazy();
78-
let channel = BoxCloneSyncChannel::new(channel);
82+
let channel = FlightServiceClient::new(BoxCloneSyncChannel::new(channel));
7983
v.insert(channel.clone());
8084
Ok(channel)
8185
}

src/channel_resolver_ext.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use arrow_flight::flight_service_client::FlightServiceClient;
12
use async_trait::async_trait;
23
use datafusion::common::exec_datafusion_err;
34
use datafusion::error::DataFusionError;
@@ -38,8 +39,11 @@ pub type BoxCloneSyncChannel = tower::util::BoxCloneSyncService<
3839
pub trait ChannelResolver {
3940
/// Gets all available worker URLs. Used during stage assignment.
4041
fn get_urls(&self) -> Result<Vec<Url>, DataFusionError>;
41-
/// For a given URL, get a channel for communicating to it.
42-
async fn get_channel_for_url(&self, url: &Url) -> Result<BoxCloneSyncChannel, DataFusionError>;
42+
/// For a given URL, get an Arrow Flight client for communicating to it.
43+
async fn get_flight_client_for_url(
44+
&self,
45+
url: &Url,
46+
) -> Result<FlightServiceClient<BoxCloneSyncChannel>, DataFusionError>;
4347
}
4448

4549
#[async_trait]
@@ -48,7 +52,10 @@ impl ChannelResolver for Arc<dyn ChannelResolver + Send + Sync> {
4852
self.as_ref().get_urls()
4953
}
5054

51-
async fn get_channel_for_url(&self, url: &Url) -> Result<BoxCloneSyncChannel, DataFusionError> {
52-
self.as_ref().get_channel_for_url(url).await
55+
async fn get_flight_client_for_url(
56+
&self,
57+
url: &Url,
58+
) -> Result<FlightServiceClient<BoxCloneSyncChannel>, DataFusionError> {
59+
self.as_ref().get_flight_client_for_url(url).await
5360
}
5461
}

src/distributed_ext.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ pub trait DistributedExt: Sized {
180180
/// Example:
181181
///
182182
/// ```
183+
/// # use arrow_flight::flight_service_client::FlightServiceClient;
183184
/// # use async_trait::async_trait;
184185
/// # use datafusion::common::DataFusionError;
185186
/// # use datafusion::execution::{SessionState, SessionStateBuilder};
@@ -195,7 +196,7 @@ pub trait DistributedExt: Sized {
195196
/// todo!()
196197
/// }
197198
///
198-
/// async fn get_channel_for_url(&self, url: &Url) -> Result<BoxCloneSyncChannel, DataFusionError> {
199+
/// async fn get_flight_client_for_url(&self, url: &Url) -> Result<FlightServiceClient<BoxCloneSyncChannel>, DataFusionError> {
199200
/// todo!()
200201
/// }
201202
/// }

src/execution_plans/network_coalesce.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ use crate::protobuf::{map_flight_to_datafusion_error, map_status_to_datafusion_e
1212
use arrow_flight::Ticket;
1313
use arrow_flight::decode::FlightRecordBatchStream;
1414
use arrow_flight::error::FlightError;
15-
use arrow_flight::flight_service_client::FlightServiceClient;
1615
use dashmap::DashMap;
1716
use datafusion::common::{exec_err, internal_datafusion_err, internal_err, plan_err};
1817
use datafusion::error::DataFusionError;
@@ -285,8 +284,8 @@ impl ExecutionPlan for NetworkCoalesceExec {
285284

286285
let metrics_collection_capture = self_ready.metrics_collection.clone();
287286
let stream = async move {
288-
let channel = channel_resolver.get_channel_for_url(&url).await?;
289-
let stream = FlightServiceClient::new(channel)
287+
let mut client = channel_resolver.get_flight_client_for_url(&url).await?;
288+
let stream = client
290289
.do_get(ticket)
291290
.await
292291
.map_err(map_status_to_datafusion_error)?

src/execution_plans/network_shuffle.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ use crate::protobuf::{map_flight_to_datafusion_error, map_status_to_datafusion_e
1212
use arrow_flight::Ticket;
1313
use arrow_flight::decode::FlightRecordBatchStream;
1414
use arrow_flight::error::FlightError;
15-
use arrow_flight::flight_service_client::FlightServiceClient;
1615
use dashmap::DashMap;
1716
use datafusion::common::{exec_err, internal_datafusion_err, internal_err, plan_err};
1817
use datafusion::error::DataFusionError;
@@ -337,8 +336,8 @@ impl ExecutionPlan for NetworkShuffleExec {
337336
"NetworkShuffleExec: task is unassigned, cannot proceed"
338337
))?;
339338

340-
let channel = channel_resolver.get_channel_for_url(&url).await?;
341-
let stream = FlightServiceClient::new(channel)
339+
let mut client = channel_resolver.get_flight_client_for_url(&url).await?;
340+
let stream = client
342341
.do_get(ticket)
343342
.await
344343
.map_err(map_status_to_datafusion_error)?

src/flight_service/do_get.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@ use crate::protobuf::{
1010
AppMetadata, DistributedCodec, FlightAppMetadata, MetricsCollection, StageKey, TaskMetrics,
1111
datafusion_error_to_tonic_status, stage_from_proto,
1212
};
13-
use arrow::array::RecordBatch;
14-
use arrow::datatypes::SchemaRef;
15-
use arrow::ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
1613
use arrow_flight::FlightData;
1714
use arrow_flight::Ticket;
1815
use arrow_flight::encode::FlightDataEncoderBuilder;
1916
use arrow_flight::error::FlightError;
2017
use arrow_flight::flight_service_server::FlightService;
2118
use bytes::Bytes;
19+
use datafusion::arrow::array::RecordBatch;
20+
use datafusion::arrow::datatypes::SchemaRef;
21+
use datafusion::arrow::ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
2222
use datafusion::common::exec_datafusion_err;
2323
use datafusion::execution::SendableRecordBatchStream;
2424
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;

src/test_utils/in_memory_channel_resolver.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use crate::{
22
ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt,
33
DistributedSessionBuilderContext,
44
};
5+
use arrow_flight::flight_service_client::FlightServiceClient;
56
use arrow_flight::flight_service_server::FlightServiceServer;
67
use async_trait::async_trait;
78
use datafusion::common::DataFusionError;
@@ -15,7 +16,7 @@ const DUMMY_URL: &str = "http://localhost:50051";
1516
/// tokio duplex rather than a TCP connection.
1617
#[derive(Clone)]
1718
pub struct InMemoryChannelResolver {
18-
channel: BoxCloneSyncChannel,
19+
channel: FlightServiceClient<BoxCloneSyncChannel>,
1920
}
2021

2122
impl Default for InMemoryChannelResolver {
@@ -39,7 +40,7 @@ impl InMemoryChannelResolver {
3940
}));
4041

4142
let this = Self {
42-
channel: BoxCloneSyncChannel::new(channel),
43+
channel: FlightServiceClient::new(BoxCloneSyncChannel::new(channel)),
4344
};
4445
let this_clone = this.clone();
4546

@@ -73,10 +74,10 @@ impl ChannelResolver for InMemoryChannelResolver {
7374
Ok(vec![url::Url::parse(DUMMY_URL).unwrap()])
7475
}
7576

76-
async fn get_channel_for_url(
77+
async fn get_flight_client_for_url(
7778
&self,
7879
_: &url::Url,
79-
) -> Result<BoxCloneSyncChannel, DataFusionError> {
80+
) -> Result<FlightServiceClient<BoxCloneSyncChannel>, DataFusionError> {
8081
Ok(self.channel.clone())
8182
}
8283
}

src/test_utils/localhost.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::{
33
DistributedSessionBuilder, DistributedSessionBuilderContext,
44
MappedDistributedSessionBuilderExt,
55
};
6+
use arrow_flight::flight_service_client::FlightServiceClient;
67
use arrow_flight::flight_service_server::FlightServiceServer;
78
use async_trait::async_trait;
89
use datafusion::common::DataFusionError;
@@ -98,10 +99,13 @@ impl ChannelResolver for LocalHostChannelResolver {
9899
.map(|url| Url::parse(&url).map_err(external_err))
99100
.collect::<Result<Vec<Url>, _>>()
100101
}
101-
async fn get_channel_for_url(&self, url: &Url) -> Result<BoxCloneSyncChannel, DataFusionError> {
102+
async fn get_flight_client_for_url(
103+
&self,
104+
url: &Url,
105+
) -> Result<FlightServiceClient<BoxCloneSyncChannel>, DataFusionError> {
102106
let endpoint = Channel::from_shared(url.to_string()).map_err(external_err)?;
103107
let channel = endpoint.connect().await.map_err(external_err)?;
104-
Ok(BoxCloneSyncChannel::new(channel))
108+
Ok(FlightServiceClient::new(BoxCloneSyncChannel::new(channel)))
105109
}
106110
}
107111

0 commit comments

Comments
 (0)