Skip to content

Commit 42b4426

Browse files
do_get: use TTL map to store task state
The do_get call now evicts task state from the map after N calls where N is the number of partitions. This is an approximation because we don't track that unique partition ids are used, so we might evict early in case of retries. The TTLMap will also GC orphaned task state after the configured TTL period.
1 parent 45eebb8 commit 42b4426

File tree

4 files changed

+249
-22
lines changed

4 files changed

+249
-22
lines changed

src/common/ttl_map.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ where
5656
/// new creates a new Bucket
5757
fn new<V>(data: Arc<DashMap<K, V>>) -> Self
5858
where
59-
V: Default + Clone + Send + Sync + 'static,
59+
V: Send + Sync + 'static,
6060
{
6161
// TODO: To avoid unbounded growth, consider making this bounded. Alternatively, we can
6262
// introduce a dynamic GC period to ensure that GC can keep up.
@@ -84,7 +84,7 @@ where
8484
/// task is responsible for managing a subset of keys in the map.
8585
async fn task<V>(mut rx: UnboundedReceiver<BucketOp<K>>, data: Arc<DashMap<K, V>>)
8686
where
87-
V: Default + Clone + Send + Sync + 'static,
87+
V: Send + Sync + 'static,
8888
{
8989
let mut shard = HashSet::new();
9090
while let Some(op) = rx.recv().await {
@@ -207,7 +207,7 @@ where
207207

208208
/// get_or_default executes the provided closure with a reference to the map entry for the given key.
209209
/// If the key does not exist, it inserts a new entry with the default value.
210-
pub fn get_or_init<F>(&self, key: K, f: F) -> V
210+
pub fn get_or_init<F>(&self, key: K, init: F) -> V
211211
where
212212
F: FnOnce() -> V,
213213
{
@@ -218,7 +218,7 @@ where
218218

219219
let value = match self.data.entry(key.clone()) {
220220
Entry::Vacant(entry) => {
221-
let value = f();
221+
let value = init();
222222
entry.insert(value.clone());
223223
new_entry = true;
224224
value
@@ -257,6 +257,18 @@ where
257257
self.data.remove(&key);
258258
}
259259

260+
/// Returns the number of entries currently stored in the map
261+
#[cfg(test)]
262+
pub fn len(&self) -> usize {
263+
self.data.len()
264+
}
265+
266+
/// Returns an iterator over the keys currently stored in the map
267+
#[cfg(test)]
268+
pub fn keys(&self) -> impl Iterator<Item = K> + '_ {
269+
self.data.iter().map(|entry| entry.key().clone())
270+
}
271+
260272
/// run_gc_loop will continuously clear expired entries from the map, checking every `period`. The
261273
/// function terminates if `shutdown` is signalled.
262274
async fn run_gc_loop(time: Arc<AtomicU64>, period: Duration, buckets: &[Bucket<K>]) {

src/flight_service/do_get.rs

Lines changed: 221 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ use arrow_flight::Ticket;
1414
use datafusion::execution::SessionState;
1515
use futures::TryStreamExt;
1616
use prost::Message;
17+
use std::sync::atomic::{AtomicUsize, Ordering};
1718
use std::sync::Arc;
19+
use tokio::sync::OnceCell;
1820
use tonic::metadata::MetadataMap;
1921
use tonic::{Request, Response, Status};
2022

@@ -37,6 +39,18 @@ pub struct DoGet {
3739
pub stage_key: Option<StageKey>,
3840
}
3941

42+
#[derive(Clone, Debug)]
43+
/// TaskData stores state for a single task being executed by this Endpoint. It may be shared
44+
/// by concurrent requests for the same task which execute separate partitions.
45+
pub struct TaskData {
46+
pub(super) state: SessionState,
47+
pub(super) stage: Arc<ExecutionStage>,
48+
/// Initialized to the total number of partitions in the task (not only tasks in the partition
49+
/// group). This is decremented for each request to the endpoint for this task. Once this count
50+
/// is zero, the task is likely complete.
51+
approx_partitions_remaining: Arc<AtomicUsize>,
52+
}
53+
4054
impl ArrowFlightEndpoint {
4155
pub(super) async fn get(
4256
&self,
@@ -50,7 +64,10 @@ impl ArrowFlightEndpoint {
5064

5165
let partition = doget.partition as usize;
5266
let task_number = doget.task_number as usize;
53-
let (mut state, stage) = self.get_state_and_stage(doget, metadata).await?;
67+
let task_data = self.get_state_and_stage(doget, metadata).await?;
68+
69+
let stage = task_data.stage;
70+
let mut state = task_data.state;
5471

5572
// find out which partition group we are executing
5673
let task = stage
@@ -90,16 +107,15 @@ impl ArrowFlightEndpoint {
90107
&self,
91108
doget: DoGet,
92109
metadata_map: MetadataMap,
93-
) -> Result<(SessionState, Arc<ExecutionStage>), Status> {
110+
) -> Result<TaskData, Status> {
94111
let key = doget
95112
.stage_key
96113
.ok_or(Status::invalid_argument("DoGet is missing the stage key"))?;
97-
let once_stage = {
98-
let entry = self.stages.entry(key).or_default();
99-
Arc::clone(&entry)
100-
};
114+
let once_stage = self
115+
.stages
116+
.get_or_init(key.clone(), || Arc::new(OnceCell::<TaskData>::new()));
101117

102-
let (state, stage) = once_stage
118+
let stage_data = once_stage
103119
.get_or_try_init(|| async {
104120
let stage_proto = doget
105121
.stage_proto
@@ -134,10 +150,206 @@ impl ArrowFlightEndpoint {
134150
config.set_extension(stage.clone());
135151
config.set_extension(Arc::new(ContextGrpcMetadata(headers)));
136152

137-
Ok::<_, Status>((state, stage))
153+
// Initialize partition count to the number of partitions in the stage
154+
let total_partitions = stage.plan.properties().partitioning.partition_count();
155+
Ok::<_, Status>(TaskData {
156+
state,
157+
stage,
158+
approx_partitions_remaining: Arc::new(AtomicUsize::new(total_partitions)),
159+
})
138160
})
139161
.await?;
140162

141-
Ok((state.clone(), stage.clone()))
163+
// If all the partitions are done, remove the stage from the cache.
164+
let remaining_partitions = stage_data
165+
.approx_partitions_remaining
166+
.fetch_sub(1, Ordering::SeqCst);
167+
if remaining_partitions <= 1 {
168+
self.stages.remove(key.clone());
169+
}
170+
171+
Ok(stage_data.clone())
172+
}
173+
}
174+
175+
#[cfg(test)]
176+
mod tests {
177+
use super::*;
178+
use arrow_flight::Ticket;
179+
use prost::{bytes::Bytes, Message};
180+
use uuid::Uuid;
181+
182+
#[tokio::test]
183+
async fn test_task_data_partition_counting() {
184+
use crate::task::ExecutionTask;
185+
use arrow_flight::Ticket;
186+
use prost::{bytes::Bytes, Message};
187+
use tonic::Request;
188+
use url::Url;
189+
190+
// Create a mock channel resolver for ArrowFlightEndpoint
191+
#[derive(Clone)]
192+
struct MockChannelResolver;
193+
194+
#[async_trait::async_trait]
195+
impl crate::ChannelResolver for MockChannelResolver {
196+
fn get_urls(&self) -> Result<Vec<Url>, datafusion::error::DataFusionError> {
197+
Ok(vec![])
198+
}
199+
200+
async fn get_channel_for_url(
201+
&self,
202+
_url: &Url,
203+
) -> Result<crate::BoxCloneSyncChannel, datafusion::error::DataFusionError>
204+
{
205+
Err(datafusion::error::DataFusionError::NotImplemented(
206+
"Mock resolver".to_string(),
207+
))
208+
}
209+
}
210+
211+
// Create ArrowFlightEndpoint
212+
let endpoint =
213+
ArrowFlightEndpoint::new(MockChannelResolver).expect("Failed to create endpoint");
214+
215+
// Create 3 tasks with 3 partitions each.
216+
let num_tasks = 3;
217+
let num_partitions_per_task = 3;
218+
let stage_id = 1;
219+
let query_id_uuid = Uuid::new_v4();
220+
let query_id = query_id_uuid.as_bytes().to_vec();
221+
222+
// Set up protos.
223+
let mut tasks = Vec::new();
224+
for i in 0..num_tasks {
225+
tasks.push(ExecutionTask {
226+
url_str: None,
227+
partition_group: vec![i], // Set a random partition in the partition group.
228+
});
229+
}
230+
231+
let stage_proto = ExecutionStageProto {
232+
query_id: query_id.clone(),
233+
num: 1,
234+
name: format!("test_stage_{}", 1),
235+
plan: Some(Box::new(create_mock_physical_plan_proto(
236+
num_partitions_per_task,
237+
))),
238+
inputs: vec![],
239+
tasks,
240+
};
241+
242+
let task_keys = vec![
243+
StageKey {
244+
query_id: query_id_uuid.to_string(),
245+
stage_id,
246+
task_number: 0,
247+
},
248+
StageKey {
249+
query_id: query_id_uuid.to_string(),
250+
stage_id,
251+
task_number: 1,
252+
},
253+
StageKey {
254+
query_id: query_id_uuid.to_string(),
255+
stage_id,
256+
task_number: 2,
257+
},
258+
];
259+
260+
let stage_proto_for_closure = stage_proto.clone();
261+
let endpoint_ref = &endpoint;
262+
let do_get = async move |partition: u64, task_number: u64, stage_key: StageKey| {
263+
let stage_proto = stage_proto_for_closure.clone();
264+
// Create DoGet message
265+
let doget = DoGet {
266+
stage_proto: Some(stage_proto),
267+
task_number,
268+
partition,
269+
stage_key: Some(stage_key),
270+
};
271+
272+
// Create Flight ticket
273+
let ticket = Ticket {
274+
ticket: Bytes::from(doget.encode_to_vec()),
275+
};
276+
277+
// Call the actual get() method
278+
let request = Request::new(ticket);
279+
endpoint_ref.get(request).await
280+
};
281+
282+
// For each task, call do_get() for each partition except the last.
283+
for task_number in 0..num_tasks {
284+
for partition in 0..num_partitions_per_task - 1 {
285+
let result = do_get(
286+
partition as u64,
287+
task_number,
288+
task_keys[task_number as usize].clone(),
289+
)
290+
.await;
291+
assert!(result.is_ok());
292+
}
293+
}
294+
295+
// Check that the endpoint has not evicted any task states.
296+
assert_eq!(endpoint.stages.len(), num_tasks as usize);
297+
298+
// Run the last partition of task 0. Any partition number works. Verify that the task state
299+
// is evicted because all partitions have been processed.
300+
let result = do_get(1, 0, task_keys[0].clone()).await;
301+
assert!(result.is_ok());
302+
let stored_stage_keys = endpoint.stages.keys().collect::<Vec<StageKey>>();
303+
assert_eq!(stored_stage_keys.len(), 2);
304+
assert!(stored_stage_keys.contains(&task_keys[1]));
305+
assert!(stored_stage_keys.contains(&task_keys[2]));
306+
307+
// Run the last partition of task 1.
308+
let result = do_get(1, 1, task_keys[1].clone()).await;
309+
assert!(result.is_ok());
310+
let stored_stage_keys = endpoint.stages.keys().collect::<Vec<StageKey>>();
311+
assert_eq!(stored_stage_keys.len(), 1);
312+
assert!(stored_stage_keys.contains(&task_keys[2]));
313+
314+
// Run the last partition of the last task.
315+
let result = do_get(1, 2, task_keys[2].clone()).await;
316+
assert!(result.is_ok());
317+
let stored_stage_keys = endpoint.stages.keys().collect::<Vec<StageKey>>();
318+
assert_eq!(stored_stage_keys.len(), 0);
319+
}
320+
321+
// Helper to create a mock physical plan proto
322+
fn create_mock_physical_plan_proto(
323+
partitions: usize,
324+
) -> datafusion_proto::protobuf::PhysicalPlanNode {
325+
use datafusion_proto::protobuf::partitioning::PartitionMethod;
326+
use datafusion_proto::protobuf::{
327+
Partitioning, PhysicalPlanNode, RepartitionExecNode, Schema,
328+
};
329+
330+
// Create a repartition node that will have the desired partition count
331+
PhysicalPlanNode {
332+
physical_plan_type: Some(
333+
datafusion_proto::protobuf::physical_plan_node::PhysicalPlanType::Repartition(
334+
Box::new(RepartitionExecNode {
335+
input: Some(Box::new(PhysicalPlanNode {
336+
physical_plan_type: Some(
337+
datafusion_proto::protobuf::physical_plan_node::PhysicalPlanType::Empty(
338+
datafusion_proto::protobuf::EmptyExecNode {
339+
schema: Some(Schema {
340+
columns: vec![],
341+
metadata: std::collections::HashMap::new(),
342+
})
343+
}
344+
)
345+
),
346+
})),
347+
partitioning: Some(Partitioning {
348+
partition_method: Some(PartitionMethod::RoundRobin(partitions as u64)),
349+
}),
350+
})
351+
)
352+
),
353+
}
142354
}
143355
}

src/flight_service/service.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
use crate::channel_manager::ChannelManager;
2+
use crate::common::ttl_map::{TTLMap, TTLMapConfig};
3+
use crate::flight_service::do_get::TaskData;
24
use crate::flight_service::session_builder::DefaultSessionBuilder;
35
use crate::flight_service::DistributedSessionBuilder;
4-
use crate::stage::ExecutionStage;
56
use crate::ChannelResolver;
67
use arrow_flight::flight_service_server::FlightService;
78
use arrow_flight::{
89
Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
910
HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
1011
};
1112
use async_trait::async_trait;
12-
use dashmap::DashMap;
13+
use datafusion::error::DataFusionError;
1314
use datafusion::execution::runtime_env::RuntimeEnv;
14-
use datafusion::execution::SessionState;
1515
use futures::stream::BoxStream;
1616
use std::sync::Arc;
1717
use tokio::sync::OnceCell;
@@ -35,18 +35,21 @@ pub struct ArrowFlightEndpoint {
3535
pub(super) channel_manager: Arc<ChannelManager>,
3636
pub(super) runtime: Arc<RuntimeEnv>,
3737
#[allow(clippy::type_complexity)]
38-
pub(super) stages: DashMap<StageKey, Arc<OnceCell<(SessionState, Arc<ExecutionStage>)>>>,
38+
pub(super) stages: TTLMap<StageKey, Arc<OnceCell<TaskData>>>,
3939
pub(super) session_builder: Arc<dyn DistributedSessionBuilder + Send + Sync>,
4040
}
4141

4242
impl ArrowFlightEndpoint {
43-
pub fn new(channel_resolver: impl ChannelResolver + Send + Sync + 'static) -> Self {
44-
Self {
43+
pub fn new(
44+
channel_resolver: impl ChannelResolver + Send + Sync + 'static,
45+
) -> Result<Self, DataFusionError> {
46+
let ttl_map = TTLMap::try_new(TTLMapConfig::default())?;
47+
Ok(Self {
4548
channel_manager: Arc::new(ChannelManager::new(channel_resolver)),
4649
runtime: Arc::new(RuntimeEnv::default()),
47-
stages: DashMap::new(),
50+
stages: ttl_map,
4851
session_builder: Arc::new(DefaultSessionBuilder),
49-
}
52+
})
5053
}
5154

5255
pub fn with_session_builder(

src/test_utils/localhost.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ pub async fn spawn_flight_service(
106106
session_builder: impl DistributedSessionBuilder + Send + Sync + 'static,
107107
incoming: TcpListener,
108108
) -> Result<(), Box<dyn Error + Send + Sync>> {
109-
let mut endpoint = ArrowFlightEndpoint::new(channel_resolver);
109+
let mut endpoint = ArrowFlightEndpoint::new(channel_resolver)?;
110110
endpoint.with_session_builder(session_builder);
111111

112112
let incoming = tokio_stream::wrappers::TcpListenerStream::new(incoming);

0 commit comments

Comments
 (0)