diff --git a/src/common/ttl_map.rs b/src/common/ttl_map.rs index 568b1f5..4bebcd1 100644 --- a/src/common/ttl_map.rs +++ b/src/common/ttl_map.rs @@ -56,7 +56,7 @@ where /// new creates a new Bucket fn new(data: Arc>) -> Self where - V: Default + Clone + Send + Sync + 'static, + V: Send + Sync + 'static, { // TODO: To avoid unbounded growth, consider making this bounded. Alternatively, we can // introduce a dynamic GC period to ensure that GC can keep up. @@ -84,7 +84,7 @@ where /// task is responsible for managing a subset of keys in the map. async fn task(mut rx: UnboundedReceiver>, data: Arc>) where - V: Default + Clone + Send + Sync + 'static, + V: Send + Sync + 'static, { let mut shard = HashSet::new(); while let Some(op) = rx.recv().await { @@ -207,7 +207,7 @@ where /// get_or_default executes the provided closure with a reference to the map entry for the given key. /// If the key does not exist, it inserts a new entry with the default value. - pub fn get_or_init(&self, key: K, f: F) -> V + pub fn get_or_init(&self, key: K, init: F) -> V where F: FnOnce() -> V, { @@ -218,7 +218,7 @@ where let value = match self.data.entry(key.clone()) { Entry::Vacant(entry) => { - let value = f(); + let value = init(); entry.insert(value.clone()); new_entry = true; value @@ -250,6 +250,25 @@ where value } + /// Removes the key from the map. + /// TODO: Consider removing the key from the time bucket as well. We would need to know which + /// bucket the key was in to do this. One idea is to store the bucket idx in the map value. + pub fn remove(&self, key: K) { + self.data.remove(&key); + } + + /// Returns the number of entries currently stored in the map + #[cfg(test)] + pub fn len(&self) -> usize { + self.data.len() + } + + /// Returns an iterator over the keys currently stored in the map + #[cfg(test)] + pub fn keys(&self) -> impl Iterator + '_ { + self.data.iter().map(|entry| entry.key().clone()) + } + /// run_gc_loop will continuously clear expired entries from the map, checking every `period`. The /// function terminates if `shutdown` is signalled. async fn run_gc_loop(time: Arc, period: Duration, buckets: &[Bucket]) { @@ -349,6 +368,7 @@ mod tests { // All entries expired } + // assert_eventually checks a condition every 10ms for a maximum of timeout async fn assert_eventually(assertion: F, timeout: Duration) where F: Fn() -> bool, @@ -358,12 +378,12 @@ mod tests { if assertion() { return; } - tokio::time::sleep(Duration::from_millis(50)).await; + tokio::time::sleep(Duration::from_millis(10)).await; } panic!("Assertion failed within {:?}", timeout); } - #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + #[tokio::test(flavor = "multi_thread", worker_threads = 8)] async fn test_concurrent_gc_and_access() { let ttl_map = TTLMap::::try_new(TTLMapConfig { ttl: Duration::from_millis(10), @@ -377,22 +397,30 @@ mod tests { // Spawn 5 concurrent tasks let mut handles = Vec::new(); - for task_id in 0..5 { + for task_id in 0..10 { let map = Arc::clone(&ttl_map); - let handle = tokio::spawn(async move { - for i in 0..20 { - let key = format!("task{}_key{}", task_id, i % 4); - map.get_or_init(key, || task_id * 100 + i); - sleep(Duration::from_millis(1)).await; + handles.push(tokio::spawn(async move { + for i in 0..100 { + let key = format!("task{}_key{}", task_id, i % 10); + map.get_or_init(key.clone(), || task_id * 100 + i); } - }); - handles.push(handle); + })); + let map2 = Arc::clone(&ttl_map); + handles.push(tokio::spawn(async move { + // Remove some keys which may or may not exist. + for i in 0..50 { + let key = format!("task{}_key{}", task_id, i % 15); + map2.remove(key) + } + })); } // Wait for all tasks to complete for handle in handles { handle.await.unwrap(); } + + assert_eventually(|| ttl_map.data.len() == 0, Duration::from_millis(20)).await; } #[tokio::test] @@ -475,4 +503,41 @@ mod tests { ttl_map.metrics.ttl_accounting_time.load(Ordering::SeqCst) / 1_000_000 ); } + + #[tokio::test] + async fn test_remove_with_manual_gc() { + let ttl_map = TTLMap::::_new(TTLMapConfig { + ttl: Duration::from_millis(50), + tick: Duration::from_millis(10), + }); + + ttl_map.get_or_init("key1".to_string(), || 100); + ttl_map.get_or_init("key2".to_string(), || 200); + ttl_map.get_or_init("key3".to_string(), || 300); + assert_eq!(ttl_map.data.len(), 3); + + // Remove key2 and verify the others remain. + ttl_map.remove("key2".to_string()); + assert_eq!(ttl_map.data.len(), 2); + let val1 = ttl_map.get_or_init("key1".to_string(), || 999); + assert_eq!(val1, 100); + let val3 = ttl_map.get_or_init("key3".to_string(), || 999); + assert_eq!(val3, 300); + + // key2 should be recreated with new value + let val2 = ttl_map.get_or_init("key2".to_string(), || 999); + assert_eq!(val2, 999); // New value since it was removed + assert_eq!(ttl_map.data.len(), 3); + let val3 = ttl_map.get_or_init("key2".to_string(), || 200); + assert_eq!(val3, 999); + + // Remove key1 before GCing. + ttl_map.remove("key1".to_string()); + + // Run GC and verify the map is empty. + for _ in 0..5 { + TTLMap::::gc(ttl_map.time.clone(), &ttl_map.buckets); + } + assert_eventually(|| ttl_map.data.len() == 0, Duration::from_millis(100)).await; + } } diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index adc7d74..a738986 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -14,7 +14,9 @@ use arrow_flight::Ticket; use datafusion::execution::SessionState; use futures::TryStreamExt; use prost::Message; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use tokio::sync::OnceCell; use tonic::metadata::MetadataMap; use tonic::{Request, Response, Status}; @@ -37,6 +39,20 @@ pub struct DoGet { pub stage_key: Option, } +#[derive(Clone, Debug)] +/// TaskData stores state for a single task being executed by this Endpoint. It may be shared +/// by concurrent requests for the same task which execute separate partitions. +pub struct TaskData { + pub(super) state: SessionState, + pub(super) stage: Arc, + ///num_partitions_remaining is initialized to the total number of partitions in the task (not + /// only tasks in the partition group). This is decremented for each request to the endpoint + /// for this task. Once this count is zero, the task is likely complete. The task may not be + /// complete because it's possible that the same partition was retried and this count was + /// decremented more than once for the same partition. + num_partitions_remaining: Arc, +} + impl ArrowFlightEndpoint { pub(super) async fn get( &self, @@ -50,7 +66,10 @@ impl ArrowFlightEndpoint { let partition = doget.partition as usize; let task_number = doget.task_number as usize; - let (mut state, stage) = self.get_state_and_stage(doget, metadata).await?; + let task_data = self.get_state_and_stage(doget, metadata).await?; + + let stage = task_data.stage; + let mut state = task_data.state; // find out which partition group we are executing let task = stage @@ -90,16 +109,15 @@ impl ArrowFlightEndpoint { &self, doget: DoGet, metadata_map: MetadataMap, - ) -> Result<(SessionState, Arc), Status> { + ) -> Result { let key = doget .stage_key .ok_or(Status::invalid_argument("DoGet is missing the stage key"))?; - let once_stage = { - let entry = self.stages.entry(key).or_default(); - Arc::clone(&entry) - }; + let once_stage = self + .stages + .get_or_init(key.clone(), || Arc::new(OnceCell::::new())); - let (state, stage) = once_stage + let stage_data = once_stage .get_or_try_init(|| async { let stage_proto = doget .stage_proto @@ -133,10 +151,183 @@ impl ArrowFlightEndpoint { config.set_extension(stage.clone()); config.set_extension(Arc::new(ContextGrpcMetadata(headers))); - Ok::<_, Status>((state, stage)) + // Initialize partition count to the number of partitions in the stage + let total_partitions = stage.plan.properties().partitioning.partition_count(); + Ok::<_, Status>(TaskData { + state, + stage, + num_partitions_remaining: Arc::new(AtomicUsize::new(total_partitions)), + }) }) .await?; - Ok((state.clone(), stage.clone())) + // If all the partitions are done, remove the stage from the cache. + let remaining_partitions = stage_data + .num_partitions_remaining + .fetch_sub(1, Ordering::SeqCst); + if remaining_partitions <= 1 { + self.stages.remove(key.clone()); + } + + Ok(stage_data.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use uuid::Uuid; + + #[tokio::test] + async fn test_task_data_partition_counting() { + use crate::flight_service::session_builder::DefaultSessionBuilder; + use crate::task::ExecutionTask; + use arrow_flight::Ticket; + use prost::{bytes::Bytes, Message}; + use tonic::Request; + + // Create ArrowFlightEndpoint with DefaultSessionBuilder + let endpoint = + ArrowFlightEndpoint::new(DefaultSessionBuilder).expect("Failed to create endpoint"); + + // Create 3 tasks with 3 partitions each. + let num_tasks = 3; + let num_partitions_per_task = 3; + let stage_id = 1; + let query_id_uuid = Uuid::new_v4(); + let query_id = query_id_uuid.as_bytes().to_vec(); + + // Set up protos. + let mut tasks = Vec::new(); + for i in 0..num_tasks { + tasks.push(ExecutionTask { + url_str: None, + partition_group: vec![i], // Set a random partition in the partition group. + }); + } + + let stage_proto = ExecutionStageProto { + query_id: query_id.clone(), + num: 1, + name: format!("test_stage_{}", 1), + plan: Some(Box::new(create_mock_physical_plan_proto( + num_partitions_per_task, + ))), + inputs: vec![], + tasks, + }; + + let task_keys = vec![ + StageKey { + query_id: query_id_uuid.to_string(), + stage_id, + task_number: 0, + }, + StageKey { + query_id: query_id_uuid.to_string(), + stage_id, + task_number: 1, + }, + StageKey { + query_id: query_id_uuid.to_string(), + stage_id, + task_number: 2, + }, + ]; + + let stage_proto_for_closure = stage_proto.clone(); + let endpoint_ref = &endpoint; + let do_get = async move |partition: u64, task_number: u64, stage_key: StageKey| { + let stage_proto = stage_proto_for_closure.clone(); + // Create DoGet message + let doget = DoGet { + stage_proto: Some(stage_proto), + task_number, + partition, + stage_key: Some(stage_key), + }; + + // Create Flight ticket + let ticket = Ticket { + ticket: Bytes::from(doget.encode_to_vec()), + }; + + // Call the actual get() method + let request = Request::new(ticket); + endpoint_ref.get(request).await + }; + + // For each task, call do_get() for each partition except the last. + for task_number in 0..num_tasks { + for partition in 0..num_partitions_per_task - 1 { + let result = do_get( + partition as u64, + task_number, + task_keys[task_number as usize].clone(), + ) + .await; + assert!(result.is_ok()); + } + } + + // Check that the endpoint has not evicted any task states. + assert_eq!(endpoint.stages.len(), num_tasks as usize); + + // Run the last partition of task 0. Any partition number works. Verify that the task state + // is evicted because all partitions have been processed. + let result = do_get(1, 0, task_keys[0].clone()).await; + assert!(result.is_ok()); + let stored_stage_keys = endpoint.stages.keys().collect::>(); + assert_eq!(stored_stage_keys.len(), 2); + assert!(stored_stage_keys.contains(&task_keys[1])); + assert!(stored_stage_keys.contains(&task_keys[2])); + + // Run the last partition of task 1. + let result = do_get(1, 1, task_keys[1].clone()).await; + assert!(result.is_ok()); + let stored_stage_keys = endpoint.stages.keys().collect::>(); + assert_eq!(stored_stage_keys.len(), 1); + assert!(stored_stage_keys.contains(&task_keys[2])); + + // Run the last partition of the last task. + let result = do_get(1, 2, task_keys[2].clone()).await; + assert!(result.is_ok()); + let stored_stage_keys = endpoint.stages.keys().collect::>(); + assert_eq!(stored_stage_keys.len(), 0); + } + + // Helper to create a mock physical plan proto + fn create_mock_physical_plan_proto( + partitions: usize, + ) -> datafusion_proto::protobuf::PhysicalPlanNode { + use datafusion_proto::protobuf::partitioning::PartitionMethod; + use datafusion_proto::protobuf::{ + Partitioning, PhysicalPlanNode, RepartitionExecNode, Schema, + }; + + // Create a repartition node that will have the desired partition count + PhysicalPlanNode { + physical_plan_type: Some( + datafusion_proto::protobuf::physical_plan_node::PhysicalPlanType::Repartition( + Box::new(RepartitionExecNode { + input: Some(Box::new(PhysicalPlanNode { + physical_plan_type: Some( + datafusion_proto::protobuf::physical_plan_node::PhysicalPlanType::Empty( + datafusion_proto::protobuf::EmptyExecNode { + schema: Some(Schema { + columns: vec![], + metadata: std::collections::HashMap::new(), + }) + } + ) + ), + })), + partitioning: Some(Partitioning { + partition_method: Some(PartitionMethod::RoundRobin(partitions as u64)), + }), + }) + ) + ), + } } } diff --git a/src/flight_service/service.rs b/src/flight_service/service.rs index 403ec20..510627d 100644 --- a/src/flight_service/service.rs +++ b/src/flight_service/service.rs @@ -1,14 +1,14 @@ +use crate::common::ttl_map::{TTLMap, TTLMapConfig}; +use crate::flight_service::do_get::TaskData; use crate::flight_service::DistributedSessionBuilder; -use crate::stage::ExecutionStage; use arrow_flight::flight_service_server::FlightService; use arrow_flight::{ Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket, }; use async_trait::async_trait; -use dashmap::DashMap; +use datafusion::error::DataFusionError; use datafusion::execution::runtime_env::RuntimeEnv; -use datafusion::execution::SessionState; use futures::stream::BoxStream; use std::sync::Arc; use tokio::sync::OnceCell; @@ -31,17 +31,20 @@ pub struct StageKey { pub struct ArrowFlightEndpoint { pub(super) runtime: Arc, #[allow(clippy::type_complexity)] - pub(super) stages: DashMap)>>>, + pub(super) stages: TTLMap>>, pub(super) session_builder: Arc, } impl ArrowFlightEndpoint { - pub fn new(session_builder: impl DistributedSessionBuilder + Send + Sync + 'static) -> Self { - Self { + pub fn new( + session_builder: impl DistributedSessionBuilder + Send + Sync + 'static, + ) -> Result { + let ttl_map = TTLMap::try_new(TTLMapConfig::default())?; + Ok(Self { runtime: Arc::new(RuntimeEnv::default()), - stages: DashMap::new(), + stages: ttl_map, session_builder: Arc::new(session_builder), - } + }) } } diff --git a/src/test_utils/localhost.rs b/src/test_utils/localhost.rs index c20472d..2dcda1b 100644 --- a/src/test_utils/localhost.rs +++ b/src/test_utils/localhost.rs @@ -109,7 +109,7 @@ pub async fn spawn_flight_service( session_builder: impl DistributedSessionBuilder + Send + Sync + 'static, incoming: TcpListener, ) -> Result<(), Box> { - let endpoint = ArrowFlightEndpoint::new(session_builder); + let endpoint = ArrowFlightEndpoint::new(session_builder)?; let incoming = tokio_stream::wrappers::TcpListenerStream::new(incoming);