diff --git a/examples/in_memory_cluster.rs b/examples/in_memory_cluster.rs index 9a084f8..7171a75 100644 --- a/examples/in_memory_cluster.rs +++ b/examples/in_memory_cluster.rs @@ -1,6 +1,5 @@ use arrow::util::pretty::pretty_format_batches; use arrow_flight::flight_service_client::FlightServiceClient; -use arrow_flight::flight_service_server::FlightServiceServer; use async_trait::async_trait; use datafusion::common::DataFusionError; use datafusion::execution::SessionStateBuilder; @@ -8,7 +7,7 @@ use datafusion::physical_plan::displayable; use datafusion::prelude::{ParquetReadOptions, SessionContext}; use datafusion_distributed::{ ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt, - DistributedPhysicalOptimizerRule, DistributedSessionBuilderContext, + DistributedPhysicalOptimizerRule, DistributedSessionBuilderContext, create_flight_client, }; use futures::TryStreamExt; use hyper_util::rt::TokioIo; @@ -98,7 +97,7 @@ impl InMemoryChannelResolver { })); let this = Self { - channel: FlightServiceClient::new(BoxCloneSyncChannel::new(channel)), + channel: create_flight_client(BoxCloneSyncChannel::new(channel)), }; let this_clone = this.clone(); @@ -117,7 +116,7 @@ impl InMemoryChannelResolver { tokio::spawn(async move { Server::builder() - .add_service(FlightServiceServer::new(endpoint)) + .add_service(endpoint.into_flight_server()) .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) .await }); diff --git a/examples/localhost_worker.rs b/examples/localhost_worker.rs index 809175d..f02afbc 100644 --- a/examples/localhost_worker.rs +++ b/examples/localhost_worker.rs @@ -1,12 +1,11 @@ use arrow_flight::flight_service_client::FlightServiceClient; -use arrow_flight::flight_service_server::FlightServiceServer; use async_trait::async_trait; use dashmap::{DashMap, Entry}; use datafusion::common::DataFusionError; use datafusion::execution::SessionStateBuilder; use datafusion_distributed::{ ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt, - DistributedSessionBuilderContext, + DistributedSessionBuilderContext, create_flight_client, }; use std::error::Error; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; @@ -46,7 +45,7 @@ async fn main() -> Result<(), Box> { })?; Server::builder() - .add_service(FlightServiceServer::new(endpoint)) + .add_service(endpoint.into_flight_server()) .serve(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), args.port)) .await?; @@ -79,7 +78,7 @@ impl ChannelResolver for LocalhostChannelResolver { let channel = Channel::from_shared(url.to_string()) .unwrap() .connect_lazy(); - let channel = FlightServiceClient::new(BoxCloneSyncChannel::new(channel)); + let channel = create_flight_client(BoxCloneSyncChannel::new(channel)); v.insert(channel.clone()); Ok(channel) } diff --git a/src/channel_resolver_ext.rs b/src/channel_resolver_ext.rs index c8c1c5b..f49b64e 100644 --- a/src/channel_resolver_ext.rs +++ b/src/channel_resolver_ext.rs @@ -35,11 +35,21 @@ pub type BoxCloneSyncChannel = tower::util::BoxCloneSyncService< /// Abstracts networking details so that users can implement their own network resolution /// mechanism. +/// +/// # Implementation Note +/// +/// When implementing `get_flight_client_for_url`, it is recommended to use the +/// [`create_flight_client`] helper function to ensure clients are configured with +/// appropriate message size limits for internal communication. This helps avoid message +/// size errors when transferring large datasets. #[async_trait] pub trait ChannelResolver { /// Gets all available worker URLs. Used during stage assignment. fn get_urls(&self) -> Result, DataFusionError>; /// For a given URL, get an Arrow Flight client for communicating to it. + /// + /// Consider using [`create_flight_client`] to create the client with appropriate + /// default message size limits. async fn get_flight_client_for_url( &self, url: &Url, @@ -59,3 +69,39 @@ impl ChannelResolver for Arc { self.as_ref().get_flight_client_for_url(url).await } } + +/// Creates a [`FlightServiceClient`] with high default message size limits. +/// +/// This is a convenience function that wraps [`FlightServiceClient::new`] and configures +/// it with `max_decoding_message_size(usize::MAX)` and `max_encoding_message_size(usize::MAX)` +/// to avoid message size limitations for internal communication. +/// +/// Users implementing custom [`ChannelResolver`]s should use this function in their +/// `get_flight_client_for_url` implementations to ensure consistent behavior with built-in +/// implementations. +/// +/// # Example +/// +/// ```rust,ignore +/// use datafusion_distributed::{create_flight_client, BoxCloneSyncChannel, ChannelResolver}; +/// use arrow_flight::flight_service_client::FlightServiceClient; +/// use tonic::transport::Channel; +/// +/// #[async_trait] +/// impl ChannelResolver for MyResolver { +/// async fn get_flight_client_for_url( +/// &self, +/// url: &Url, +/// ) -> Result, DataFusionError> { +/// let channel = Channel::from_shared(url.to_string())?.connect().await?; +/// Ok(create_flight_client(BoxCloneSyncChannel::new(channel))) +/// } +/// } +/// ``` +pub fn create_flight_client( + channel: BoxCloneSyncChannel, +) -> FlightServiceClient { + FlightServiceClient::new(channel) + .max_decoding_message_size(usize::MAX) + .max_encoding_message_size(usize::MAX) +} diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index e51aecd..485fbee 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -152,6 +152,15 @@ impl ArrowFlightEndpoint { // Note that we do garbage collection of unused dictionary values above, so we are not sending // unused dictionary values over the wire. .with_dictionary_handling(DictionaryHandling::Resend) + // Set max flight data size to unlimited. + // This requires servers and clients to also be configured to handle unlimited sizes. + // Using unlimited sizes avoids splitting RecordBatches into multiple FlightData messages, + // which could add significant overhead for large RecordBatches. + // The only reason to split them really is if the client/server are configured with a message size limit, + // which mainly makes sense in a public network scenario where you want to avoid DoS attacks. + // Since all of our Arrow Flight communication happens within trusted data plane networks, + // we can safely use unlimited sizes here. + .with_max_flight_data_size(usize::MAX) .build(stream.map_err(|err| { FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err))) })); diff --git a/src/flight_service/service.rs b/src/flight_service/service.rs index 4b94346..881e992 100644 --- a/src/flight_service/service.rs +++ b/src/flight_service/service.rs @@ -2,7 +2,7 @@ use crate::common::ttl_map::{TTLMap, TTLMapConfig}; use crate::flight_service::DistributedSessionBuilder; use crate::flight_service::do_get::TaskData; use crate::protobuf::StageKey; -use arrow_flight::flight_service_server::FlightService; +use arrow_flight::flight_service_server::{FlightService, FlightServiceServer}; use arrow_flight::{ Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket, @@ -28,6 +28,7 @@ pub struct ArrowFlightEndpoint { pub(super) task_data_entries: Arc>>>, pub(super) session_builder: Arc, pub(super) hooks: ArrowFlightEndpointHooks, + pub(super) max_message_size: Option, } impl ArrowFlightEndpoint { @@ -40,6 +41,7 @@ impl ArrowFlightEndpoint { task_data_entries: Arc::new(ttl_map), session_builder: Arc::new(session_builder), hooks: ArrowFlightEndpointHooks::default(), + max_message_size: Some(usize::MAX), }) } @@ -54,6 +56,44 @@ impl ArrowFlightEndpoint { ) { self.hooks.on_plan.push(Arc::new(hook)); } + + /// Set the maximum message size for FlightData chunks. + /// + /// Defaults to `usize::MAX` to minimize chunking overhead for internal communication. + /// See [`FlightDataEncoderBuilder::with_max_flight_data_size`] for details. + /// + /// If you change this to a lower value, ensure you configure the server's + /// max_encoding_message_size and max_decoding_message_size to at least 2x this value + /// to allow for overhead. For most use cases, the default of `usize::MAX` is appropriate. + /// + /// [`FlightDataEncoderBuilder::with_max_flight_data_size`]: https://arrow.apache.org/rust/arrow_flight/encode/struct.FlightDataEncoderBuilder.html#structfield.max_flight_data_size + pub fn with_max_message_size(mut self, size: usize) -> Self { + self.max_message_size = Some(size); + self + } + + /// Converts this endpoint into a [`FlightServiceServer`] with high default message size limits. + /// + /// This is a convenience method that wraps the endpoint in a [`FlightServiceServer`] and + /// configures it with `max_decoding_message_size(usize::MAX)` and + /// `max_encoding_message_size(usize::MAX)` to avoid message size limitations for internal + /// communication. + /// + /// You can further customize the returned server by chaining additional tonic methods. + /// + /// # Example + /// + /// ```rust,ignore + /// let endpoint = ArrowFlightEndpoint::try_new(session_builder)?; + /// let server = endpoint.into_flight_server(); + /// // Can chain additional tonic methods if needed + /// // let server = server.some_other_tonic_method(...); + /// ``` + pub fn into_flight_server(self) -> FlightServiceServer { + FlightServiceServer::new(self) + .max_decoding_message_size(usize::MAX) + .max_encoding_message_size(usize::MAX) + } } #[async_trait] diff --git a/src/lib.rs b/src/lib.rs index 20491bf..c5d7c0f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,7 +14,7 @@ mod protobuf; #[cfg(any(feature = "integration", test))] pub mod test_utils; -pub use channel_resolver_ext::{BoxCloneSyncChannel, ChannelResolver}; +pub use channel_resolver_ext::{BoxCloneSyncChannel, ChannelResolver, create_flight_client}; pub use distributed_ext::DistributedExt; pub use distributed_planner::{ DistributedConfig, DistributedPhysicalOptimizerRule, InputStageInfo, NetworkBoundary, diff --git a/src/test_utils/in_memory_channel_resolver.rs b/src/test_utils/in_memory_channel_resolver.rs index cd412fd..47fcd15 100644 --- a/src/test_utils/in_memory_channel_resolver.rs +++ b/src/test_utils/in_memory_channel_resolver.rs @@ -1,9 +1,8 @@ use crate::{ ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt, - DistributedSessionBuilderContext, + DistributedSessionBuilderContext, create_flight_client, }; use arrow_flight::flight_service_client::FlightServiceClient; -use arrow_flight::flight_service_server::FlightServiceServer; use async_trait::async_trait; use datafusion::common::DataFusionError; use datafusion::execution::SessionStateBuilder; @@ -40,7 +39,7 @@ impl InMemoryChannelResolver { })); let this = Self { - channel: FlightServiceClient::new(BoxCloneSyncChannel::new(channel)), + channel: create_flight_client(BoxCloneSyncChannel::new(channel)), }; let this_clone = this.clone(); @@ -59,7 +58,7 @@ impl InMemoryChannelResolver { tokio::spawn(async move { Server::builder() - .add_service(FlightServiceServer::new(endpoint)) + .add_service(endpoint.into_flight_server()) .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) .await }); diff --git a/src/test_utils/localhost.rs b/src/test_utils/localhost.rs index f168531..8c0f5e7 100644 --- a/src/test_utils/localhost.rs +++ b/src/test_utils/localhost.rs @@ -1,10 +1,9 @@ use crate::{ ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt, DistributedSessionBuilder, DistributedSessionBuilderContext, - MappedDistributedSessionBuilderExt, + MappedDistributedSessionBuilderExt, create_flight_client, }; use arrow_flight::flight_service_client::FlightServiceClient; -use arrow_flight::flight_service_server::FlightServiceServer; use async_trait::async_trait; use datafusion::common::DataFusionError; use datafusion::common::runtime::JoinSet; @@ -105,7 +104,7 @@ impl ChannelResolver for LocalHostChannelResolver { ) -> Result, DataFusionError> { let endpoint = Channel::from_shared(url.to_string()).map_err(external_err)?; let channel = endpoint.connect().await.map_err(external_err)?; - Ok(FlightServiceClient::new(BoxCloneSyncChannel::new(channel))) + Ok(create_flight_client(BoxCloneSyncChannel::new(channel))) } } @@ -118,7 +117,7 @@ pub async fn spawn_flight_service( let incoming = tokio_stream::wrappers::TcpListenerStream::new(incoming); Ok(Server::builder() - .add_service(FlightServiceServer::new(endpoint)) + .add_service(endpoint.into_flight_server()) .serve_with_incoming(incoming) .await?) }