Skip to content

Commit a3e61d2

Browse files
committed
allow configuring message sizes
1 parent 2f29bd1 commit a3e61d2

File tree

6 files changed

+75
-8
lines changed

6 files changed

+75
-8
lines changed

examples/in_memory_cluster.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,13 @@ async fn main() -> Result<(), Box<dyn Error>> {
7676

7777
const DUMMY_URL: &str = "http://localhost:50051";
7878

79+
/// Maximum message size for FlightData chunks in ArrowFlightEndpoint.
80+
const ENDPOINT_MESSAGE_SIZE: usize = 128 * 1024 * 1024; // 128 MB
81+
82+
/// Maximum message size for gRPC server encoding and decoding.
83+
/// This should be 2x the ArrowFlightEndpoint max_message_size to allow for overhead.
84+
const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024; // 256 MB
85+
7986
/// [ChannelResolver] implementation that returns gRPC clients baked by an in-memory
8087
/// tokio duplex rather than a TCP connection.
8188
#[derive(Clone)]
@@ -113,11 +120,16 @@ impl InMemoryChannelResolver {
113120
Ok(builder.build())
114121
}
115122
})
116-
.unwrap();
123+
.unwrap()
124+
.with_max_message_size(ENDPOINT_MESSAGE_SIZE);
117125

118126
tokio::spawn(async move {
119127
Server::builder()
120-
.add_service(FlightServiceServer::new(endpoint))
128+
.add_service(
129+
FlightServiceServer::new(endpoint)
130+
.max_decoding_message_size(MAX_MESSAGE_SIZE)
131+
.max_encoding_message_size(MAX_MESSAGE_SIZE),
132+
)
121133
.serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
122134
.await
123135
});

examples/localhost_worker.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ struct Args {
2525
cluster_ports: Vec<u16>,
2626
}
2727

28+
/// Maximum message size for FlightData chunks in ArrowFlightEndpoint.
29+
const ENDPOINT_MESSAGE_SIZE: usize = 128 * 1024 * 1024; // 128 MB
30+
31+
/// Maximum message size for gRPC server encoding and decoding.
32+
/// This should be 2x the ArrowFlightEndpoint max_message_size to allow for overhead.
33+
const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024; // 256 MB
34+
2835
#[tokio::main]
2936
async fn main() -> Result<(), Box<dyn Error>> {
3037
let args = Args::from_args();
@@ -43,10 +50,15 @@ async fn main() -> Result<(), Box<dyn Error>> {
4350
.with_default_features()
4451
.build())
4552
}
46-
})?;
53+
})?
54+
.with_max_message_size(ENDPOINT_MESSAGE_SIZE);
4755

4856
Server::builder()
49-
.add_service(FlightServiceServer::new(endpoint))
57+
.add_service(
58+
FlightServiceServer::new(endpoint)
59+
.max_decoding_message_size(MAX_MESSAGE_SIZE)
60+
.max_encoding_message_size(MAX_MESSAGE_SIZE),
61+
)
5062
.serve(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), args.port))
5163
.await?;
5264

src/flight_service/do_get.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ impl ArrowFlightEndpoint {
152152
// Note that we do garbage collection of unused dictionary values above, so we are not sending
153153
// unused dictionary values over the wire.
154154
.with_dictionary_handling(DictionaryHandling::Resend)
155+
.with_max_flight_data_size(usize::MAX)
155156
.build(stream.map_err(|err| {
156157
FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err)))
157158
}));

src/flight_service/service.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,13 @@ pub struct ArrowFlightEndpoint {
2828
pub(super) task_data_entries: Arc<TTLMap<StageKey, Arc<OnceCell<TaskData>>>>,
2929
pub(super) session_builder: Arc<dyn DistributedSessionBuilder + Send + Sync>,
3030
pub(super) hooks: ArrowFlightEndpointHooks,
31+
max_message_size: usize,
3132
}
3233

34+
/// Default maximum message size for FlightData chunks in ArrowFlightEndpoint.
35+
/// This is the size used for chunking FlightData within the endpoint.
36+
const DEFAULT_MESSAGE_SIZE: usize = 2 * 1024 * 1024; // 2 MB
37+
3338
impl ArrowFlightEndpoint {
3439
pub fn try_new(
3540
session_builder: impl DistributedSessionBuilder + Send + Sync + 'static,
@@ -40,6 +45,7 @@ impl ArrowFlightEndpoint {
4045
task_data_entries: Arc::new(ttl_map),
4146
session_builder: Arc::new(session_builder),
4247
hooks: ArrowFlightEndpointHooks::default(),
48+
max_message_size: DEFAULT_MESSAGE_SIZE,
4349
})
4450
}
4551

