Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions examples/in_memory_cluster.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
use arrow::util::pretty::pretty_format_batches;
use arrow_flight::flight_service_client::FlightServiceClient;
use arrow_flight::flight_service_server::FlightServiceServer;
use async_trait::async_trait;
use datafusion::common::DataFusionError;
use datafusion::execution::SessionStateBuilder;
use datafusion::physical_plan::displayable;
use datafusion::prelude::{ParquetReadOptions, SessionContext};
use datafusion_distributed::{
ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt,
DistributedPhysicalOptimizerRule, DistributedSessionBuilderContext,
DistributedPhysicalOptimizerRule, DistributedSessionBuilderContext, create_flight_client,
};
use futures::TryStreamExt;
use hyper_util::rt::TokioIo;
Expand Down Expand Up @@ -98,7 +97,7 @@ impl InMemoryChannelResolver {
}));

let this = Self {
channel: FlightServiceClient::new(BoxCloneSyncChannel::new(channel)),
channel: create_flight_client(BoxCloneSyncChannel::new(channel)),
};
let this_clone = this.clone();

Expand All @@ -117,7 +116,7 @@ impl InMemoryChannelResolver {

tokio::spawn(async move {
Server::builder()
.add_service(FlightServiceServer::new(endpoint))
.add_service(endpoint.into_flight_server())
.serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
.await
});
Expand Down
7 changes: 3 additions & 4 deletions examples/localhost_worker.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use arrow_flight::flight_service_client::FlightServiceClient;
use arrow_flight::flight_service_server::FlightServiceServer;
use async_trait::async_trait;
use dashmap::{DashMap, Entry};
use datafusion::common::DataFusionError;
use datafusion::execution::SessionStateBuilder;
use datafusion_distributed::{
ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt,
DistributedSessionBuilderContext,
DistributedSessionBuilderContext, create_flight_client,
};
use std::error::Error;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
Expand Down Expand Up @@ -46,7 +45,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
})?;

Server::builder()
.add_service(FlightServiceServer::new(endpoint))
.add_service(endpoint.into_flight_server())
.serve(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), args.port))
.await?;

Expand Down Expand Up @@ -79,7 +78,7 @@ impl ChannelResolver for LocalhostChannelResolver {
let channel = Channel::from_shared(url.to_string())
.unwrap()
.connect_lazy();
let channel = FlightServiceClient::new(BoxCloneSyncChannel::new(channel));
let channel = create_flight_client(BoxCloneSyncChannel::new(channel));
v.insert(channel.clone());
Ok(channel)
}
Expand Down
46 changes: 46 additions & 0 deletions src/channel_resolver_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,21 @@ pub type BoxCloneSyncChannel = tower::util::BoxCloneSyncService<

