Skip to content

Commit a02666a

Browse files
authored
Split channel resolver in two (#265)
* Split channel resolver in two * Simplify WorkerResolverExtension and ChannelResolverExtension * Add default builder to ArrowFlightEndpoint * Add some docs * Listen to clippy * Split get_flight_client_for_url in two * Fix conflicts * Remove unnecessary channel resolver * Improve WorkerResolver docs * Use one ChannelResolver per runtime * Improve error reporting on client connection failure * Add a from_session_builder method for constructing an InMemoryChannelResolver * Add ChannelResolver and WorkerResolver default implementations for Arcs * Add test for default channel resolver * Extend doc comment * Fix typo
1 parent b7e5555 commit a02666a

35 files changed

+723
-551
lines changed

Cargo.lock

Lines changed: 42 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ object_store = "0.12.3"
3838
bytes = "1.10.1"
3939
pin-project = "1.1.10"
4040
tokio-stream = "0.1.17"
41+
moka = { version = "0.12", features = ["sync"] }
4142

4243
# integration_tests deps
4344
insta = { version = "1.43.1", features = ["filters"], optional = true }

benchmarks/cdk/bin/worker.rs

Lines changed: 25 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,19 @@
1-
use arrow_flight::flight_service_client::FlightServiceClient;
21
use async_trait::async_trait;
32
use aws_config::BehaviorVersion;
43
use aws_sdk_ec2::Client as Ec2Client;
54
use axum::{Json, Router, extract::Query, http::StatusCode, routing::get};
6-
use dashmap::{DashMap, Entry};
75
use datafusion::common::DataFusionError;
86
use datafusion::common::instant::Instant;
97
use datafusion::common::runtime::SpawnedTask;
10-
use datafusion::execution::{SessionState, SessionStateBuilder};
8+
use datafusion::execution::SessionStateBuilder;
119
use datafusion::physical_plan::execute_stream;
1210
use datafusion::prelude::SessionContext;
1311
use datafusion_distributed::{
14-
ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt,
15-
DistributedPhysicalOptimizerRule, DistributedSessionBuilder, DistributedSessionBuilderContext,
16-
create_flight_client, display_plan_ascii,
12+
ArrowFlightEndpoint, DistributedExt, DistributedPhysicalOptimizerRule,
13+
DistributedSessionBuilderContext, WorkerResolver, display_plan_ascii,
1714
};
1815
use futures::{StreamExt, TryFutureExt};
1916
use log::{error, info, warn};
20-
use object_store::ObjectStore;
2117
use object_store::aws::AmazonS3Builder;
2218
use serde::Serialize;
2319
use std::collections::HashMap;
@@ -27,7 +23,7 @@ use std::sync::atomic::AtomicBool;
2723
use std::sync::{Arc, RwLock};
2824
use std::time::Duration;
2925
use structopt::StructOpt;
30-
use tonic::transport::{Channel, Server};
26+
use tonic::transport::Server;
3127
use url::Url;
3228

3329
#[derive(Serialize)]
@@ -44,43 +40,6 @@ struct Cmd {
4440
bucket: String,
4541
}
4642

47-
#[derive(Clone)]
48-
struct BenchSessionStateBuilder {
49-
s3_url: Url,
50-
s3: Arc<dyn ObjectStore>,
51-
channel_resolver: Ec2ChannelResolver,
52-
}
53-
54-
impl BenchSessionStateBuilder {
55-
fn new(s3_url: Url) -> Result<Self, Box<dyn Error>> {
56-
let s3 = AmazonS3Builder::from_env()
57-
.with_bucket_name(s3_url.host().unwrap().to_string())
58-
.build()?;
59-
Ok(Self {
60-
s3_url,
61-
s3: Arc::new(s3),
62-
channel_resolver: Ec2ChannelResolver::new(),
63-
})
64-
}
65-
}
66-
67-
#[async_trait]
68-
impl DistributedSessionBuilder for BenchSessionStateBuilder {
69-
async fn build_session_state(
70-
&self,
71-
ctx: DistributedSessionBuilderContext,
72-
) -> Result<SessionState, DataFusionError> {
73-
let state = SessionStateBuilder::new()
74-
.with_default_features()
75-
.with_runtime_env(ctx.runtime_env)
76-
.with_object_store(&self.s3_url, Arc::clone(&self.s3))
77-
.with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule))
78-
.with_distributed_channel_resolver(self.channel_resolver.clone())
79-
.build();
80-
Ok(state)
81-
}
82-
}
83-
8443
#[tokio::main]
8544
async fn main() -> Result<(), Box<dyn Error>> {
8645
env_logger::builder()
@@ -98,15 +57,27 @@ async fn main() -> Result<(), Box<dyn Error>> {
9857

9958
// Register S3 object store
10059
let s3_url = Url::parse(&format!("s3://{}", cmd.bucket))?;
101-
let state_builder = BenchSessionStateBuilder::new(s3_url)?;
10260

10361
info!("Building shared SessionContext for the whole lifetime of the HTTP listener...");
104-
let state = state_builder
105-
.build_session_state(Default::default())
106-
.await?;
62+
let s3 = Arc::new(
63+
AmazonS3Builder::from_env()
64+
.with_bucket_name(s3_url.host().unwrap().to_string())
65+
.build()?,
66+
);
67+
let state = SessionStateBuilder::new()
68+
.with_default_features()
69+
.with_object_store(&s3_url, Arc::clone(&s3) as _)
70+
.with_distributed_worker_resolver(Ec2WorkerResolver::new())
71+
.with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule))
72+
.build();
10773
let ctx = SessionContext::from(state);
10874

