Skip to content

Commit 7024098

Browse files
committed
use usize::MAX and add helpers to create clients/servers
1 parent d94747b commit 7024098

File tree

7 files changed

+92
-77
lines changed

7 files changed

+92
-77
lines changed

examples/in_memory_cluster.rs

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
use arrow::util::pretty::pretty_format_batches;
22
use arrow_flight::flight_service_client::FlightServiceClient;
3-
use arrow_flight::flight_service_server::FlightServiceServer;
43
use async_trait::async_trait;
54
use datafusion::common::DataFusionError;
65
use datafusion::execution::SessionStateBuilder;
76
use datafusion::physical_plan::displayable;
87
use datafusion::prelude::{ParquetReadOptions, SessionContext};
98
use datafusion_distributed::{
109
ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt,
11-
DistributedPhysicalOptimizerRule, DistributedSessionBuilderContext,
10+
DistributedPhysicalOptimizerRule, DistributedSessionBuilderContext, create_flight_client,
1211
};
1312
use futures::TryStreamExt;
1413
use hyper_util::rt::TokioIo;
@@ -76,13 +75,6 @@ async fn main() -> Result<(), Box<dyn Error>> {
7675

7776
const DUMMY_URL: &str = "http://localhost:50051";
7877

79-
/// Maximum message size for FlightData chunks in ArrowFlightEndpoint.
80-
const ENDPOINT_MESSAGE_SIZE: usize = 128 * 1024 * 1024; // 128 MB
81-
82-
/// Maximum message size for gRPC server encoding and decoding.
83-
/// This should be 2x the ArrowFlightEndpoint max_message_size to allow for overhead.
84-
const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024; // 256 MB
85-
8678
/// [ChannelResolver] implementation that returns gRPC clients baked by an in-memory
8779
/// tokio duplex rather than a TCP connection.
8880
#[derive(Clone)]
@@ -105,7 +97,7 @@ impl InMemoryChannelResolver {
10597
}));
10698

10799
let this = Self {
108-
channel: FlightServiceClient::new(BoxCloneSyncChannel::new(channel)),
100+
channel: create_flight_client(BoxCloneSyncChannel::new(channel)),
109101
};
110102
let this_clone = this.clone();
111103

@@ -120,16 +112,11 @@ impl InMemoryChannelResolver {
120112
Ok(builder.build())
121113
}
122114
})
123-
.unwrap()
124-
.with_max_message_size(ENDPOINT_MESSAGE_SIZE);
115+
.unwrap();
125116

126117
tokio::spawn(async move {
127118
Server::builder()
128-
.add_service(
129-
FlightServiceServer::new(endpoint)
130-
.max_decoding_message_size(MAX_MESSAGE_SIZE)
131-
.max_encoding_message_size(MAX_MESSAGE_SIZE),
132-
)
119+
.add_service(endpoint.into_flight_server())
133120
.serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
134121
.await
135122
});

examples/localhost_worker.rs

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
use arrow_flight::flight_service_client::FlightServiceClient;
2-
use arrow_flight::flight_service_server::FlightServiceServer;
32
use async_trait::async_trait;
43
use dashmap::{DashMap, Entry};
54
use datafusion::common::DataFusionError;
65
use datafusion::execution::SessionStateBuilder;
76
use datafusion_distributed::{
87
ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt,
9-
DistributedSessionBuilderContext,
8+
DistributedSessionBuilderContext, create_flight_client,
109
};
1110
use std::error::Error;
1211
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
@@ -25,13 +24,6 @@ struct Args {
2524
cluster_ports: Vec<u16>,
2625
}
2726

28-
/// Maximum message size for FlightData chunks in ArrowFlightEndpoint.
29-
const ENDPOINT_MESSAGE_SIZE: usize = 128 * 1024 * 1024; // 128 MB
30-
31-
/// Maximum message size for gRPC server encoding and decoding.
32-
/// This should be 2x the ArrowFlightEndpoint max_message_size to allow for overhead.
33-
const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024; // 256 MB
34-
3527
#[tokio::main]
3628
async fn main() -> Result<(), Box<dyn Error>> {
3729
let args = Args::from_args();
@@ -50,15 +42,10 @@ async fn main() -> Result<(), Box<dyn Error>> {
5042
.with_default_features()
5143
.build())
5244
}
53-
})?
54-
.with_max_message_size(ENDPOINT_MESSAGE_SIZE);
45+
})?;
5546

