Skip to content

Commit acebefc

Browse files
do_get: use TTLMap in ArrowFlightEndpoint instead of DashMap
1 parent 3a8d1d7 commit acebefc

File tree

4 files changed

+241
-22
lines changed

4 files changed

+241
-22
lines changed

src/common/ttl_map.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ where
5656
/// new creates a new Bucket
5757
fn new<V>(data: Arc<DashMap<K, V>>) -> Self
5858
where
59-
V: Default + Clone + Send + Sync + 'static,
59+
V: Send + Sync + 'static,
6060
{
6161
// TODO: To avoid unbounded growth, consider making this bounded. Alternatively, we can
6262
// introduce a dynamic GC period to ensure that GC can keep up.
@@ -84,7 +84,7 @@ where
8484
/// task is responsible for managing a subset of keys in the map.
8585
async fn task<V>(mut rx: UnboundedReceiver<BucketOp<K>>, data: Arc<DashMap<K, V>>)
8686
where
87-
V: Default + Clone + Send + Sync + 'static,
87+
V: Send + Sync + 'static,
8888
{
8989
let mut shard = HashSet::new();
9090
while let Some(op) = rx.recv().await {
@@ -207,7 +207,7 @@ where
207207

208208
/// get_or_default executes the provided closure with a reference to the map entry for the given key.
209209
/// If the key does not exist, it inserts a new entry with the default value.
210-
pub fn get_or_init<F>(&self, key: K, f: F) -> V
210+
pub fn get_or_init<F>(&self, key: K, init: F) -> V
211211
where
212212
F: FnOnce() -> V,
213213
{
@@ -218,7 +218,7 @@ where
218218

219219
let value = match self.data.entry(key.clone()) {
220220
Entry::Vacant(entry) => {
221-
let value = f();
221+
let value = init();
222222
entry.insert(value.clone());
223223
new_entry = true;
224224
value
@@ -257,6 +257,18 @@ where
257257
self.data.remove(&key);
258258
}
259259

260+
/// Returns the number of entries currently stored in the map
261+
#[cfg(test)]
262+
pub fn len(&self) -> usize {
263+
self.data.len()
264+
}
265+
266+
/// Returns an iterator over the keys currently stored in the map
267+
#[cfg(test)]
268+
pub fn keys(&self) -> impl Iterator<Item = K> + '_ {
269+
self.data.iter().map(|entry| entry.key().clone())
270+
}
271+
260272
/// run_gc_loop will continuously clear expired entries from the map, checking every `period`. The
261273
/// function terminates if `shutdown` is signalled.
262274
async fn run_gc_loop(time: Arc<AtomicU64>, period: Duration, buckets: &[Bucket<K>]) {

src/flight_service/do_get.rs

Lines changed: 213 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ use arrow_flight::Ticket;
1414
use datafusion::execution::SessionState;
1515
use futures::TryStreamExt;
1616
use prost::Message;
17+
use std::sync::atomic::{AtomicUsize, Ordering};
1718
use std::sync::Arc;
19+
use tokio::sync::OnceCell;
1820
use tonic::metadata::MetadataMap;
1921
use tonic::{Request, Response, Status};
2022

@@ -37,6 +39,13 @@ pub struct DoGet {
3739
pub stage_key: Option<StageKey>,
3840
}
3941

42+
#[derive(Clone, Debug)]
43+
pub struct TaskData {
44+
pub(super) state: SessionState,
45+
pub(super) stage: Arc<ExecutionStage>,
46+
partition_count: Arc<AtomicUsize>,
47+
}
48+
4049
impl ArrowFlightEndpoint {
4150
pub(super) async fn get(
4251
&self,
@@ -50,7 +59,10 @@ impl ArrowFlightEndpoint {
5059

5160
let partition = doget.partition as usize;
5261
let task_number = doget.task_number as usize;
53-
let (mut state, stage) = self.get_state_and_stage(doget, metadata).await?;
62+
let task_data = self.get_state_and_stage(doget, metadata).await?;
63+
64+
let stage = task_data.stage;
65+
let mut state = task_data.state;
5466

5567
// find out which partition group we are executing
5668
let task = stage
@@ -90,16 +102,15 @@ impl ArrowFlightEndpoint {
90102
&self,
91103
doget: DoGet,
92104
metadata_map: MetadataMap,
93-
) -> Result<(SessionState, Arc<ExecutionStage>), Status> {
105+
) -> Result<TaskData, Status> {
94106
let key = doget
95107
.stage_key
96108
.ok_or(Status::invalid_argument("DoGet is missing the stage key"))?;
97-
let once_stage = {
98-
let entry = self.stages.entry(key).or_default();
99-
Arc::clone(&entry)
100-
};
109+
let once_stage = self
110+
.stages
111+
.get_or_init(key.clone(), || Arc::new(OnceCell::<TaskData>::new()));
101112

102-
let (state, stage) = once_stage
113+
let stage_data = once_stage
103114
.get_or_try_init(|| async {
104115
let stage_proto = doget
105116
.stage_proto
@@ -134,10 +145,203 @@ impl ArrowFlightEndpoint {
134145
config.set_extension(stage.clone());
135146
config.set_extension(Arc::new(ContextGrpcMetadata(headers)));
136147

137-
Ok::<_, Status>((state, stage))
148+
// Initialize partition count to the number of partitions in the stage
149+
let total_partitions = stage.plan.properties().partitioning.partition_count();
150+
Ok::<_, Status>(TaskData {
151+
state,
152+
stage,
153+
partition_count: Arc::new(AtomicUsize::new(total_partitions)),
154+
})
138155
})
139156
.await?;
140157

141-
Ok((state.clone(), stage.clone()))
158+
let remaining_partitions = stage_data.partition_count.fetch_sub(1, Ordering::SeqCst);
159+
if remaining_partitions <= 1 {
160+
self.stages.remove(key.clone());
161+
}
162+
163+
Ok(stage_data.clone())
164+
}
165+
}
166+
167+
#[cfg(test)]
168+
mod tests {
169+
use super::*;
170+
use arrow_flight::Ticket;
171+
use prost::{bytes::Bytes, Message};
172+
use uuid::Uuid;
173+
174+
#[tokio::test]
175+
async fn test_task_data_partition_counting() {
176+
use crate::task::ExecutionTask;
177+
use arrow_flight::Ticket;
178+
use prost::{bytes::Bytes, Message};
179+
use tonic::Request;
180+
use url::Url;
181+
182+
// Create a mock channel resolver for ArrowFlightEndpoint
183+
#[derive(Clone)]
184+
struct MockChannelResolver;
185+
186+
#[async_trait::async_trait]
187+
impl crate::ChannelResolver for MockChannelResolver {
188+
fn get_urls(&self) -> Result<Vec<Url>, datafusion::error::DataFusionError> {
189+
Ok(vec![])
190+
}
191+
192+
async fn get_channel_for_url(
193+
&self,
194+
_url: &Url,
195+
) -> Result<crate::BoxCloneSyncChannel, datafusion::error::DataFusionError>
196+
{
197+
Err(datafusion::error::DataFusionError::NotImplemented(
198+
"Mock resolver".to_string(),
199+
))
200+
}
201+
}
202+
203+
// Create ArrowFlightEndpoint
204+
let endpoint =
205+
ArrowFlightEndpoint::new(MockChannelResolver).expect("Failed to create endpoint");
206+
207+
// Test parameters: 2 stages, each with 3 tasks, each task has 10 partitions
208+
let num_tasks = 3;
209+
let num_partitions_per_task = 3;
210+
let stage_id = 1;
211+
let query_id_uuid = Uuid::new_v4();
212+
let query_id = query_id_uuid.as_bytes().to_vec();
213+
214+
// Create ExecutionStage proto with multiple tasks
215+
let mut tasks = Vec::new();
216+
for i in 0..num_tasks {
217+
tasks.push(ExecutionTask {
218+
url_str: None,
219+
partition_group: vec![i], // Different partition group for each task
220+
});
221+
}
222+
223+
let stage_proto = ExecutionStageProto {
224+
query_id: query_id.clone(),
225+
num: 1,
226+
name: format!("test_stage_{}", 1),
227+
plan: Some(Box::new(create_mock_physical_plan_proto(
228+
num_partitions_per_task,
229+
))),
230+
inputs: vec![],
231+
tasks,
232+
};
233+
234+
let task_keys = vec![
235+
StageKey {
236+
query_id: query_id_uuid.to_string(),
237+
stage_id,
238+
task_number: 0,
239+
},
240+
StageKey {
241+
query_id: query_id_uuid.to_string(),
242+
stage_id,
243+
task_number: 1,
244+
},
245+
StageKey {
246+
query_id: query_id_uuid.to_string(),
247+
stage_id,
248+
task_number: 2,
249+
},
250+
];
251+
252+
let stage_proto_for_closure = stage_proto.clone();
253+
let endpoint_ref = &endpoint;
254+
let do_get = async move |partition: u64, task_number: u64, stage_key: StageKey| {
255+
let stage_proto = stage_proto_for_closure.clone();
256+
// Create DoGet message
257+
let doget = DoGet {
258+
stage_proto: Some(stage_proto),
259+
task_number,
260+
partition,
261+
stage_key: Some(stage_key),
262+
};
263+
264+
// Create Flight ticket
265+
let ticket = Ticket {
266+
ticket: Bytes::from(doget.encode_to_vec()),
267+
};
268+
269+
// Call the actual get() method
270+
let request = Request::new(ticket);
271+
endpoint_ref.get(request).await
272+
};
273+
274+
// For each task, call do_get() for each partition except the last.
275+
for task_number in 0..num_tasks {
276+
for partition in 0..num_partitions_per_task - 1 {
277+
let result = do_get(
278+
partition as u64,
279+
task_number,
280+
task_keys[task_number as usize].clone(),
281+
)
282+
.await;
283+
assert!(result.is_ok());
284+
}
285+
}
286+
287+
// Check that the endpoint has not evicted any tasks.
288+
assert_eq!(endpoint.stages.len(), num_tasks as usize);
289+
290+
// Run the last partition of task 0. Any partition number works. Verify that the task state
291+
// is evicted because all partitions have been processed.
292+
let result = do_get(1, 0, task_keys[0].clone()).await;
293+
assert!(result.is_ok());
294+
let stored_stage_keys = endpoint.stages.keys().collect::<Vec<StageKey>>();
295+
assert_eq!(stored_stage_keys.len(), 2);
296+
assert!(stored_stage_keys.contains(&task_keys[1]));
297+
assert!(stored_stage_keys.contains(&task_keys[2]));
298+
299+
// Run the last partition of task 1.
300+
let result = do_get(1, 1, task_keys[1].clone()).await;
301+
assert!(result.is_ok());
302+
let stored_stage_keys = endpoint.stages.keys().collect::<Vec<StageKey>>();
303+
assert_eq!(stored_stage_keys.len(), 1);
304+
assert!(stored_stage_keys.contains(&task_keys[2]));
305+
306+
// Run the last partition of the last task.
307+
let result = do_get(1, 2, task_keys[2].clone()).await;
308+
assert!(result.is_ok());
309+
let stored_stage_keys = endpoint.stages.keys().collect::<Vec<StageKey>>();
310+
assert_eq!(stored_stage_keys.len(), 0);
311+
}
312+
313+
// Helper to create a mock physical plan proto
314+
fn create_mock_physical_plan_proto(
315+
partitions: usize,
316+
) -> datafusion_proto::protobuf::PhysicalPlanNode {
317+
use datafusion_proto::protobuf::partitioning::PartitionMethod;
318+
use datafusion_proto::protobuf::{
319+
Partitioning, PhysicalPlanNode, RepartitionExecNode, Schema,
320+
};
321+
322+
// Create a repartition node that will have the desired partition count
323+
PhysicalPlanNode {
324+
physical_plan_type: Some(
325+
datafusion_proto::protobuf::physical_plan_node::PhysicalPlanType::Repartition(
326+
Box::new(RepartitionExecNode {
327+
input: Some(Box::new(PhysicalPlanNode {
328+
physical_plan_type: Some(
329+
datafusion_proto::protobuf::physical_plan_node::PhysicalPlanType::Empty(
330+
datafusion_proto::protobuf::EmptyExecNode {
331+
schema: Some(Schema {
332+
columns: vec![],
333+
metadata: std::collections::HashMap::new(),
334+
})
335+
}
336+
)
337+
),
338+
})),
339+
partitioning: Some(Partitioning {
340+
partition_method: Some(PartitionMethod::RoundRobin(partitions as u64)),
341+
}),
342+
})
343+
)
344+
),
345+
}
142346
}
143347
}

src/flight_service/service.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
use crate::channel_manager::ChannelManager;
2+
use crate::common::ttl_map::{TTLMap, TTLMapConfig};
3+
use crate::flight_service::do_get::TaskData;
24
use crate::flight_service::session_builder::DefaultSessionBuilder;
35
use crate::flight_service::DistributedSessionBuilder;
4-
use crate::stage::ExecutionStage;
56
use crate::ChannelResolver;
67
use arrow_flight::flight_service_server::FlightService;
78
use arrow_flight::{
89
Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
910
HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
1011
};
1112
use async_trait::async_trait;
12-
use dashmap::DashMap;
13+
use datafusion::error::DataFusionError;
1314
use datafusion::execution::runtime_env::RuntimeEnv;
14-
use datafusion::execution::SessionState;
1515
use futures::stream::BoxStream;
1616
use std::sync::Arc;
1717
use tokio::sync::OnceCell;
@@ -35,18 +35,21 @@ pub struct ArrowFlightEndpoint {
3535
pub(super) channel_manager: Arc<ChannelManager>,
3636
pub(super) runtime: Arc<RuntimeEnv>,
3737
#[allow(clippy::type_complexity)]
38-
pub(super) stages: DashMap<StageKey, Arc<OnceCell<(SessionState, Arc<ExecutionStage>)>>>,
38+
pub(super) stages: TTLMap<StageKey, Arc<OnceCell<TaskData>>>,
3939
pub(super) session_builder: Arc<dyn DistributedSessionBuilder + Send + Sync>,
4040
}
4141

4242
impl ArrowFlightEndpoint {
43-
pub fn new(channel_resolver: impl ChannelResolver + Send + Sync + 'static) -> Self {
44-
Self {
43+
pub fn new(
44+
channel_resolver: impl ChannelResolver + Send + Sync + 'static,
45+
) -> Result<Self, DataFusionError> {
46+
let ttl_map = TTLMap::try_new(TTLMapConfig::default())?;
47+
Ok(Self {
4548
channel_manager: Arc::new(ChannelManager::new(channel_resolver)),
4649
runtime: Arc::new(RuntimeEnv::default()),
47-
stages: DashMap::new(),
50+
stages: ttl_map,
4851
session_builder: Arc::new(DefaultSessionBuilder),
49-
}
52+
})
5053
}
5154

5255
pub fn with_session_builder(

src/test_utils/localhost.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ pub async fn spawn_flight_service(
106106
session_builder: impl DistributedSessionBuilder + Send + Sync + 'static,
107107
incoming: TcpListener,
108108
) -> Result<(), Box<dyn Error + Send + Sync>> {
109-
let mut endpoint = ArrowFlightEndpoint::new(channel_resolver);
109+
let mut endpoint = ArrowFlightEndpoint::new(channel_resolver)?;
110110
endpoint.with_session_builder(session_builder);
111111

112112
let incoming = tokio_stream::wrappers::TcpListenerStream::new(incoming);

0 commit comments

Comments
 (0)