Skip to content

Commit c3ef7b8

Browse files
JSOD11jayshrivastava
authored andcommitted
Create stage key constructor.
1 parent dbf80f9 commit c3ef7b8

File tree

7 files changed

+39
-62
lines changed

7 files changed

+39
-62
lines changed

src/execution_plans/network_coalesce.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -280,11 +280,11 @@ impl ExecutionPlan for NetworkCoalesceExec {
280280
ticket: DoGet {
281281
plan_proto: encoded_input_plan.clone(),
282282
target_partition: target_partition as u64,
283-
stage_key: Some(StageKey {
284-
query_id: Bytes::from(input_stage.query_id.as_bytes().to_vec()),
285-
stage_id: input_stage.num as u64,
286-
task_number: target_task as u64,
287-
}),
283+
stage_key: Some(StageKey::new(
284+
Bytes::from(input_stage.query_id.as_bytes().to_vec()),
285+
input_stage.num as u64,
286+
target_task as u64,
287+
)),
288288
target_task_index: target_task as u64,
289289
target_task_count: input_stage.tasks.len() as u64,
290290
}

src/execution_plans/network_shuffle.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -334,11 +334,7 @@ impl ExecutionPlan for NetworkShuffleExec {
334334
ticket: DoGet {
335335
plan_proto: encoded_input_plan.clone(),
336336
target_partition: (off + partition) as u64,
337-
stage_key: Some(StageKey {
338-
query_id: query_id.clone(),
339-
stage_id: input_stage_num,
340-
task_number: i as u64,
341-
}),
337+
stage_key: Some(StageKey::new(query_id.clone(), input_stage_num, i as u64)),
342338
target_task_index: i as u64,
343339
target_task_count: input_task_count as u64,
344340
}

src/flight_service/do_get.rs

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -311,23 +311,10 @@ mod tests {
311311
.encode_to_vec()
312312
.into();
313313

314-
let task_keys = [
315-
StageKey {
316-
query_id: query_id.clone(),
317-
stage_id,
318-
task_number: 0,
319-
},
320-
StageKey {
321-
query_id: query_id.clone(),
322-
stage_id,
323-
task_number: 1,
324-
},
325-
StageKey {
326-
query_id: query_id.clone(),
327-
stage_id,
328-
task_number: 2,
329-
},
330-
];
314+
let task_keys: Vec<_> = (0..3)
315+
.map(|i| StageKey::new(query_id.clone(), stage_id, i))
316+
.collect();
317+
331318
let plan_proto_for_closure = plan_proto.clone();
332319
let endpoint_ref = &endpoint;
333320

src/metrics/metrics_collecting_stream.rs

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -135,19 +135,9 @@ mod tests {
135135
// and asserts that the metrics are collected correctly.
136136
#[tokio::test]
137137
async fn test_metrics_collecting_stream_extracts_and_removes_metadata() {
138-
let stage_keys = vec![
139-
StageKey {
140-
query_id: Bytes::from("test_query"),
141-
stage_id: 1,
142-
task_number: 1,
143-
},
144-
StageKey {
145-
query_id: Bytes::from("test_query"),
146-
stage_id: 1,
147-
task_number: 2,
148-
},
149-
];
150-
138+
let stage_keys: Vec<_> = (1..3)
139+
.map(|i| StageKey::new(Bytes::from("test_query"), 1, i))
140+
.collect();
151141
let app_metadatas = stage_keys
152142
.iter()
153143
.map(|stage_key| FlightAppMetadata {

src/metrics/task_metrics_rewriter.rs

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,7 @@ pub fn stage_metrics_rewriter(
153153
let mut stage_metrics = MetricsSetProto::new();
154154

155155
for idx in 0..stage.tasks.len() {
156-
let stage_key = StageKey {
157-
query_id: Bytes::from(stage.query_id.as_bytes().to_vec()),
158-
stage_id: stage.num as u64,
159-
task_number: idx as u64,
160-
};
156+
let stage_key = StageKey::new(Bytes::from(stage.query_id.as_bytes().to_vec()), stage.num as u64, idx as u64);
161157
match metrics_collection.get(&stage_key) {
162158
Some(task_metrics) => {
163159
if node_idx >= task_metrics.len() {
@@ -353,11 +349,11 @@ mod tests {
353349
// Generate metrics for each task and store them in the map.
354350
let mut metrics_collection = HashMap::new();
355351
for task_id in 0..stage.tasks.len() {
356-
let stage_key = StageKey {
357-
query_id: Bytes::from(stage.query_id.as_bytes().to_vec()),
358-
stage_id: stage.num as u64,
359-
task_number: task_id as u64,
360-
};
352+
let stage_key = StageKey::new(
353+
Bytes::from(stage.query_id.as_bytes().to_vec()),
354+
stage.num as u64,
355+
task_id as u64,
356+
);
361357
let metrics = (0..count_plan_nodes(&plan))
362358
.map(|node_id| {
363359
make_test_metrics_set_proto_from_seed(
@@ -390,11 +386,11 @@ mod tests {
390386
.enumerate()
391387
{
392388
let expected_task_node_metrics = metrics_collection
393-
.get(&StageKey {
394-
query_id: Bytes::from(stage.query_id.as_bytes().to_vec()),
395-
stage_id: stage.num as u64,
396-
task_number: task_id as u64,
397-
})
389+
.get(&StageKey::new(
390+
Bytes::from(stage.query_id.as_bytes().to_vec()),
391+
stage.num as u64,
392+
task_id as u64,
393+
))
398394
.unwrap()[node_id]
399395
.clone();
400396

src/protobuf/distributed_codec.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,16 @@ pub struct StageKey {
253253
pub task_number: u64,
254254
}
255255

256+
impl StageKey {
257+
pub fn new(query_id: Bytes, stage_id: u64, task_number: u64) -> StageKey {
258+
Self {
259+
query_id,
260+
stage_id,
261+
task_number,
262+
}
263+
}
264+
}
265+
256266
#[derive(Clone, PartialEq, ::prost::Message)]
257267
pub struct StageProto {
258268
/// Our query id

src/test_utils/plans.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use bytes::Bytes;
21
use datafusion::{
32
common::{HashMap, HashSet},
43
physical_plan::ExecutionPlan,
@@ -51,12 +50,11 @@ pub fn get_stages_and_stage_keys(
5150

5251
// Add each task.
5352
for j in 0..stage.tasks.len() {
54-
let stage_key = StageKey {
55-
query_id: Bytes::from(stage.query_id.as_bytes().to_vec()),
56-
stage_id: stage.num as u64,
57-
task_number: j as u64,
58-
};
59-
stage_keys.insert(stage_key);
53+
stage_keys.insert(StageKey::new(
54+
stage.query_id.as_bytes().to_vec().into(),
55+
stage.num as u64,
56+
j as u64,
57+
));
6058
}
6159

6260
// Add any child stages

0 commit comments

Comments
 (0)