diff --git a/src/execution_plans/network_coalesce.rs b/src/execution_plans/network_coalesce.rs index 7e3ce7a..f6653f2 100644 --- a/src/execution_plans/network_coalesce.rs +++ b/src/execution_plans/network_coalesce.rs @@ -280,11 +280,11 @@ impl ExecutionPlan for NetworkCoalesceExec { ticket: DoGet { plan_proto: encoded_input_plan.clone(), target_partition: target_partition as u64, - stage_key: Some(StageKey { - query_id: Bytes::from(input_stage.query_id.as_bytes().to_vec()), - stage_id: input_stage.num as u64, - task_number: target_task as u64, - }), + stage_key: Some(StageKey::new( + Bytes::from(input_stage.query_id.as_bytes().to_vec()), + input_stage.num as u64, + target_task as u64, + )), target_task_index: target_task as u64, target_task_count: input_stage.tasks.len() as u64, } diff --git a/src/execution_plans/network_shuffle.rs b/src/execution_plans/network_shuffle.rs index 9f18289..93376dd 100644 --- a/src/execution_plans/network_shuffle.rs +++ b/src/execution_plans/network_shuffle.rs @@ -334,11 +334,7 @@ impl ExecutionPlan for NetworkShuffleExec { ticket: DoGet { plan_proto: encoded_input_plan.clone(), target_partition: (off + partition) as u64, - stage_key: Some(StageKey { - query_id: query_id.clone(), - stage_id: input_stage_num, - task_number: i as u64, - }), + stage_key: Some(StageKey::new(query_id.clone(), input_stage_num, i as u64)), target_task_index: i as u64, target_task_count: input_task_count as u64, } diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index 485fbee..396e9c7 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -311,23 +311,10 @@ mod tests { .encode_to_vec() .into(); - let task_keys = [ - StageKey { - query_id: query_id.clone(), - stage_id, - task_number: 0, - }, - StageKey { - query_id: query_id.clone(), - stage_id, - task_number: 1, - }, - StageKey { - query_id: query_id.clone(), - stage_id, - task_number: 2, - }, - ]; + let task_keys: Vec<_> = (0..3) + .map(|i| StageKey::new(query_id.clone(), stage_id, i)) + .collect(); + let plan_proto_for_closure = plan_proto.clone(); let endpoint_ref = &endpoint; diff --git a/src/metrics/metrics_collecting_stream.rs b/src/metrics/metrics_collecting_stream.rs index a60ba25..744a944 100644 --- a/src/metrics/metrics_collecting_stream.rs +++ b/src/metrics/metrics_collecting_stream.rs @@ -135,19 +135,9 @@ mod tests { // and asserts that the metrics are collected correctly. #[tokio::test] async fn test_metrics_collecting_stream_extracts_and_removes_metadata() { - let stage_keys = vec![ - StageKey { - query_id: Bytes::from("test_query"), - stage_id: 1, - task_number: 1, - }, - StageKey { - query_id: Bytes::from("test_query"), - stage_id: 1, - task_number: 2, - }, - ]; - + let stage_keys: Vec<_> = (1..3) + .map(|i| StageKey::new(Bytes::from("test_query"), 1, i)) + .collect(); let app_metadatas = stage_keys .iter() .map(|stage_key| FlightAppMetadata { diff --git a/src/metrics/task_metrics_rewriter.rs b/src/metrics/task_metrics_rewriter.rs index e170fe3..5000bbe 100644 --- a/src/metrics/task_metrics_rewriter.rs +++ b/src/metrics/task_metrics_rewriter.rs @@ -153,11 +153,7 @@ pub fn stage_metrics_rewriter( let mut stage_metrics = MetricsSetProto::new(); for idx in 0..stage.tasks.len() { - let stage_key = StageKey { - query_id: Bytes::from(stage.query_id.as_bytes().to_vec()), - stage_id: stage.num as u64, - task_number: idx as u64, - }; + let stage_key = StageKey::new(Bytes::from(stage.query_id.as_bytes().to_vec()), stage.num as u64, idx as u64); match metrics_collection.get(&stage_key) { Some(task_metrics) => { if node_idx >= task_metrics.len() { @@ -353,11 +349,11 @@ mod tests { // Generate metrics for each task and store them in the map. let mut metrics_collection = HashMap::new(); for task_id in 0..stage.tasks.len() { - let stage_key = StageKey { - query_id: Bytes::from(stage.query_id.as_bytes().to_vec()), - stage_id: stage.num as u64, - task_number: task_id as u64, - }; + let stage_key = StageKey::new( + Bytes::from(stage.query_id.as_bytes().to_vec()), + stage.num as u64, + task_id as u64, + ); let metrics = (0..count_plan_nodes(&plan)) .map(|node_id| { make_test_metrics_set_proto_from_seed( @@ -390,11 +386,11 @@ mod tests { .enumerate() { let expected_task_node_metrics = metrics_collection - .get(&StageKey { - query_id: Bytes::from(stage.query_id.as_bytes().to_vec()), - stage_id: stage.num as u64, - task_number: task_id as u64, - }) + .get(&StageKey::new( + Bytes::from(stage.query_id.as_bytes().to_vec()), + stage.num as u64, + task_id as u64, + )) .unwrap()[node_id] .clone(); diff --git a/src/protobuf/distributed_codec.rs b/src/protobuf/distributed_codec.rs index ab7070e..29a7173 100644 --- a/src/protobuf/distributed_codec.rs +++ b/src/protobuf/distributed_codec.rs @@ -239,7 +239,7 @@ impl PhysicalExtensionCodec for DistributedCodec { } } -/// A key that uniquely identifies a stage in a query +/// A key that uniquely identifies a stage in a query. #[derive(Clone, Hash, Eq, PartialEq, ::prost::Message)] pub struct StageKey { /// Our query id @@ -253,6 +253,17 @@ pub struct StageKey { pub task_number: u64, } +impl StageKey { + /// Creates a new `StageKey`. + pub fn new(query_id: Bytes, stage_id: u64, task_number: u64) -> StageKey { + Self { + query_id, + stage_id, + task_number, + } + } +} + #[derive(Clone, PartialEq, ::prost::Message)] pub struct StageProto { /// Our query id diff --git a/src/test_utils/plans.rs b/src/test_utils/plans.rs index 6a09a5e..b7df910 100644 --- a/src/test_utils/plans.rs +++ b/src/test_utils/plans.rs @@ -1,4 +1,3 @@ -use bytes::Bytes; use datafusion::{ common::{HashMap, HashSet}, physical_plan::ExecutionPlan, @@ -51,12 +50,11 @@ pub fn get_stages_and_stage_keys( // Add each task. for j in 0..stage.tasks.len() { - let stage_key = StageKey { - query_id: Bytes::from(stage.query_id.as_bytes().to_vec()), - stage_id: stage.num as u64, - task_number: j as u64, - }; - stage_keys.insert(stage_key); + stage_keys.insert(StageKey::new( + stage.query_id.as_bytes().to_vec().into(), + stage.num as u64, + j as u64, + )); } // Add any child stages