Skip to content

Commit 4b46791

Browse files
do_get: use TTLMap in ArrowFlightEndpoint instead of DashMap
1 parent 5c92110 commit 4b46791

File tree

3 files changed

+17
-13
lines changed

3 files changed

+17
-13
lines changed

src/flight_service/do_get.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use datafusion::execution::SessionState;
1515
use futures::TryStreamExt;
1616
use prost::Message;
1717
use std::sync::Arc;
18+
use tokio::sync::OnceCell;
1819
use tonic::metadata::MetadataMap;
1920
use tonic::{Request, Response, Status};
2021

@@ -94,12 +95,11 @@ impl ArrowFlightEndpoint {
9495
let key = doget
9596
.stage_key
9697
.ok_or(Status::invalid_argument("DoGet is missing the stage key"))?;
97-
let once_stage = {
98-
let entry = self.stages.entry(key).or_default();
99-
Arc::clone(&entry)
100-
};
98+
let once_stage = self.stages.get_or_init(key, || {
99+
OnceCell::<(SessionState, Arc<ExecutionStage>)>::new()
100+
});
101101

102-
let (state, stage) = once_stage
102+
let stage_data = once_stage
103103
.get_or_try_init(|| async {
104104
let stage_proto = doget
105105
.stage_proto
@@ -138,6 +138,6 @@ impl ArrowFlightEndpoint {
138138
})
139139
.await?;
140140

141-
Ok((state.clone(), stage.clone()))
141+
Ok(stage_data.clone())
142142
}
143143
}

src/flight_service/service.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::channel_manager::ChannelManager;
2+
use crate::common::ttl_map::{TTLMap, TTLMapConfig};
23
use crate::flight_service::session_builder::DefaultSessionBuilder;
34
use crate::flight_service::DistributedSessionBuilder;
45
use crate::stage::ExecutionStage;
@@ -9,7 +10,7 @@ use arrow_flight::{
910
HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
1011
};
1112
use async_trait::async_trait;
12-
use dashmap::DashMap;
13+
use datafusion::error::DataFusionError;
1314
use datafusion::execution::runtime_env::RuntimeEnv;
1415
use datafusion::execution::SessionState;
1516
use futures::stream::BoxStream;
@@ -35,18 +36,21 @@ pub struct ArrowFlightEndpoint {
3536
pub(super) channel_manager: Arc<ChannelManager>,
3637
pub(super) runtime: Arc<RuntimeEnv>,
3738
#[allow(clippy::type_complexity)]
38-
pub(super) stages: DashMap<StageKey, Arc<OnceCell<(SessionState, Arc<ExecutionStage>)>>>,
39+
pub(super) stages: TTLMap<StageKey, OnceCell<(SessionState, Arc<ExecutionStage>)>>,
3940
pub(super) session_builder: Arc<dyn DistributedSessionBuilder + Send + Sync>,
4041
}
4142

4243
impl ArrowFlightEndpoint {
43-
pub fn new(channel_resolver: impl ChannelResolver + Send + Sync + 'static) -> Self {
44-
Self {
44+
pub fn new(
45+
channel_resolver: impl ChannelResolver + Send + Sync + 'static,
46+
) -> Result<Self, DataFusionError> {
47+
let ttl_map = TTLMap::try_new(TTLMapConfig::default())?;
48+
Ok(Self {
4549
channel_manager: Arc::new(ChannelManager::new(channel_resolver)),
4650
runtime: Arc::new(RuntimeEnv::default()),
47-
stages: DashMap::new(),
51+
stages: ttl_map,
4852
session_builder: Arc::new(DefaultSessionBuilder),
49-
}
53+
})
5054
}
5155

5256
pub fn with_session_builder(

src/test_utils/localhost.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ pub async fn spawn_flight_service(
106106
session_builder: impl DistributedSessionBuilder + Send + Sync + 'static,
107107
incoming: TcpListener,
108108
) -> Result<(), Box<dyn Error + Send + Sync>> {
109-
let mut endpoint = ArrowFlightEndpoint::new(channel_resolver);
109+
let mut endpoint = ArrowFlightEndpoint::new(channel_resolver)?;
110110
endpoint.with_session_builder(session_builder);
111111

112112
let incoming = tokio_stream::wrappers::TcpListenerStream::new(incoming);

0 commit comments

Comments
 (0)