Skip to content

Commit 03452bc

Browse files
hook everything up
1 parent c8c04fd commit 03452bc

File tree

3 files changed

+33
-14
lines changed

3 files changed

+33
-14
lines changed

src/common/ttl_map.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ where
5656
/// new creates a new Bucket
5757
fn new<V>(data: Arc<DashMap<K, V>>) -> Self
5858
where
59-
V: Default + Clone + Send + Sync + 'static,
59+
V: Send + Sync + 'static,
6060
{
6161
// TODO: To avoid unbounded growth, consider making this bounded. Alternatively, we can
6262
// introduce a dynamic GC period to ensure that GC can keep up.
@@ -84,7 +84,7 @@ where
8484
/// task is responsible for managing a subset of keys in the map.
8585
async fn task<V>(mut rx: UnboundedReceiver<BucketOp<K>>, data: Arc<DashMap<K, V>>)
8686
where
87-
V: Default + Clone + Send + Sync + 'static,
87+
V: Send + Sync + 'static,
8888
{
8989
let mut shard = HashSet::new();
9090
while let Some(op) = rx.recv().await {
@@ -207,7 +207,7 @@ where
207207

208208
/// get_or_default executes the provided closure with a reference to the map entry for the given key.
209209
/// If the key does not exist, it inserts a new entry with the default value.
210-
pub fn get_or_init<F>(&self, key: K, f: F) -> V
210+
pub fn get_or_init<F>(&self, key: K, init: F) -> V
211211
where
212212
F: FnOnce() -> V,
213213
{
@@ -218,7 +218,7 @@ where
218218

219219
let value = match self.data.entry(key.clone()) {
220220
Entry::Vacant(entry) => {
221-
let value = f();
221+
let value = init();
222222
entry.insert(value.clone());
223223
new_entry = true;
224224
value

src/flight_service/do_get.rs

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use arrow_flight::Ticket;
1414
use datafusion::execution::SessionState;
1515
use futures::TryStreamExt;
1616
use prost::Message;
17+
use std::sync::atomic::{AtomicUsize, Ordering};
1718
use std::sync::Arc;
1819
use tokio::sync::OnceCell;
1920
use tonic::metadata::MetadataMap;
@@ -38,6 +39,13 @@ pub struct DoGet {
3839
pub stage_key: Option<StageKey>,
3940
}
4041

42+
#[derive(Clone)]
43+
pub struct TaskData {
44+
pub(super) state: SessionState,
45+
pub(super) stage: Arc<ExecutionStage>,
46+
partition_count: Arc<AtomicUsize>,
47+
}
48+
4149
impl ArrowFlightEndpoint {
4250
pub(super) async fn get(
4351
&self,
@@ -51,7 +59,10 @@ impl ArrowFlightEndpoint {
5159

5260
let partition = doget.partition as usize;
5361
let task_number = doget.task_number as usize;
54-
let (mut state, stage) = self.get_state_and_stage(doget, metadata).await?;
62+
let task_data = self.get_state_and_stage(doget, metadata).await?;
63+
64+
let stage = task_data.stage;
65+
let mut state = task_data.state;
5566

5667
// find out which partition group we are executing
5768
let task = stage
@@ -91,15 +102,15 @@ impl ArrowFlightEndpoint {
91102
&self,
92103
doget: DoGet,
93104
metadata_map: MetadataMap,
94-
) -> Result<(SessionState, Arc<ExecutionStage>), Status> {
105+
) -> Result<TaskData, Status> {
95106
let key = doget
96107
.stage_key
97108
.ok_or(Status::invalid_argument("DoGet is missing the stage key"))?;
98-
let once_stage = self.stages.get_or_init(key, || {
99-
OnceCell::<(SessionState, Arc<ExecutionStage>)>::new()
100-
});
109+
let once_stage = self
110+
.stages
111+
.get_or_init(key.clone(), || OnceCell::<TaskData>::new());
101112

102-
let stage_data = once_stage
113+
let mut stage_data = once_stage
103114
.get_or_try_init(|| async {
104115
let stage_proto = doget
105116
.stage_proto
@@ -134,10 +145,19 @@ impl ArrowFlightEndpoint {
134145
config.set_extension(stage.clone());
135146
config.set_extension(Arc::new(ContextGrpcMetadata(headers)));
136147

137-
Ok::<_, Status>((state, stage))
148+
Ok::<_, Status>(TaskData {
149+
state,
150+
stage,
151+
partition_count: Arc::new(AtomicUsize::new(0)),
152+
})
138153
})
139154
.await?;
140155

156+
stage_data.partition_count.fetch_sub(1, Ordering::SeqCst);
157+
if stage_data.partition_count.load(Ordering::SeqCst) <= 0 {
158+
self.stages.remove(key.clone());
159+
}
160+
141161
Ok(stage_data.clone())
142162
}
143163
}

src/flight_service/service.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use crate::channel_manager::ChannelManager;
22
use crate::common::ttl_map::{TTLMap, TTLMapConfig};
3+
use crate::flight_service::do_get::TaskData;
34
use crate::flight_service::session_builder::DefaultSessionBuilder;
45
use crate::flight_service::DistributedSessionBuilder;
5-
use crate::stage::ExecutionStage;
66
use crate::ChannelResolver;
77
use arrow_flight::flight_service_server::FlightService;
88
use arrow_flight::{
@@ -12,7 +12,6 @@ use arrow_flight::{
1212
use async_trait::async_trait;
1313
use datafusion::error::DataFusionError;
1414
use datafusion::execution::runtime_env::RuntimeEnv;
15-
use datafusion::execution::SessionState;
1615
use futures::stream::BoxStream;
1716
use std::sync::Arc;
1817
use tokio::sync::OnceCell;
@@ -36,7 +35,7 @@ pub struct ArrowFlightEndpoint {
3635
pub(super) channel_manager: Arc<ChannelManager>,
3736
pub(super) runtime: Arc<RuntimeEnv>,
3837
#[allow(clippy::type_complexity)]
39-
pub(super) stages: TTLMap<StageKey, OnceCell<(SessionState, Arc<ExecutionStage>)>>,
38+
pub(super) stages: TTLMap<StageKey, OnceCell<TaskData>>,
4039
pub(super) session_builder: Arc<dyn DistributedSessionBuilder + Send + Sync>,
4140
}
4241

0 commit comments

Comments
 (0)