Skip to content

Commit 22c3802

Browse files
test_utils: add in memory channel resolver
This change copies the in-memory channel resolver from the `examples` folder to `test_utils` so it is possible to execute distributed queries easily in unit tests without the overhead of network communication.
1 parent dcf7ed1 commit 22c3802

File tree

2 files changed

+83
-0
lines changed

2 files changed

+83
-0
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
use crate::{
2+
ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt,
3+
DistributedSessionBuilderContext,
4+
};
5+
use arrow_flight::flight_service_server::FlightServiceServer;
6+
use async_trait::async_trait;
7+
use datafusion::common::DataFusionError;
8+
use datafusion::execution::SessionStateBuilder;
9+
use hyper_util::rt::TokioIo;
10+
use tonic::transport::{Endpoint, Server};
11+
12+
const DUMMY_URL: &str = "http://localhost:50051";
13+
14+
/// [ChannelResolver] implementation that returns gRPC clients backed by an in-memory
15+
/// tokio duplex rather than a TCP connection.
16+
#[derive(Clone)]
17+
pub struct InMemoryChannelResolver {
18+
channel: BoxCloneSyncChannel,
19+
}
20+
21+
impl Default for InMemoryChannelResolver {
22+
fn default() -> Self {
23+
Self::new()
24+
}
25+
}
26+
27+
impl InMemoryChannelResolver {
28+
pub fn new() -> Self {
29+
let (client, server) = tokio::io::duplex(1024 * 1024);
30+
31+
let mut client = Some(client);
32+
let channel = Endpoint::try_from(DUMMY_URL)
33+
.expect("Invalid dummy URL for building an endpoint. This should never happen")
34+
.connect_with_connector_lazy(tower::service_fn(move |_| {
35+
let client = client
36+
.take()
37+
.expect("Client taken twice. This should never happen");
38+
async move { Ok::<_, std::io::Error>(TokioIo::new(client)) }
39+
}));
40+
41+
let this = Self {
42+
channel: BoxCloneSyncChannel::new(channel),
43+
};
44+
let this_clone = this.clone();
45+
46+
let endpoint =
47+
ArrowFlightEndpoint::try_new(move |ctx: DistributedSessionBuilderContext| {
48+
let this = this.clone();
49+
async move {
50+
let builder = SessionStateBuilder::new()
51+
.with_default_features()
52+
.with_distributed_channel_resolver(this)
53+
.with_runtime_env(ctx.runtime_env.clone());
54+
Ok(builder.build())
55+
}
56+
})
57+
.unwrap();
58+
59+
tokio::spawn(async move {
60+
Server::builder()
61+
.add_service(FlightServiceServer::new(endpoint))
62+
.serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
63+
.await
64+
});
65+
66+
this_clone
67+
}
68+
}
69+
70+
#[async_trait]
71+
impl ChannelResolver for InMemoryChannelResolver {
72+
fn get_urls(&self) -> Result<Vec<url::Url>, DataFusionError> {
73+
Ok(vec![url::Url::parse(DUMMY_URL).unwrap()])
74+
}
75+
76+
async fn get_channel_for_url(
77+
&self,
78+
_: &url::Url,
79+
) -> Result<BoxCloneSyncChannel, DataFusionError> {
80+
Ok(self.channel.clone())
81+
}
82+
}

src/test_utils/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
pub mod in_memory_channel_resolver;
12
pub mod insta;
23
pub mod localhost;
34
pub mod mock_exec;

0 commit comments

Comments
 (0)