/// Abstracts networking details so that users can implement their own network resolution
/// mechanism.
///
/// # Implementation Note
///
/// When implementing `get_flight_client_for_url`, it is recommended to use the
/// [`create_flight_client`] helper function to ensure clients are configured with
/// appropriate message size limits for internal communication. This helps avoid message
/// size errors when transferring large datasets.
#[async_trait]
pub trait ChannelResolver {
/// Gets all available worker URLs. Used during stage assignment.
fn get_urls(&self) -> Result<Vec<Url>, DataFusionError>;
/// For a given URL, get an Arrow Flight client for communicating to it.
///
/// Consider using [`create_flight_client`] to create the client with appropriate
/// default message size limits.
async fn get_flight_client_for_url(
&self,
url: &Url,
Expand All @@ -59,3 +69,39 @@ impl ChannelResolver for Arc<dyn ChannelResolver + Send + Sync> {
self.as_ref().get_flight_client_for_url(url).await
}
}

/// Creates a [`FlightServiceClient`] with high default message size limits.
///
/// This is a convenience function that wraps [`FlightServiceClient::new`] and configures
/// it with `max_decoding_message_size(usize::MAX)` and `max_encoding_message_size(usize::MAX)`
/// to avoid message size limitations for internal communication.
///
/// Users implementing custom [`ChannelResolver`]s should use this function in their
/// `get_flight_client_for_url` implementations to ensure consistent behavior with built-in
/// implementations.
///
/// # Example
///
/// ```rust,ignore
/// use datafusion_distributed::{create_flight_client, BoxCloneSyncChannel, ChannelResolver};
/// use arrow_flight::flight_service_client::FlightServiceClient;
/// use tonic::transport::Channel;
///
/// #[async_trait]
/// impl ChannelResolver for MyResolver {
/// async fn get_flight_client_for_url(
/// &self,
/// url: &Url,
/// ) -> Result<FlightServiceClient<BoxCloneSyncChannel>, DataFusionError> {
/// let channel = Channel::from_shared(url.to_string())?.connect().await?;
/// Ok(create_flight_client(BoxCloneSyncChannel::new(channel)))
/// }
/// }
/// ```
pub fn create_flight_client(
channel: BoxCloneSyncChannel,
) -> FlightServiceClient<BoxCloneSyncChannel> {
FlightServiceClient::new(channel)
.max_decoding_message_size(usize::MAX)
.max_encoding_message_size(usize::MAX)
}
9 changes: 9 additions & 0 deletions src/flight_service/do_get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,15 @@ impl ArrowFlightEndpoint {
// Note that we do garbage collection of unused dictionary values above, so we are not sending
// unused dictionary values over the wire.
.with_dictionary_handling(DictionaryHandling::Resend)
// Set max flight data size to unlimited.
// This requires servers and clients to also be configured to handle unlimited sizes.
// Using unlimited sizes avoids splitting RecordBatches into multiple FlightData messages,
// which could add significant overhead for large RecordBatches.
// The only reason to split them really is if the client/server are configured with a message size limit,
// which mainly makes sense in a public network scenario where you want to avoid DoS attacks.
// Since all of our Arrow Flight communication happens within trusted data plane networks,
// we can safely use unlimited sizes here.
.with_max_flight_data_size(usize::MAX)
.build(stream.map_err(|err| {
FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err)))
}));
Expand Down
42 changes: 41 additions & 1 deletion src/flight_service/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::common::ttl_map::{TTLMap, TTLMapConfig};
use crate::flight_service::DistributedSessionBuilder;
use crate::flight_service::do_get::TaskData;
use crate::protobuf::StageKey;
use arrow_flight::flight_service_server::FlightService;
use arrow_flight::flight_service_server::{FlightService, FlightServiceServer};
use arrow_flight::{
Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
Expand All @@ -28,6 +28,7 @@ pub struct ArrowFlightEndpoint {
pub(super) task_data_entries: Arc<TTLMap<StageKey, Arc<OnceCell<TaskData>>>>,
pub(super) session_builder: Arc<dyn DistributedSessionBuilder + Send + Sync>,
pub(super) hooks: ArrowFlightEndpointHooks,
pub(super) max_message_size: Option<usize>,
}

impl ArrowFlightEndpoint {
Expand All @@ -40,6 +41,7 @@ impl ArrowFlightEndpoint {
task_data_entries: Arc::new(ttl_map),
session_builder: Arc::new(session_builder),
hooks: ArrowFlightEndpointHooks::default(),
max_message_size: Some(usize::MAX),
})
}

Expand All @@ -54,6 +56,44 @@ impl ArrowFlightEndpoint {
) {
self.hooks.on_plan.push(Arc::new(hook));
}

/// Set the maximum message size for FlightData chunks.
///
/// Defaults to `usize::MAX` to minimize chunking overhead for internal communication.
/// See [`FlightDataEncoderBuilder::with_max_flight_data_size`] for details.
///
/// If you change this to a lower value, ensure you configure the server's
/// max_encoding_message_size and max_decoding_message_size to at least 2x this value
/// to allow for overhead. For most use cases, the default of `usize::MAX` is appropriate.
///
/// [`FlightDataEncoderBuilder::with_max_flight_data_size`]: https://arrow.apache.org/rust/arrow_flight/encode/struct.FlightDataEncoderBuilder.html#structfield.max_flight_data_size
pub fn with_max_message_size(mut self, size: usize) -> Self {
self.max_message_size = Some(size);
self
}