@@ -54,6 +60,18 @@ impl ArrowFlightEndpoint {
5460
) {
5561
self.hooks.on_plan.push(Arc::new(hook));
5662
}
63+
64+
/// Set the maximum message size for FlightData chunks.
65+
/// Defaults to 2 MB.
66+
/// If you change this, ensure you configure the server's max_encoding_message_size and
67+
/// max_decoding_message_size to at least 2x this value to allow for overhead.
68+
/// If your service communication is purely internal and there is no risk of DOS attacks,
69+
/// you may want to set this to a considerably larger value to minimize the overhead of chunking
70+
/// larger datasets.
71+
pub fn with_max_message_size(mut self, size: usize) -> Self {
72+
self.max_message_size = size;
73+
self
74+
}
5775
}
5876

5977
#[async_trait]

src/test_utils/in_memory_channel_resolver.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@ use tonic::transport::{Endpoint, Server};
1212

1313
const DUMMY_URL: &str = "http://localhost:50051";
1414

15+
/// Maximum message size for FlightData chunks in ArrowFlightEndpoint.
16+
const ENDPOINT_MESSAGE_SIZE: usize = 128 * 1024 * 1024; // 128 MB
17+
18+
/// Maximum message size for gRPC server encoding and decoding.
19+
/// This should be 2x the ArrowFlightEndpoint max_message_size to allow for overhead.
20+
const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024; // 256 MB
21+
1522
/// [ChannelResolver] implementation that returns gRPC clients backed by an in-memory
1623
/// tokio duplex rather than a TCP connection.
1724
#[derive(Clone)]
@@ -55,11 +62,16 @@ impl InMemoryChannelResolver {
5562
Ok(builder.build())
5663
}
5764
})
58-
.unwrap();
65+
.unwrap()
66+
.with_max_message_size(ENDPOINT_MESSAGE_SIZE);
5967

6068
tokio::spawn(async move {
6169
Server::builder()
62-
.add_service(FlightServiceServer::new(endpoint))
70+
.add_service(
71+
FlightServiceServer::new(endpoint)
72+
.max_decoding_message_size(MAX_MESSAGE_SIZE)
73+
.max_encoding_message_size(MAX_MESSAGE_SIZE),
74+
)
6375
.serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
6476
.await
6577
});

src/test_utils/localhost.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@ use tokio::net::TcpListener;
1818
use tonic::transport::{Channel, Server};
1919
use url::Url;
2020

21+
/// Maximum message size for FlightData chunks in ArrowFlightEndpoint.
22+
const ENDPOINT_MESSAGE_SIZE: usize = 128 * 1024 * 1024; // 128 MB
23+
24+
/// Maximum message size for gRPC server encoding and decoding.
25+
/// This should be 2x the ArrowFlightEndpoint max_message_size to allow for overhead.
26+
const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024; // 256 MB
27+
2128
pub async fn start_localhost_context<B>(
2229
num_workers: usize,
2330
session_builder: B,
@@ -113,12 +120,17 @@ pub async fn spawn_flight_service(
113120
session_builder: impl DistributedSessionBuilder + Send + Sync + 'static,
114121
incoming: TcpListener,
115122
) -> Result<(), Box<dyn Error + Send + Sync>> {
116-
let endpoint = ArrowFlightEndpoint::try_new(session_builder)?;
123+
let endpoint =
124+
ArrowFlightEndpoint::try_new(session_builder)?.with_max_message_size(ENDPOINT_MESSAGE_SIZE);
117125

118126
let incoming = tokio_stream::wrappers::TcpListenerStream::new(incoming);
119127

120128
Ok(Server::builder()
121-
.add_service(FlightServiceServer::new(endpoint))
129+
.add_service(
130+
FlightServiceServer::new(endpoint)
131+
.max_decoding_message_size(MAX_MESSAGE_SIZE)
132+
.max_encoding_message_size(MAX_MESSAGE_SIZE),
133+
)
122134
.serve_with_incoming(incoming)
123135
.await?)
124136
}

0 commit comments

Comments
 (0)