Skip to content

Commit 40dbce6

Browse files
do_get: use TTL map to store task state
The do_get call now evicts task state from the map after N calls where N is the number of partitions. This is an approximation because we don't track that unique partition ids are used, so we might evict early in case of retries. The TTLMap will also GC orphaned task state after the configured TTL period.
1 parent 45eebb8 commit 40dbce6

File tree

4 files changed

+250
-22
lines changed

4 files changed

+250
-22
lines changed

src/common/ttl_map.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ where
5656
/// new creates a new Bucket
5757
fn new<V>(data: Arc<DashMap<K, V>>) -> Self
5858
where
59-
V: Default + Clone + Send + Sync + 'static,
59+
V: Send + Sync + 'static,
6060
{
6161
// TODO: To avoid unbounded growth, consider making this bounded. Alternatively, we can
6262
// introduce a dynamic GC period to ensure that GC can keep up.
@@ -84,7 +84,7 @@ where
8484
/// task is responsible for managing a subset of keys in the map.
8585
async fn task<V>(mut rx: UnboundedReceiver<BucketOp<K>>, data: Arc<DashMap<K, V>>)
8686
where
87-
V: Default + Clone + Send + Sync + 'static,
87+
V: Send + Sync + 'static,
8888
{
8989
let mut shard = HashSet::new();
9090
while let Some(op) = rx.recv().await {
@@ -207,7 +207,7 @@ where
207207

208208
/// get_or_default executes the provided closure with a reference to the map entry for the given key.
209209
/// If the key does not exist, it inserts a new entry with the default value.
210-
pub fn get_or_init<F>(&self, key: K, f: F) -> V
210+
pub fn get_or_init<F>(&self, key: K, init: F) -> V
211211
where
212212
F: FnOnce() -> V,
213213
{
@@ -218,7 +218,7 @@ where
218218

219219
let value = match self.data.entry(key.clone()) {
220220
Entry::Vacant(entry) => {
221-
let value = f();
221+
let value = init();
222222
entry.insert(value.clone());
223223
new_entry = true;
224224
value
@@ -257,6 +257,18 @@ where
257257
self.data.remove(&key);
258258
}
259259

260+
/// Returns the number of entries currently stored in the map
261+
#[cfg(test)]
262+
pub fn len(&self) -> usize {
263+
self.data.len()
264+
}
265+
266+
/// Returns an iterator over the keys currently stored in the map
267+
#[cfg(test)]
268+
pub fn keys(&self) -> impl Iterator<Item = K> + '_ {
269+
self.data.iter().map(|entry| entry.key().clone())
270+
}
271+
260272
/// run_gc_loop will continuously clear expired entries from the map, checking every `period`. The
261273
/// function terminates if `shutdown` is signalled.
262274
async fn run_gc_loop(time: Arc<AtomicU64>, period: Duration, buckets: &[Bucket<K>]) {

src/flight_service/do_get.rs

Lines changed: 222 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ use arrow_flight::Ticket;
1414
use datafusion::execution::SessionState;
1515
use futures::TryStreamExt;
1616
use prost::Message;
17+
use std::sync::atomic::{AtomicUsize, Ordering};
1718
use std::sync::Arc;
19+
use tokio::sync::OnceCell;
1820
use tonic::metadata::MetadataMap;
1921
use tonic::{Request, Response, Status};
2022

@@ -37,6 +39,19 @@ pub struct DoGet {
3739
pub stage_key: Option<StageKey>,
3840
}
3941

42+
#[derive(Clone, Debug)]
43+
/// TaskData stores state for a single task being executed by this Endpoint. It may be shared
44+
/// by concurrent requests for the same task which execute separate partitions.
45+
pub struct TaskData {
46+
pub(super) state: SessionState,
47+
pub(super) stage: Arc<ExecutionStage>,
48+
/// Initialized to the total number of partitions in the task (not only tasks in the partition
49+
/// group). This is decremented for each request to the endpoint for this task. Once this count
50+
/// is zero, the task is likely complete. This count does not account for retried requests
51+
/// for the same partition.
52+
approx_partitions_remaining: Arc<AtomicUsize>,
53+
}
54+
4055
impl ArrowFlightEndpoint {
4156
pub(super) async fn get(
4257
&self,
@@ -50,7 +65,10 @@ impl ArrowFlightEndpoint {
5065

5166
let partition = doget.partition as usize;
5267
let task_number = doget.task_number as usize;
53-
let (mut state, stage) = self.get_state_and_stage(doget, metadata).await?;
68+
let task_data = self.get_state_and_stage(doget, metadata).await?;
69+
70+
let stage = task_data.stage;
71+
let mut state = task_data.state;
5472

5573
// find out which partition group we are executing
5674
let task = stage
@@ -90,16 +108,15 @@ impl ArrowFlightEndpoint {
90108
&self,
91109
doget: DoGet,
92110
metadata_map: MetadataMap,
93-
) -> Result<(SessionState, Arc<ExecutionStage>), Status> {
111+
) -> Result<TaskData, Status> {
94112
let key = doget
95113
.stage_key
96114
.ok_or(Status::invalid_argument("DoGet is missing the stage key"))?;
97-
let once_stage = {
98-
let entry = self.stages.entry(key).or_default();
99-
Arc::clone(&entry)
100-
};
115+
let once_stage = self
116+
.stages
117+
.get_or_init(key.clone(), || Arc::new(OnceCell::<TaskData>::new()));
101118

102-
let (state, stage) = once_stage
119+
let stage_data = once_stage
103120
.get_or_try_init(|| async {
104121
let stage_proto = doget
105122
.stage_proto
@@ -134,10 +151,206 @@ impl ArrowFlightEndpoint {
134151
config.set_extension(stage.clone());
135152
config.set_extension(Arc::new(ContextGrpcMetadata(headers)));
136153

137-
Ok::<_, Status>((state, stage))
154+
// Initialize partition count to the number of partitions in the stage
155+
let total_partitions = stage.plan.properties().partitioning.partition_count();
156+
Ok::<_, Status>(TaskData {
157+
state,
158+
stage,
159+
approx_partitions_remaining: Arc::new(AtomicUsize::new(total_partitions)),
160+
})
138161
})
139162
.await?;
140163

141-
Ok((state.clone(), stage.clone()))
164+
// If all the partitions are done, remove the stage from the cache.
165+
let remaining_partitions = stage_data
166+
.approx_partitions_remaining
167+
.fetch_sub(1, Ordering::SeqCst);
168+
if remaining_partitions <= 1 {
169+
self.stages.remove(key.clone());
170+
}
171+
172+
Ok(stage_data.clone())
173+
}
174+
}
175+
176+
#[cfg(test)]
177+
mod tests {
178+
use super::*;
179+
use arrow_flight::Ticket;
180+
use prost::{bytes::Bytes, Message};
181+
use uuid::Uuid;
182+
183+
#[tokio::test]
184+
async fn test_task_data_partition_counting() {
185+
use crate::task::ExecutionTask;
186+
use arrow_flight::Ticket;
187+
use prost::{bytes::Bytes, Message};
188+
use tonic::Request;
189+
use url::Url;
190+
191+
// Create a mock channel resolver for ArrowFlightEndpoint
192+
#[derive(Clone)]
193+
struct MockChannelResolver;
194+
195+
#[async_trait::async_trait]
196+
impl crate::ChannelResolver for MockChannelResolver {
197+
fn get_urls(&self) -> Result<Vec<Url>, datafusion::error::DataFusionError> {
198+
Ok(vec![])
199+
}
200+
201+
async fn get_channel_for_url(
202+
&self,
203+
_url: &Url,
204+
) -> Result<crate::BoxCloneSyncChannel, datafusion::error::DataFusionError>
205+
{
206+
Err(datafusion::error::DataFusionError::NotImplemented(
207+
"Mock resolver".to_string(),
208+
))
209+
}
210+
}
211+
212+
// Create ArrowFlightEndpoint
213+
let endpoint =
214+
ArrowFlightEndpoint::new(MockChannelResolver).expect("Failed to create endpoint");
215+
216+
// Create 3 tasks with 3 partitions each.
217+
let num_tasks = 3;
218+
let num_partitions_per_task = 3;
219+
let stage_id = 1;
220+
let query_id_uuid = Uuid::new_v4();
221+
let query_id = query_id_uuid.as_bytes().to_vec();
222+
223+
// Set up protos.
224+
let mut tasks = Vec::new();
225+
for i in 0..num_tasks {
226+
tasks.push(ExecutionTask {
227+
url_str: None,
228+
partition_group: vec![i], // Set a random partition in the partition group.
229+
});
230+
}
231+
232+
let stage_proto = ExecutionStageProto {
233+
query_id: query_id.clone(),
234+
num: 1,
235+
name: format!("test_stage_{}", 1),
236+
plan: Some(Box::new(create_mock_physical_plan_proto(
237+
num_partitions_per_task,
238+
))),
239+
inputs: vec![],
240+
tasks,
241+
};
242+
243+
let task_keys = vec![
244+
StageKey {
245+
query_id: query_id_uuid.to_string(),
246+
stage_id,
247+
task_number: 0,
248+
},
249+
StageKey {
250+
query_id: query_id_uuid.to_string(),
251+
stage_id,
252+
task_number: 1,
253+
},
254+
StageKey {
255+
query_id: query_id_uuid.to_string(),
256+
stage_id,
257+
task_number: 2,
258+
},
259+
];
260+
261+
let stage_proto_for_closure = stage_proto.clone();
262+
let endpoint_ref = &endpoint;
263+
let do_get = async move |partition: u64, task_number: u64, stage_key: StageKey| {
264+
let stage_proto = stage_proto_for_closure.clone();
265+
// Create DoGet message
266+
let doget = DoGet {
267+
stage_proto: Some(stage_proto),
268+
task_number,
269+
partition,
270+
stage_key: Some(stage_key),
271+
};
272+
273+
// Create Flight ticket
274+
let ticket = Ticket {
275+
ticket: Bytes::from(doget.encode_to_vec()),
276+
};
277+
278+
// Call the actual get() method
279+
let request = Request::new(ticket);
280+
endpoint_ref.get(request).await
281+
};
282+
283+
// For each task, call do_get() for each partition except the last.
284+
for task_number in 0..num_tasks {
285+
for partition in 0..num_partitions_per_task - 1 {
286+
let result = do_get(
287+
partition as u64,
288+
task_number,
289+
task_keys[task_number as usize].clone(),
290+
)
291+
.await;
292+
assert!(result.is_ok());
293+
}
294+
}
295+
296+
// Check that the endpoint has not evicted any task states.
297+
assert_eq!(endpoint.stages.len(), num_tasks as usize);
298+
299+
// Run the last partition of task 0. Any partition number works. Verify that the task state
300+
// is evicted because all partitions have been processed.
301+
let result = do_get(1, 0, task_keys[0].clone()).await;
302+
assert!(result.is_ok());
303+
let stored_stage_keys = endpoint.stages.keys().collect::<Vec<StageKey>>();
304+
assert_eq!(stored_stage_keys.len(), 2);
305+
assert!(stored_stage_keys.contains(&task_keys[1]));
306+
assert!(stored_stage_keys.contains(&task_keys[2]));
307+
308+
// Run the last partition of task 1.
309+
let result = do_get(1, 1, task_keys[1].clone()).await;
310+
assert!(result.is_ok());
311+
let stored_stage_keys = endpoint.stages.keys().collect::<Vec<StageKey>>();
312+
assert_eq!(stored_stage_keys.len(), 1);
313+
assert!(stored_stage_keys.contains(&task_keys[2]));
314+
315+
// Run the last partition of the last task.
316+
let result = do_get(1, 2, task_keys[2].clone()).await;
317+
assert!(result.is_ok());
318+
let stored_stage_keys = endpoint.stages.keys().collect::<Vec<StageKey>>();
319+
assert_eq!(stored_stage_keys.len(), 0);
320+
}
321+
322+
// Helper to create a mock physical plan proto
323+
fn create_mock_physical_plan_proto(
324+
partitions: usize,
325+
) -> datafusion_proto::protobuf::PhysicalPlanNode {
326+
use datafusion_proto::protobuf::partitioning::PartitionMethod;
327+
use datafusion_proto::protobuf::{
328+
Partitioning, PhysicalPlanNode, RepartitionExecNode, Schema,
329+
};
330+
331+
// Create a repartition node that will have the desired partition count
332+
PhysicalPlanNode {
333+
physical_plan_type: Some(
334+
datafusion_proto::protobuf::physical_plan_node::PhysicalPlanType::Repartition(
335+
Box::new(RepartitionExecNode {
336+
input: Some(Box::new(PhysicalPlanNode {
337+
physical_plan_type: Some(
338+
datafusion_proto::protobuf::physical_plan_node::PhysicalPlanType::Empty(
339+
datafusion_proto::protobuf::EmptyExecNode {
340+
schema: Some(Schema {
341+
columns: vec![],
342+
metadata: std::collections::HashMap::new(),
343+
})
344+
}
345+
)
346+
),
347+
})),
348+
partitioning: Some(Partitioning {
349+
partition_method: Some(PartitionMethod::RoundRobin(partitions as u64)),
350+
}),
351+
})
352+
)
353+
),
354+
}
142355
}
143356
}

src/flight_service/service.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
use crate::channel_manager::ChannelManager;
2+
use crate::common::ttl_map::{TTLMap, TTLMapConfig};
3+
use crate::flight_service::do_get::TaskData;
24
use crate::flight_service::session_builder::DefaultSessionBuilder;
35
use crate::flight_service::DistributedSessionBuilder;
4-
use crate::stage::ExecutionStage;
56
use crate::ChannelResolver;
67
use arrow_flight::flight_service_server::FlightService;
78
use arrow_flight::{
89
Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
910
HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
1011
};
1112
use async_trait::async_trait;
12-
use dashmap::DashMap;
13+
use datafusion::error::DataFusionError;
1314
use datafusion::execution::runtime_env::RuntimeEnv;
14-
use datafusion::execution::SessionState;
1515
use futures::stream::BoxStream;
1616
use std::sync::Arc;
1717
use tokio::sync::OnceCell;
@@ -35,18 +35,21 @@ pub struct ArrowFlightEndpoint {
3535
pub(super) channel_manager: Arc<ChannelManager>,
3636
pub(super) runtime: Arc<RuntimeEnv>,
3737
#[allow(clippy::type_complexity)]
38-
pub(super) stages: DashMap<StageKey, Arc<OnceCell<(SessionState, Arc<ExecutionStage>)>>>,
38+
pub(super) stages: TTLMap<StageKey, Arc<OnceCell<TaskData>>>,
3939
pub(super) session_builder: Arc<dyn DistributedSessionBuilder + Send + Sync>,
4040
}
4141

4242
impl ArrowFlightEndpoint {
43-
pub fn new(channel_resolver: impl ChannelResolver + Send + Sync + 'static) -> Self {
44-
Self {
43+
pub fn new(
44+
channel_resolver: impl ChannelResolver + Send + Sync + 'static,
45+
) -> Result<Self, DataFusionError> {
46+
let ttl_map = TTLMap::try_new(TTLMapConfig::default())?;
47+
Ok(Self {
4548
channel_manager: Arc::new(ChannelManager::new(channel_resolver)),
4649
runtime: Arc::new(RuntimeEnv::default()),
47-
stages: DashMap::new(),
50+
stages: ttl_map,
4851
session_builder: Arc::new(DefaultSessionBuilder),
49-
}
52+
})
5053
}
5154

5255
pub fn with_session_builder(

src/test_utils/localhost.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ pub async fn spawn_flight_service(
106106
session_builder: impl DistributedSessionBuilder + Send + Sync + 'static,
107107
incoming: TcpListener,
108108
) -> Result<(), Box<dyn Error + Send + Sync>> {
109-
let mut endpoint = ArrowFlightEndpoint::new(channel_resolver);
109+
let mut endpoint = ArrowFlightEndpoint::new(channel_resolver)?;
110110
endpoint.with_session_builder(session_builder);
111111

112112
let incoming = tokio_stream::wrappers::TcpListenerStream::new(incoming);

0 commit comments

Comments
 (0)