/// Converts this endpoint into a [`FlightServiceServer`] with high default message size limits.
///
/// This is a convenience method that wraps the endpoint in a [`FlightServiceServer`] and
/// configures it with `max_decoding_message_size(usize::MAX)` and
/// `max_encoding_message_size(usize::MAX)` to avoid message size limitations for internal
/// communication.
///
/// You can further customize the returned server by chaining additional tonic methods.
///
/// # Example
///
/// ```rust,ignore
/// let endpoint = ArrowFlightEndpoint::try_new(session_builder)?;
/// let server = endpoint.into_flight_server();
/// // Can chain additional tonic methods if needed
/// // let server = server.some_other_tonic_method(...);
/// ```
pub fn into_flight_server(self) -> FlightServiceServer<Self> {
FlightServiceServer::new(self)
.max_decoding_message_size(usize::MAX)
.max_encoding_message_size(usize::MAX)
}
}

#[async_trait]
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ mod protobuf;
#[cfg(any(feature = "integration", test))]
pub mod test_utils;

pub use channel_resolver_ext::{BoxCloneSyncChannel, ChannelResolver};
pub use channel_resolver_ext::{BoxCloneSyncChannel, ChannelResolver, create_flight_client};
pub use distributed_ext::DistributedExt;
pub use distributed_planner::{
DistributedConfig, DistributedPhysicalOptimizerRule, InputStageInfo, NetworkBoundary,
Expand Down
7 changes: 3 additions & 4 deletions src/test_utils/in_memory_channel_resolver.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use crate::{
ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt,
DistributedSessionBuilderContext,
DistributedSessionBuilderContext, create_flight_client,
};
use arrow_flight::flight_service_client::FlightServiceClient;
use arrow_flight::flight_service_server::FlightServiceServer;
use async_trait::async_trait;
use datafusion::common::DataFusionError;
use datafusion::execution::SessionStateBuilder;
Expand Down Expand Up @@ -40,7 +39,7 @@ impl InMemoryChannelResolver {
}));

let this = Self {
channel: FlightServiceClient::new(BoxCloneSyncChannel::new(channel)),
channel: create_flight_client(BoxCloneSyncChannel::new(channel)),
};
let this_clone = this.clone();

Expand All @@ -59,7 +58,7 @@ impl InMemoryChannelResolver {

tokio::spawn(async move {
Server::builder()
.add_service(FlightServiceServer::new(endpoint))
.add_service(endpoint.into_flight_server())
.serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
.await
});
Expand Down
7 changes: 3 additions & 4 deletions src/test_utils/localhost.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use crate::{
ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt,
DistributedSessionBuilder, DistributedSessionBuilderContext,
MappedDistributedSessionBuilderExt,
MappedDistributedSessionBuilderExt, create_flight_client,
};
use arrow_flight::flight_service_client::FlightServiceClient;
use arrow_flight::flight_service_server::FlightServiceServer;
use async_trait::async_trait;
use datafusion::common::DataFusionError;
use datafusion::common::runtime::JoinSet;
Expand Down Expand Up @@ -105,7 +104,7 @@ impl ChannelResolver for LocalHostChannelResolver {
) -> Result<FlightServiceClient<BoxCloneSyncChannel>, DataFusionError> {
let endpoint = Channel::from_shared(url.to_string()).map_err(external_err)?;
let channel = endpoint.connect().await.map_err(external_err)?;
Ok(FlightServiceClient::new(BoxCloneSyncChannel::new(channel)))
Ok(create_flight_client(BoxCloneSyncChannel::new(channel)))
}
}

Expand All @@ -118,7 +117,7 @@ pub async fn spawn_flight_service(
let incoming = tokio_stream::wrappers::TcpListenerStream::new(incoming);

Ok(Server::builder()
.add_service(FlightServiceServer::new(endpoint))
.add_service(endpoint.into_flight_server())
.serve_with_incoming(incoming)
.await?)
}
Expand Down