Skip to content

Commit ed67281

Browse files
do_get: use TTL map to store task state
The do_get call now evicts task state from the map after N calls where N is the number of partitions. This is an approximation because we don't track that unique partition ids are used, so we might evict early in case of retries. The TTLMap will also GC orphaned task state after the configured TTL period.
1 parent e773d74 commit ed67281

File tree

4 files changed

+241
-22
lines changed

4 files changed

+241
-22
lines changed

src/common/ttl_map.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ where
5656
/// new creates a new Bucket
5757
fn new<V>(data: Arc<DashMap<K, V>>) -> Self
5858
where
59-
V: Default + Clone + Send + Sync + 'static,
59+
V: Send + Sync + 'static,
6060
{
6161
// TODO: To avoid unbounded growth, consider making this bounded. Alternatively, we can
6262
// introduce a dynamic GC period to ensure that GC can keep up.
@@ -84,7 +84,7 @@ where
8484
/// task is responsible for managing a subset of keys in the map.
8585
async fn task<V>(mut rx: UnboundedReceiver<BucketOp<K>>, data: Arc<DashMap<K, V>>)
8686
where
87-
V: Default + Clone + Send + Sync + 'static,
87+
V: Send + Sync + 'static,
8888
{
8989
let mut shard = HashSet::new();
9090
while let Some(op) = rx.recv().await {
@@ -207,7 +207,7 @@ where
207207

208208
/// get_or_default executes the provided closure with a reference to the map entry for the given key.
209209
/// If the key does not exist, it inserts a new entry with the default value.
210-
pub fn get_or_init<F>(&self, key: K, f: F) -> V
210+
pub fn get_or_init<F>(&self, key: K, init: F) -> V
211211
where
212212
F: FnOnce() -> V,
213213
{
@@ -218,7 +218,7 @@ where
218218

219219
let value = match self.data.entry(key.clone()) {
220220
Entry::Vacant(entry) => {
221-
let value = f();
221+
let value = init();
222222
entry.insert(value.clone());
223223
new_entry = true;
224224
value
@@ -257,6 +257,18 @@ where
257257
self.data.remove(&key);
258258
}
259259

260+
/// Returns the number of entries currently stored in the map
261+
#[cfg(test)]
262+
pub fn len(&self) -> usize {
263+
self.data.len()
264+
}
265+
266+
/// Returns an iterator over the keys currently stored in the map
267+
#[cfg(test)]
268+
pub fn keys(&self) -> impl Iterator<Item = K> + '_ {
269+
self.data.iter().map(|entry| entry.key().clone())
270+
}
271+
260272
/// run_gc_loop will continuously clear expired entries from the map, checking every `period`. The
261273
/// function terminates if `shutdown` is signalled.
262274
async fn run_gc_loop(time: Arc<AtomicU64>, period: Duration, buckets: &[Bucket<K>]) {

src/flight_service/do_get.rs

Lines changed: 213 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ use arrow_flight::Ticket;
1414
use datafusion::execution::SessionState;
1515
use futures::TryStreamExt;
1616
use prost::Message;
17+
use std::sync::atomic::{AtomicUsize, Ordering};
1718
use std::sync::Arc;
19+
use tokio::sync::OnceCell;
1820
use tonic::metadata::MetadataMap;
1921
use tonic::{Request, Response, Status};
2022

@@ -37,6 +39,13 @@ pub struct DoGet {
3739
pub stage_key: Option<StageKey>,
3840
}
3941

42+
#[derive(Clone, Debug)]
43+
pub struct TaskData {
44+
pub(super) state: SessionState,
45+
pub(super) stage: Arc<ExecutionStage>,
46+
partition_count: Arc<AtomicUsize>,
47+
}
48+
4049
impl ArrowFlightEndpoint {
4150
pub(super) async fn get(
4251
&self,
@@ -50,7 +59,10 @@ impl ArrowFlightEndpoint {
5059

5160
let partition = doget.partition as usize;
5261
let task_number = doget.task_number as usize;
53-
let (mut state, stage) = self.get_state_and_stage(doget, metadata).await?;
62+
let task_data = self.get_state_and_stage(doget, metadata).await?;
63+
64+
let stage = task_data.stage;
65+
let mut state = task_data.state;
5466

5567
// find out which partition group we are executing
5668
let task = stage
@@ -90,16 +102,15 @@ impl ArrowFlightEndpoint {
90102
&self,
91103
doget: DoGet,
92104
metadata_map: MetadataMap,
93-
) -> Result<(SessionState, Arc<ExecutionStage>), Status> {
105+
) -> Result<TaskData, Status> {
94106
let key = doget
95107
.stage_key
96108
.ok_or(Status::invalid_argument("DoGet is missing the stage key"))?;
97-
let once_stage = {
98-
let entry = self.stages.entry(key).or_default();
99-
Arc::clone(&entry)
100-
};
109+
let once_stage = self
110+
.stages
111+
.get_or_init(key.clone(), || Arc::new(OnceCell::<TaskData>::new()));
101112

102-
let (state, stage) = once_stage
113+
let stage_data = once_stage
103114
.get_or_try_init(|| async {
104115
let stage_proto = doget
105116
.stage_proto
@@ -134,10 +145,203 @@ impl ArrowFlightEndpoint {
134145
config.set_extension(stage.clone());
135146
config.set_extension(Arc::new(ContextGrpcMetadata(headers)));
136147

137-
Ok::<_, Status>((state, stage))
148+
// Initialize partition count to the number of partitions in the stage
149+
let total_partitions = stage.plan.properties().partitioning.partition_count();
150+
Ok::<_, Status>(TaskData {
151+
state,
152+
stage,
153+
partition_count: Arc::new(AtomicUsize::new(total_partitions)),
154+
})
138155
})
139156
.await?;
140157

141-
Ok((state.clone(), stage.clone()))
158+
let remaining_partitions = stage_data.partition_count.fetch_sub(1, Ordering::SeqCst);
159+
if remaining_partitions <= 1 {
160+
self.stages.remove(key.clone());
161+
}
162+
163+
Ok(stage_data.clone())
164+
}
165+
}
166+
167+
#[cfg(test)]
168+
mod tests {
169+
use super::*;
170+
use arrow_flight::Ticket;
171+
use prost::{bytes::Bytes, Message};
172+
use uuid::Uuid;
173+
174+
#[tokio::test]
175+
async fn test_task_data_partition_counting() {
176+
use crate::task::ExecutionTask;
177+
use arrow_flight::Ticket;
178+
use prost::{bytes::Bytes, Message};
179+
use tonic::Request;
180+
use url::Url;
181+
182+
// Create a mock channel resolver for ArrowFlightEndpoint
183+
#[derive(Clone)]
184+
struct MockChannelResolver;
185+
186+
#[async_trait::async_trait]
187+
impl crate::ChannelResolver for MockChannelResolver {
188+
fn get_urls(&self) -> Result<Vec<Url>, datafusion::error::DataFusionError> {
189+
Ok(vec![])
190+
}
191+
192+
async fn get_channel_for_url(
193+
&self,
194+
_url: &Url,
195+
) -> Result<crate::BoxCloneSyncChannel, datafusion::error::DataFusionError>
196+
{
197+
Err(datafusion::error::DataFusionError::NotImplemented(
198+
"Mock resolver".to_string(),
199+
))
200+
}
201+
}
202+
203+
// Create ArrowFlightEndpoint
204+
let endpoint =
205+
ArrowFlightEndpoint::new(MockChannelResolver).expect("Failed to create endpoint");
206+
207+
// Test parameters: 2 stages, each with 3 tasks, each task has 10 partitions
208+
let num_tasks = 3;
209+
let num_partitions_per_task = 3;
210+
let stage_id = 1;
211+
let query_id_uuid = Uuid::new_v4();
212+
let query_id = query_id_uuid.as_bytes().to_vec();
213+
214+
// Create ExecutionStage proto with multiple tasks
215+
let mut tasks = Vec::new();
216+
for i in 0..num_tasks {
217+
tasks.push(ExecutionTask {
218+
url_str: None,
219+
partition_group: vec![i], // Different partition group for each task
220+
});
221+
}
222+
223+
let stage_proto = ExecutionStageProto {
224+
query_id: query_id.clone(),
225+
num: 1,
226+
name: format!("test_stage_{}", 1),
227+
plan: Some(Box::new(create_mock_physical_plan_proto(
228+
num_partitions_per_task,
229+
))),
230+
inputs: vec![],
231+
tasks,
232+
};
233+
234+
let task_keys = vec![
235+
StageKey {
236+
query_id: query_id_uuid.to_string(),
237+
stage_id,
238+
task_number: 0,
239+
},
240+
StageKey {
241+
query_id: query_id_uuid.to_string(),
242+
stage_id,
243+
task_number: 1,
244+
},
245+
StageKey {
246+
query_id: query_id_uuid.to_string(),
247+
stage_id,
248+
task_number: 2,
249+
},
250+
];
251+
252+
let stage_proto_for_closure = stage_proto.clone();
253+
let endpoint_ref = &endpoint;
254+
let do_get = async move |partition: u64, task_number: u64, stage_key: StageKey| {
255+
let stage_proto = stage_proto_for_closure.clone();
256+
// Create DoGet message
257+
let doget = DoGet {
258+
stage_proto: Some(stage_proto),
259+
task_number,
260+
partition,
261+
stage_key: Some(stage_key),
262+
};
263+
264+
// Create Flight ticket
265+
let ticket = Ticket {
266+
ticket: Bytes::from(doget.encode_to_vec()),
267+
};
268+
269+
// Call the actual get() method
270+
let request = Request::new(ticket);
271+
endpoint_ref.get(request).await
272+
};
273+
274+
// For each task, call do_get() for each partition except the last.
275+
for task_number in 0..num_tasks {
276+
for partition in 0..num_partitions_per_task - 1 {
277+
let result = do_get(
278+
partition as u64,
279+
task_number,
280+
task_keys[task_number as usize].clone(),
281+
)
282+
.await;
283+
assert!(result.is_ok());
284+
}
285+
}
286+
287+
// Check that the endpoint has not evicted any tasks.
288+
assert_eq!(endpoint.stages.len(), num_tasks as usize);
289+
290+
// Run the last partition of task 0. Any partition number works. Verify that the task state
291+
// is evicted because all partitions have been processed.
292+
let result = do_get(1, 0, task_keys[0].clone()).await;
293+
assert!(result.is_ok());
294+
let stored_stage_keys = endpoint.stages.keys().collect::<Vec<StageKey>>();
295+
assert_eq!(stored_stage_keys.len(), 2);
296+
assert!(stored_stage_keys.contains(&task_keys[1]));
297+
assert!(stored_stage_keys.contains(&task_keys[2]));
298+
299+
// Run the last partition of task 1.
300+
let result = do_get(1, 1, task_keys[1].clone()).await;
301+
assert!(result.is_ok());
302+
let stored_stage_keys = endpoint.stages.keys().collect::<Vec<StageKey>>();
303+
assert_eq!(stored_stage_keys.len(), 1);
304+
assert!(stored_stage_keys.contains(&task_keys[2]));
305+
306+
// Run the last partition of the last task.
307+
let result = do_get(1, 2, task_keys[2].clone()).await;
308+
assert!(result.is_ok());
309+
let stored_stage_keys = endpoint.stages.keys().collect::<Vec<StageKey>>();
310+
assert_eq!(stored_stage_keys.len(), 0);
311+
}
312+
313+
// Helper to create a mock physical plan proto
314+
fn create_mock_physical_plan_proto(
315+
partitions: usize,
316+
) -> datafusion_proto::protobuf::PhysicalPlanNode {
317+
use datafusion_proto::protobuf::partitioning::PartitionMethod;
318+
use datafusion_proto::protobuf::{
319+
Partitioning, PhysicalPlanNode, RepartitionExecNode, Schema,
320+
};
321+
322+
// Create a repartition node that will have the desired partition count
323+
PhysicalPlanNode {
324+
physical_plan_type: Some(
325+
datafusion_proto::protobuf::physical_plan_node::PhysicalPlanType::Repartition(
326+
Box::new(RepartitionExecNode {
327+
input: Some(Box::new(PhysicalPlanNode {
328+
physical_plan_type: Some(
329+
datafusion_proto::protobuf::physical_plan_node::PhysicalPlanType::Empty(
330+
datafusion_proto::protobuf::EmptyExecNode {
331+
schema: Some(Schema {
332+
columns: vec![],
333+
metadata: std::collections::HashMap::new(),
334+
})
335+
}
336+
)
337+
),
338+
})),
339+
partitioning: Some(Partitioning {
340+
partition_method: Some(PartitionMethod::RoundRobin(partitions as u64)),
341+
}),
342+
})
343+
)
344+
),
345+
}
142346
}
143347
}

src/flight_service/service.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
use crate::channel_manager::ChannelManager;
2+
use crate::common::ttl_map::{TTLMap, TTLMapConfig};
3+
use crate::flight_service::do_get::TaskData;
24
use crate::flight_service::session_builder::DefaultSessionBuilder;
35
use crate::flight_service::DistributedSessionBuilder;
4-
use crate::stage::ExecutionStage;
56
use crate::ChannelResolver;
67
use arrow_flight::flight_service_server::FlightService;
78
use arrow_flight::{
89
Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
910
HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
1011
};
1112
use async_trait::async_trait;
12-
use dashmap::DashMap;
13+
use datafusion::error::DataFusionError;
1314
use datafusion::execution::runtime_env::RuntimeEnv;
14-
use datafusion::execution::SessionState;
1515
use futures::stream::BoxStream;
1616
use std::sync::Arc;
1717
use tokio::sync::OnceCell;
@@ -35,18 +35,21 @@ pub struct ArrowFlightEndpoint {
3535
pub(super) channel_manager: Arc<ChannelManager>,
3636
pub(super) runtime: Arc<RuntimeEnv>,
3737
#[allow(clippy::type_complexity)]
38-
pub(super) stages: DashMap<StageKey, Arc<OnceCell<(SessionState, Arc<ExecutionStage>)>>>,
38+
pub(super) stages: TTLMap<StageKey, Arc<OnceCell<TaskData>>>,
3939
pub(super) session_builder: Arc<dyn DistributedSessionBuilder + Send + Sync>,
4040
}
4141

4242
impl ArrowFlightEndpoint {
43-
pub fn new(channel_resolver: impl ChannelResolver + Send + Sync + 'static) -> Self {
44-
Self {
43+
pub fn new(
44+
channel_resolver: impl ChannelResolver + Send + Sync + 'static,
45+
) -> Result<Self, DataFusionError> {
46+
let ttl_map = TTLMap::try_new(TTLMapConfig::default())?;
47+
Ok(Self {
4548
channel_manager: Arc::new(ChannelManager::new(channel_resolver)),
4649
runtime: Arc::new(RuntimeEnv::default()),
47-
stages: DashMap::new(),
50+
stages: ttl_map,
4851
session_builder: Arc::new(DefaultSessionBuilder),
49-
}
52+
})
5053
}
5154

5255
pub fn with_session_builder(

src/test_utils/localhost.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ pub async fn spawn_flight_service(
106106
session_builder: impl DistributedSessionBuilder + Send + Sync + 'static,
107107
incoming: TcpListener,
108108
) -> Result<(), Box<dyn Error + Send + Sync>> {
109-
let mut endpoint = ArrowFlightEndpoint::new(channel_resolver);
109+
let mut endpoint = ArrowFlightEndpoint::new(channel_resolver)?;
110110
endpoint.with_session_builder(session_builder);
111111

112112
let incoming = tokio_stream::wrappers::TcpListenerStream::new(incoming);

0 commit comments

Comments
 (0)