5647
Server::builder()
57-
.add_service(
58-
FlightServiceServer::new(endpoint)
59-
.max_decoding_message_size(MAX_MESSAGE_SIZE)
60-
.max_encoding_message_size(MAX_MESSAGE_SIZE),
61-
)
48+
.add_service(endpoint.into_flight_server())
6249
.serve(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), args.port))
6350
.await?;
6451

@@ -91,7 +78,7 @@ impl ChannelResolver for LocalhostChannelResolver {
9178
let channel = Channel::from_shared(url.to_string())
9279
.unwrap()
9380
.connect_lazy();
94-
let channel = FlightServiceClient::new(BoxCloneSyncChannel::new(channel));
81+
let channel = create_flight_client(BoxCloneSyncChannel::new(channel));
9582
v.insert(channel.clone());
9683
Ok(channel)
9784
}

src/channel_resolver_ext.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,21 @@ pub type BoxCloneSyncChannel = tower::util::BoxCloneSyncService<
3535

3636
/// Abstracts networking details so that users can implement their own network resolution
3737
/// mechanism.
38+
///
39+
/// # Implementation Note
40+
///
41+
/// When implementing `get_flight_client_for_url`, it is recommended to use the
42+
/// [`create_flight_client`] helper function to ensure clients are configured with
43+
/// appropriate message size limits for internal communication. This helps avoid message
44+
/// size errors when transferring large datasets.
3845
#[async_trait]
3946
pub trait ChannelResolver {
4047
/// Gets all available worker URLs. Used during stage assignment.
4148
fn get_urls(&self) -> Result<Vec<Url>, DataFusionError>;
4249
/// For a given URL, get an Arrow Flight client for communicating to it.
50+
///
51+
/// Consider using [`create_flight_client`] to create the client with appropriate
52+
/// default message size limits.
4353
async fn get_flight_client_for_url(
4454
&self,
4555
url: &Url,
@@ -59,3 +69,39 @@ impl ChannelResolver for Arc<dyn ChannelResolver + Send + Sync> {
5969
self.as_ref().get_flight_client_for_url(url).await
6070
}
6171
}
72+
73+
/// Creates a [`FlightServiceClient`] with high default message size limits.
74+
///
75+
/// This is a convenience function that wraps [`FlightServiceClient::new`] and configures
76+
/// it with `max_decoding_message_size(usize::MAX)` and `max_encoding_message_size(usize::MAX)`
77+
/// to avoid message size limitations for internal communication.
78+
///
79+
/// Users implementing custom [`ChannelResolver`]s should use this function in their
80+
/// `get_flight_client_for_url` implementations to ensure consistent behavior with built-in
81+
/// implementations.
82+
///
83+
/// # Example
84+
///
85+
/// ```rust,ignore
86+
/// use datafusion_distributed::{create_flight_client, BoxCloneSyncChannel, ChannelResolver};
87+
/// use arrow_flight::flight_service_client::FlightServiceClient;
88+
/// use tonic::transport::Channel;
89+
///
90+
/// #[async_trait]
91+
/// impl ChannelResolver for MyResolver {
92+
/// async fn get_flight_client_for_url(
93+
/// &self,
94+
/// url: &Url,
95+
/// ) -> Result<FlightServiceClient<BoxCloneSyncChannel>, DataFusionError> {
96+
/// let channel = Channel::from_shared(url.to_string())?.connect().await?;
97+
/// Ok(create_flight_client(BoxCloneSyncChannel::new(channel)))
98+
/// }
99+
/// }
100+
/// ```
101+
pub fn create_flight_client(
102+
channel: BoxCloneSyncChannel,
103+
) -> FlightServiceClient<BoxCloneSyncChannel> {
104+
FlightServiceClient::new(channel)
105+
.max_decoding_message_size(usize::MAX)
106+
.max_encoding_message_size(usize::MAX)
107+
}

src/flight_service/service.rs

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::common::ttl_map::{TTLMap, TTLMapConfig};
22
use crate::flight_service::DistributedSessionBuilder;
33
use crate::flight_service::do_get::TaskData;
44
use crate::protobuf::StageKey;
5-
use arrow_flight::flight_service_server::FlightService;
5+
use arrow_flight::flight_service_server::{FlightService, FlightServiceServer};
66
use arrow_flight::{
77
Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
88
HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
@@ -41,7 +41,7 @@ impl ArrowFlightEndpoint {
4141
task_data_entries: Arc::new(ttl_map),
4242
session_builder: Arc::new(session_builder),
4343
hooks: ArrowFlightEndpointHooks::default(),
44-
max_message_size: None,
44+
max_message_size: Some(usize::MAX),
4545
})
4646
}
4747

