Skip to content

Commit f98c78d

Browse files
add test
1 parent 03452bc commit f98c78d

File tree

3 files changed

+203
-7
lines changed

3 files changed

+203
-7
lines changed

src/common/ttl_map.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,18 @@ where
257257
self.data.remove(&key);
258258
}
259259

260+
/// Returns the number of entries currently stored in the map
261+
#[cfg(test)]
262+
pub fn len(&self) -> usize {
263+
self.data.len()
264+
}
265+
266+
/// Returns an iterator over the keys currently stored in the map
267+
#[cfg(test)]
268+
pub fn keys(&self) -> impl Iterator<Item = K> + '_ {
269+
self.data.iter().map(|entry| entry.key().clone())
270+
}
271+
260272
/// run_gc_loop will continuously clear expired entries from the map, checking every `period`. The
261273
/// function terminates if `shutdown` is signalled.
262274
async fn run_gc_loop(time: Arc<AtomicU64>, period: Duration, buckets: &[Bucket<K>]) {

src/flight_service/do_get.rs

Lines changed: 190 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ pub struct DoGet {
3939
pub stage_key: Option<StageKey>,
4040
}
4141

42-
#[derive(Clone)]
42+
#[derive(Clone, Debug)]
4343
pub struct TaskData {
4444
pub(super) state: SessionState,
4545
pub(super) stage: Arc<ExecutionStage>,
@@ -108,9 +108,9 @@ impl ArrowFlightEndpoint {
108108
.ok_or(Status::invalid_argument("DoGet is missing the stage key"))?;
109109
let once_stage = self
110110
.stages
111-
.get_or_init(key.clone(), || OnceCell::<TaskData>::new());
111+
.get_or_init(key.clone(), || Arc::new(OnceCell::<TaskData>::new()));
112112

113-
let mut stage_data = once_stage
113+
let stage_data = once_stage
114114
.get_or_try_init(|| async {
115115
let stage_proto = doget
116116
.stage_proto
@@ -145,19 +145,203 @@ impl ArrowFlightEndpoint {
145145
config.set_extension(stage.clone());
146146
config.set_extension(Arc::new(ContextGrpcMetadata(headers)));
147147

148+
// Initialize partition count to the number of partitions in the stage
149+
let total_partitions = stage.plan.properties().partitioning.partition_count();
148150
Ok::<_, Status>(TaskData {
149151
state,
150152
stage,
151-
partition_count: Arc::new(AtomicUsize::new(0)),
153+
partition_count: Arc::new(AtomicUsize::new(total_partitions)),
152154
})
153155
})
154156
.await?;
155157

156-
stage_data.partition_count.fetch_sub(1, Ordering::SeqCst);
157-
if stage_data.partition_count.load(Ordering::SeqCst) <= 0 {
158+
let remaining_partitions = stage_data.partition_count.fetch_sub(1, Ordering::SeqCst);
159+
if remaining_partitions <= 1 {
158160
self.stages.remove(key.clone());
159161
}
160162

161163
Ok(stage_data.clone())
162164
}
163165
}
166+
167+
#[cfg(test)]
168+
mod tests {
169+
use super::*;
170+
use arrow_flight::Ticket;
171+
use prost::{bytes::Bytes, Message};
172+
use uuid::Uuid;
173+
174+
#[tokio::test]
175+
async fn test_task_data_partition_counting() {
176+
use crate::task::ExecutionTask;
177+
use arrow_flight::Ticket;
178+
use prost::{bytes::Bytes, Message};
179+
use tonic::Request;
180+
use url::Url;
181+
182+
// Create a mock channel resolver for ArrowFlightEndpoint
183+
#[derive(Clone)]
184+
struct MockChannelResolver;
185+
186+
#[async_trait::async_trait]
187+
impl crate::ChannelResolver for MockChannelResolver {
188+
fn get_urls(&self) -> Result<Vec<Url>, datafusion::error::DataFusionError> {
189+
Ok(vec![])
190+
}
191+
192+
async fn get_channel_for_url(
193+
&self,
194+
_url: &Url,
195+
) -> Result<crate::BoxCloneSyncChannel, datafusion::error::DataFusionError>
196+
{
197+
Err(datafusion::error::DataFusionError::NotImplemented(
198+
"Mock resolver".to_string(),
199+
))
200+
}
201+
}
202+
203+
// Create ArrowFlightEndpoint
204+
let endpoint =
205+
ArrowFlightEndpoint::new(MockChannelResolver).expect("Failed to create endpoint");
206+
207+
// Test parameters: 2 stages, each with 3 tasks, each task has 10 partitions
208+
let num_tasks = 3;
209+
let num_partitions_per_task = 3;
210+
let stage_id = 1;
211+
let query_id_uuid = Uuid::new_v4();
212+
let query_id = query_id_uuid.as_bytes().to_vec();
213+
214+
// Create ExecutionStage proto with multiple tasks
215+
let mut tasks = Vec::new();
216+
for i in 0..num_tasks {
217+
tasks.push(ExecutionTask {
218+
url_str: None,
219+
partition_group: vec![i], // Different partition group for each task
220+
});
221+
}
222+
223+
let stage_proto = ExecutionStageProto {
224+
query_id: query_id.clone(),
225+
num: 1,
226+
name: format!("test_stage_{}", 1),
227+
plan: Some(Box::new(create_mock_physical_plan_proto(
228+
num_partitions_per_task,
229+
))),
230+
inputs: vec![],
231+
tasks,
232+
};
233+
234+
let task_keys = vec![
235+
StageKey {
236+
query_id: query_id_uuid.to_string(),
237+
stage_id,
238+
task_number: 0,
239+
},
240+
StageKey {
241+
query_id: query_id_uuid.to_string(),
242+
stage_id,
243+
task_number: 1,
244+
},
245+
StageKey {
246+
query_id: query_id_uuid.to_string(),
247+
stage_id,
248+
task_number: 2,
249+
},
250+
];
251+
252+
let stage_proto_for_closure = stage_proto.clone();
253+
let endpoint_ref = &endpoint;
254+
let do_get = async move |partition: u64, task_number: u64, stage_key: StageKey| {
255+
let stage_proto = stage_proto_for_closure.clone();
256+
// Create DoGet message
257+
let doget = DoGet {
258+
stage_proto: Some(stage_proto),
259+
task_number,
260+
partition,
261+
stage_key: Some(stage_key),
262+
};
263+
264+
// Create Flight ticket
265+
let ticket = Ticket {
266+
ticket: Bytes::from(doget.encode_to_vec()),
267+
};
268+
269+
// Call the actual get() method
270+
let request = Request::new(ticket);
271+
endpoint_ref.get(request).await
272+
};
273+
274+
// For each task, call do_get() for each partition except the last.
275+
for task_number in 0..num_tasks {
276+
for partition in 0..num_partitions_per_task - 1 {
277+
let result = do_get(
278+
partition as u64,
279+
task_number,
280+
task_keys[task_number as usize].clone(),
281+
)
282+
.await;
283+
assert!(result.is_ok());
284+
}
285+
}
286+
287+
// Check that the endpoint has not evicted any tasks.
288+
assert_eq!(endpoint.stages.len(), num_tasks as usize);
289+
290+
// Run the last partition of task 0. Any partition number works. Verify that the task state
291+
// is evicted because all partitions have been processed.
292+
let result = do_get(1, 0, task_keys[0].clone()).await;
293+
assert!(result.is_ok());
294+
let stored_stage_keys = endpoint.stages.keys().collect::<Vec<StageKey>>();
295+
assert_eq!(stored_stage_keys.len(), 2);
296+
assert!(stored_stage_keys.contains(&task_keys[1]));
297+
assert!(stored_stage_keys.contains(&task_keys[2]));
298+
299+
// Run the last partition of task 1.
300+
let result = do_get(1, 1, task_keys[1].clone()).await;
301+
assert!(result.is_ok());
302+
let stored_stage_keys = endpoint.stages.keys().collect::<Vec<StageKey>>();
303+
assert_eq!(stored_stage_keys.len(), 1);
304+
assert!(stored_stage_keys.contains(&task_keys[2]));
305+
306+
// Run the last partition of the last task.
307+
let result = do_get(1, 2, task_keys[2].clone()).await;
308+
assert!(result.is_ok());
309+
let stored_stage_keys = endpoint.stages.keys().collect::<Vec<StageKey>>();
310+
assert_eq!(stored_stage_keys.len(), 0);
311+
}
312+
313+
// Helper to create a mock physical plan proto
314+
fn create_mock_physical_plan_proto(
315+
partitions: usize,
316+
) -> datafusion_proto::protobuf::PhysicalPlanNode {
317+
use datafusion_proto::protobuf::partitioning::PartitionMethod;
318+
use datafusion_proto::protobuf::{
319+
Partitioning, PhysicalPlanNode, RepartitionExecNode, Schema,
320+
};
321+
322+
// Create a repartition node that will have the desired partition count
323+
PhysicalPlanNode {
324+
physical_plan_type: Some(
325+
datafusion_proto::protobuf::physical_plan_node::PhysicalPlanType::Repartition(
326+
Box::new(RepartitionExecNode {
327+
input: Some(Box::new(PhysicalPlanNode {
328+
physical_plan_type: Some(
329+
datafusion_proto::protobuf::physical_plan_node::PhysicalPlanType::Empty(
330+
datafusion_proto::protobuf::EmptyExecNode {
331+
schema: Some(Schema {
332+
columns: vec![],
333+
metadata: std::collections::HashMap::new(),
334+
})
335+
}
336+
)
337+
),
338+
})),
339+
partitioning: Some(Partitioning {
340+
partition_method: Some(PartitionMethod::RoundRobin(partitions as u64)),
341+
}),
342+
})
343+
)
344+
),
345+
}
346+
}
347+
}

src/flight_service/service.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ pub struct ArrowFlightEndpoint {
3535
pub(super) channel_manager: Arc<ChannelManager>,
3636
pub(super) runtime: Arc<RuntimeEnv>,
3737
#[allow(clippy::type_complexity)]
38-
pub(super) stages: TTLMap<StageKey, OnceCell<TaskData>>,
38+
pub(super) stages: TTLMap<StageKey, Arc<OnceCell<TaskData>>>,
3939
pub(super) session_builder: Arc<dyn DistributedSessionBuilder + Send + Sync>,
4040
}
4141

0 commit comments

Comments
 (0)