109-
let arrow_flight_endpoint = ArrowFlightEndpoint::try_new(state_builder.clone())?;
75+
let arrow_flight_endpoint =
76+
ArrowFlightEndpoint::from_session_builder(move |ctx: DistributedSessionBuilderContext| {
77+
let s3 = s3.clone();
78+
let s3_url = s3_url.clone();
79+
async move { Ok(ctx.builder.with_object_store(&s3_url, s3).build()) }
80+
});
11081
let http_server = axum::serve(
11182
listener,
11283
Router::new().route(
@@ -213,9 +184,8 @@ fn err(s: impl Display) -> (StatusCode, String) {
213184
}
214185

215186
#[derive(Clone)]
216-
struct Ec2ChannelResolver {
187+
struct Ec2WorkerResolver {
217188
urls: Arc<RwLock<Vec<Url>>>,
218-
channels: Arc<DashMap<Url, BoxCloneSyncChannel>>,
219189
}
220190

221191
async fn background_ec2_worker_resolver(urls: Arc<RwLock<Vec<Url>>>) {
@@ -273,35 +243,17 @@ async fn background_ec2_worker_resolver(urls: Arc<RwLock<Vec<Url>>>) {
273243
});
274244
}
275245

276-
impl Ec2ChannelResolver {
246+
impl Ec2WorkerResolver {
277247
fn new() -> Self {
278248
let urls = Arc::new(RwLock::new(Vec::new()));
279-
let channels = Arc::new(DashMap::new());
280249
tokio::spawn(background_ec2_worker_resolver(urls.clone()));
281-
Self { urls, channels }
250+
Self { urls }
282251
}
283252
}
284253

285254
#[async_trait]
286-
impl ChannelResolver for Ec2ChannelResolver {
255+
impl WorkerResolver for Ec2WorkerResolver {
287256
fn get_urls(&self) -> Result<Vec<Url>, DataFusionError> {
288257
Ok(self.urls.read().unwrap().clone())
289258
}
290-
291-
async fn get_flight_client_for_url(
292-
&self,
293-
url: &Url,
294-
) -> Result<FlightServiceClient<BoxCloneSyncChannel>, DataFusionError> {
295-
let channel = match self.channels.entry(url.clone()) {
296-
Entry::Occupied(v) => v.get().clone(),
297-
Entry::Vacant(v) => {
298-
let endpoint = Channel::from_shared(url.to_string()).unwrap();
299-
let channel = endpoint.connect_lazy();
300-
let channel = BoxCloneSyncChannel::new(channel);
301-
v.insert(channel.clone());
302-
channel
303-
}
304-
};
305-
Ok(create_flight_client(channel))
306-
}
307259
}

benchmarks/src/tpch/run.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ use datafusion::physical_plan::display::DisplayableExecutionPlan;
4343
use datafusion::physical_plan::{collect, displayable};
4444
use datafusion::prelude::*;
4545
use datafusion_distributed::test_utils::localhost::{
46-
LocalHostChannelResolver, spawn_flight_service,
46+
LocalHostWorkerResolver, spawn_flight_service,
4747
};
4848
use datafusion_distributed::{
4949
DistributedExt, DistributedPhysicalOptimizerRule, DistributedSessionBuilder,
@@ -143,7 +143,7 @@ impl DistributedSessionBuilder for RunOpt {
143143
.with_default_features()
144144
.with_config(config)
145145
.with_distributed_user_codec(InMemoryCacheExecCodec)
146-
.with_distributed_channel_resolver(LocalHostChannelResolver::new(self.workers.clone()))
146+
.with_distributed_worker_resolver(LocalHostWorkerResolver::new(self.workers.clone()))
147147
.with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule))
148148
.with_distributed_option_extension_from_headers::<WarmingUpMarker>(&ctx.headers)?
149149
.with_distributed_files_per_task(

cli/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ edition = "2024"
55

66
[dependencies]
77
datafusion = { version = "51" }
8-
datafusion-distributed = { path = "..", features = ["avro"] }
8+
datafusion-distributed = { path = "..", features = ["avro", "integration"] }
99
datafusion-cli = { version = "51", default-features = false }
1010
tokio = { version = "1.46.1", features = ["full"] }
1111
clap = { version = "4", features = ["derive"] }

cli/src/main.rs

Lines changed: 5 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
// File mainly copied from https://github.com/apache/datafusion/blob/main/datafusion-cli/src/main.rs
1919

20-
use arrow_flight::flight_service_client::FlightServiceClient;
21-
use async_trait::async_trait;
2220
use clap::Parser;
2321
use datafusion::common::config_err;
2422
use datafusion::config::ConfigOptions;
@@ -35,16 +33,14 @@ use datafusion_cli::{
3533
print_format::PrintFormat,
3634
print_options::{MaxRows, PrintOptions},
3735
};
38-
use datafusion_distributed::{
39-
ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt,
40-
DistributedPhysicalOptimizerRule, DistributedSessionBuilderContext, create_flight_client,
36+
use datafusion_distributed::test_utils::in_memory_channel_resolver::{
37+
InMemoryChannelResolver, InMemoryWorkerResolver,
4138
};
42-
use hyper_util::rt::TokioIo;
39+
use datafusion_distributed::{DistributedExt, DistributedPhysicalOptimizerRule};
4340
use std::env;
4441
use std::path::Path;
4542
use std::process::ExitCode;
4643
use std::sync::Arc;
47-
use tonic::transport::{Endpoint, Server};
4844

4945
#[derive(Debug, Parser, PartialEq)]
5046
#[clap(author, version, about, long_about= None)]
@@ -153,7 +149,8 @@ async fn main_inner() -> Result<()> {
153149
.with_config(session_config)
154150
.with_runtime_env(runtime_env)
155151
.with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule))
156-
.with_distributed_channel_resolver(InMemoryChannelResolver::new())
152+
.with_distributed_worker_resolver(InMemoryWorkerResolver::new(16))
153+
.with_distributed_channel_resolver(InMemoryChannelResolver::default())
157154
.build();
158155

159156
// enable dynamic file query
@@ -265,69 +262,3 @@ fn parse_command(command: &str) -> Result<String, String> {
265262
Err("-c flag expects only non empty commands".to_string())
266263
}
267264
}
268-
269-
const DUMMY_URL: &str = "http://localhost:50051";
270-
271-
/// [ChannelResolver] implementation that returns gRPC clients baked by an in-memory
272-
/// tokio duplex rather than a TCP connection.
273-
#[derive(Clone)]
274-
struct InMemoryChannelResolver {
275-
channel: FlightServiceClient<BoxCloneSyncChannel>,
276-
}
277-
278-
impl InMemoryChannelResolver {
279-
fn new() -> Self {
280-
let (client, server) = tokio::io::duplex(1024 * 1024);
281-
282-
let mut client = Some(client);
283-
let channel = Endpoint::try_from(DUMMY_URL)
284-
.expect("Invalid dummy URL for building an endpoint. This should never happen")
285-
.connect_with_connector_lazy(tower::service_fn(move |_| {
286-
let client = client
287-
.take()
288-
.expect("Client taken twice. This should never happen");
289-
async move { Ok::<_, std::io::Error>(TokioIo::new(client)) }
290-
}));
291-
292-
let this = Self {
293-
channel: create_flight_client(BoxCloneSyncChannel::new(channel)),
294-
};
295-
let this_clone = this.clone();
296-
297-
let endpoint =
298-
ArrowFlightEndpoint::try_new(move |ctx: DistributedSessionBuilderContext| {
299-
let this = this.clone();
300-
async move {
301-
let builder = SessionStateBuilder::new()
302-
.with_default_features()
303-
.with_distributed_channel_resolver(this)
304-
.with_runtime_env(ctx.runtime_env.clone());
305-
Ok(builder.build())
306-
}
307-
})
308-
.unwrap();
309-
310-
tokio::spawn(async move {
311-
Server::builder()
312-
.add_service(endpoint.into_flight_server())
313-
.serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
314-
.await
315-
});
316-
317-
this_clone
318-
}
319-
}
320-
321-
#[async_trait]
322-
impl ChannelResolver for InMemoryChannelResolver {
323-
fn get_urls(&self) -> std::result::Result<Vec<url::Url>, DataFusionError> {
324-
Ok(vec![url::Url::parse(DUMMY_URL).unwrap(); 16]) // simulate 16 workers
325-
}
326-
327-
async fn get_flight_client_for_url(
328-
&self,
329-
_: &url::Url,
330-
) -> std::result::Result<FlightServiceClient<BoxCloneSyncChannel>, DataFusionError> {
331-
Ok(self.channel.clone())
332-
}
333-
}

0 commit comments

Comments
 (0)