@@ -59,20 +59,41 @@ impl ArrowFlightEndpoint {
5959

6060
/// Set the maximum message size for FlightData chunks.
6161
///
62-
/// Defaults to None, which uses `arrow-rs` default, curerntly 2MB.
62+
/// Defaults to `usize::MAX` to minimize chunking overhead for internal communication.
6363
/// See [`FlightDataEncoderBuilder::with_max_flight_data_size`] for details.
6464
///
65-
/// If you change this, ensure you configure the server's max_encoding_message_size and
66-
/// max_decoding_message_size to at least 2x this value to allow for overhead.
67-
/// If your service communication is purely internal and there is no risk of DOS attacks,
68-
/// you may want to set this to a considerably larger value to minimize the overhead of chunking
69-
/// larger datasets.
65+
/// If you change this to a lower value, ensure you configure the server's
66+
/// max_encoding_message_size and max_decoding_message_size to at least 2x this value
67+
/// to allow for overhead. For most use cases, the default of `usize::MAX` is appropriate.
7068
///
7169
/// [`FlightDataEncoderBuilder::with_max_flight_data_size`]: https://arrow.apache.org/rust/arrow_flight/encode/struct.FlightDataEncoderBuilder.html#structfield.max_flight_data_size
7270
pub fn with_max_message_size(mut self, size: usize) -> Self {
7371
self.max_message_size = Some(size);
7472
self
7573
}
74+
75+
/// Converts this endpoint into a [`FlightServiceServer`] with high default message size limits.
76+
///
77+
/// This is a convenience method that wraps the endpoint in a [`FlightServiceServer`] and
78+
/// configures it with `max_decoding_message_size(usize::MAX)` and
79+
/// `max_encoding_message_size(usize::MAX)` to avoid message size limitations for internal
80+
/// communication.
81+
///
82+
/// You can further customize the returned server by chaining additional tonic methods.
83+
///
84+
/// # Example
85+
///
86+
/// ```rust,ignore
87+
/// let endpoint = ArrowFlightEndpoint::try_new(session_builder)?;
88+
/// let server = endpoint.into_flight_server();
89+
/// // Can chain additional tonic methods if needed
90+
/// // let server = server.some_other_tonic_method(...);
91+
/// ```
92+
pub fn into_flight_server(self) -> FlightServiceServer<Self> {
93+
FlightServiceServer::new(self)
94+
.max_decoding_message_size(usize::MAX)
95+
.max_encoding_message_size(usize::MAX)
96+
}
7697
}
7798

7899
#[async_trait]

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ mod protobuf;
1414
#[cfg(any(feature = "integration", test))]
1515
pub mod test_utils;
1616

17-
pub use channel_resolver_ext::{BoxCloneSyncChannel, ChannelResolver};
17+
pub use channel_resolver_ext::{BoxCloneSyncChannel, ChannelResolver, create_flight_client};
1818
pub use distributed_ext::DistributedExt;
1919
pub use distributed_planner::{
2020
DistributedConfig, DistributedPhysicalOptimizerRule, InputStageInfo, NetworkBoundary,

src/test_utils/in_memory_channel_resolver.rs

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
use crate::{
22
ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt,
3-
DistributedSessionBuilderContext,
3+
DistributedSessionBuilderContext, create_flight_client,
44
};
55
use arrow_flight::flight_service_client::FlightServiceClient;
6-
use arrow_flight::flight_service_server::FlightServiceServer;
76
use async_trait::async_trait;
87
use datafusion::common::DataFusionError;
98
use datafusion::execution::SessionStateBuilder;
@@ -12,13 +11,6 @@ use tonic::transport::{Endpoint, Server};
1211

1312
const DUMMY_URL: &str = "http://localhost:50051";
1413

15-
/// Maximum message size for FlightData chunks in ArrowFlightEndpoint.
16-
const ENDPOINT_MESSAGE_SIZE: usize = 128 * 1024 * 1024; // 128 MB
17-
18-
/// Maximum message size for gRPC server encoding and decoding.
19-
/// This should be 2x the ArrowFlightEndpoint max_message_size to allow for overhead.
20-
const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024; // 256 MB
21-
2214
/// [ChannelResolver] implementation that returns gRPC clients backed by an in-memory
2315
/// tokio duplex rather than a TCP connection.
2416
#[derive(Clone)]
@@ -47,7 +39,7 @@ impl InMemoryChannelResolver {
4739
}));
4840

4941
let this = Self {
50-
channel: FlightServiceClient::new(BoxCloneSyncChannel::new(channel)),
42+
channel: create_flight_client(BoxCloneSyncChannel::new(channel)),
5143
};
5244
let this_clone = this.clone();
5345

@@ -62,16 +54,11 @@ impl InMemoryChannelResolver {
6254
Ok(builder.build())
6355
}
6456
})
65-
.unwrap()
66-
.with_max_message_size(ENDPOINT_MESSAGE_SIZE);
57+
.unwrap();
6758

6859
tokio::spawn(async move {
6960
Server::builder()
70-
.add_service(
71-
FlightServiceServer::new(endpoint)
72-
.max_decoding_message_size(MAX_MESSAGE_SIZE)
73-
.max_encoding_message_size(MAX_MESSAGE_SIZE),
74-
)
61+
.add_service(endpoint.into_flight_server())
7562
.serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
7663
.await
7764
});

src/test_utils/localhost.rs

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
use crate::{
22
ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt,
33
DistributedSessionBuilder, DistributedSessionBuilderContext,
4-
MappedDistributedSessionBuilderExt,
4+
MappedDistributedSessionBuilderExt, create_flight_client,
55
};
66
use arrow_flight::flight_service_client::FlightServiceClient;
7-
use arrow_flight::flight_service_server::FlightServiceServer;
87
use async_trait::async_trait;
98
use datafusion::common::DataFusionError;
109
use datafusion::common::runtime::JoinSet;
@@ -18,13 +17,6 @@ use tokio::net::TcpListener;
1817
use tonic::transport::{Channel, Server};
1918
use url::Url;
2019

21-
/// Maximum message size for FlightData chunks in ArrowFlightEndpoint.
22-
const ENDPOINT_MESSAGE_SIZE: usize = 128 * 1024 * 1024; // 128 MB
23-
24-
/// Maximum message size for gRPC server encoding and decoding.
25-
/// This should be 2x the ArrowFlightEndpoint max_message_size to allow for overhead.
26-
const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024; // 256 MB
27-
2820
pub async fn start_localhost_context<B>(
2921
num_workers: usize,
3022
session_builder: B,
@@ -112,25 +104,20 @@ impl ChannelResolver for LocalHostChannelResolver {
112104
) -> Result<FlightServiceClient<BoxCloneSyncChannel>, DataFusionError> {
113105
let endpoint = Channel::from_shared(url.to_string()).map_err(external_err)?;
114106
let channel = endpoint.connect().await.map_err(external_err)?;
115-
Ok(FlightServiceClient::new(BoxCloneSyncChannel::new(channel)))
107+
Ok(create_flight_client(BoxCloneSyncChannel::new(channel)))
116108
}
117109
}
118110

119111
pub async fn spawn_flight_service(
120112
session_builder: impl DistributedSessionBuilder + Send + Sync + 'static,
121113
incoming: TcpListener,
122114
) -> Result<(), Box<dyn Error + Send + Sync>> {
123-
let endpoint =
124-
ArrowFlightEndpoint::try_new(session_builder)?.with_max_message_size(ENDPOINT_MESSAGE_SIZE);
115+
let endpoint = ArrowFlightEndpoint::try_new(session_builder)?;
125116

126117
let incoming = tokio_stream::wrappers::TcpListenerStream::new(incoming);
127118

128119
Ok(Server::builder()
129-
.add_service(
130-
FlightServiceServer::new(endpoint)
131-
.max_decoding_message_size(MAX_MESSAGE_SIZE)
132-
.max_encoding_message_size(MAX_MESSAGE_SIZE),
133-
)
120+
.add_service(endpoint.into_flight_server())
134121
.serve_with_incoming(incoming)
135122
.await?)
136123
}

0 commit comments

Comments
 (0)