Skip to content

Commit 0147b8d

Browse files
execution_plans: add MetricsCollectingStream
This change adds a new type called `MetricsCollectingStream`. It wraps a stream of `FlightData` and collects any metrics that are passed in the `app_metadata`. This change also introduces an `FlightAppMetadata` enum proto which can be used to define our app metadata protocol.
1 parent fbfc512 commit 0147b8d

File tree

8 files changed

+350
-53
lines changed

8 files changed

+350
-53
lines changed

src/execution_plans/metrics.rs

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -268,15 +268,13 @@ impl ExecutionPlan for MetricsWrapperExec {
268268

269269
#[cfg(test)]
270270
mod tests {
271-
use crate::metrics::proto::{
272-
ElapsedCompute, EndTimestamp, MetricProto, MetricValueProto, OutputRows, StartTimestamp,
273-
};
274271

275272
use super::*;
276273
use datafusion::arrow::array::{Int32Array, StringArray};
277274
use datafusion::arrow::record_batch::RecordBatch;
278275

279276
use crate::test_utils::in_memory_channel_resolver::InMemoryChannelResolver;
277+
use crate::test_utils::metrics::make_test_metrics_set_proto_from_seed;
280278
use crate::test_utils::session_context::register_temp_parquet_table;
281279
use crate::DistributedExt;
282280
use crate::DistributedPhysicalOptimizerRule;
@@ -363,47 +361,12 @@ mod tests {
363361
(stage_exec, ctx)
364362
}
365363

366-
/// creates a "distinct" set of metrics from the provided seed
367-
fn make_distinct_metrics_set(seed: u64) -> MetricsSetProto {
368-
const TEST_TIMESTAMP: i64 = 1758200400000000000; // 2025-09-18 13:00:00 UTC
369-
MetricsSetProto {
370-
metrics: vec![
371-
MetricProto {
372-
metric: Some(MetricValueProto::OutputRows(OutputRows { value: seed })),
373-
labels: vec![],
374-
partition: None,
375-
},
376-
MetricProto {
377-
metric: Some(MetricValueProto::ElapsedCompute(ElapsedCompute {
378-
value: seed,
379-
})),
380-
labels: vec![],
381-
partition: None,
382-
},
383-
MetricProto {
384-
metric: Some(MetricValueProto::StartTimestamp(StartTimestamp {
385-
value: Some(TEST_TIMESTAMP + (seed as i64 * 1_000_000_000)),
386-
})),
387-
labels: vec![],
388-
partition: None,
389-
},
390-
MetricProto {
391-
metric: Some(MetricValueProto::EndTimestamp(EndTimestamp {
392-
value: Some(TEST_TIMESTAMP + ((seed as i64 + 1) * 1_000_000_000)),
393-
})),
394-
labels: vec![],
395-
partition: None,
396-
},
397-
],
398-
}
399-
}
400-
401364
#[tokio::test]
402365
#[ignore]
403366
async fn test_metrics_rewriter() {
404367
let (test_stage, _ctx) = make_test_stage_exec_with_5_nodes().await;
405368
let test_metrics_sets = (0..5) // 5 nodes excluding NetworkShuffleExec
406-
.map(|i| make_distinct_metrics_set(i + 10))
369+
.map(|i| make_test_metrics_set_proto_from_seed(i + 10))
407370
.collect::<Vec<MetricsSetProto>>();
408371

409372
let rewriter = TaskMetricsRewriter::new(test_metrics_sets.clone());
@@ -429,7 +392,7 @@ mod tests {
429392
#[tokio::test]
430393
#[ignore]
431394
async fn test_metrics_rewriter_correct_number_of_metrics() {
432-
let test_metrics_set = make_distinct_metrics_set(10);
395+
let test_metrics_set = make_test_metrics_set_proto_from_seed(10);
433396
let (executable_plan, _ctx) = make_test_stage_exec_with_5_nodes().await;
434397
let task_plan = executable_plan
435398
.as_any()
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
use std::pin::Pin;
2+
use std::task::{Context, Poll};
3+
4+
use crate::metrics::proto::MetricsSetProto;
5+
use crate::protobuf::StageKey;
6+
use crate::protobuf::{AppMetadata, FlightAppMetadata};
7+
use arrow_flight::{error::FlightError, FlightData};
8+
use dashmap::DashMap;
9+
use futures::stream::Stream;
10+
use prost::Message;
11+
use std::sync::Arc;
12+
13+
/// MetricsCollectingStream wraps a FlightData stream and extracts metrics from app_metadata
14+
/// while passing through all the other FlightData unchanged.
15+
pub struct MetricsCollectingStream<S>
16+
where
17+
S: Stream<Item = Result<FlightData, FlightError>> + Send + Unpin,
18+
{
19+
inner: S,
20+
metrics_collection: Arc<DashMap<StageKey, Vec<MetricsSetProto>>>,
21+
}
22+
23+
impl<S> MetricsCollectingStream<S>
24+
where
25+
S: Stream<Item = Result<FlightData, FlightError>> + Send + Unpin,
26+
{
27+
#[allow(dead_code)]
28+
pub fn new(
29+
stream: S,
30+
metrics_collection: Arc<DashMap<StageKey, Vec<MetricsSetProto>>>,
31+
) -> Self {
32+
Self {
33+
inner: stream,
34+
metrics_collection,
35+
}
36+
}
37+
38+
fn extract_metrics_from_flight_data(
39+
&self,
40+
flight_data: &mut FlightData,
41+
) -> Result<(), FlightError> {
42+
if flight_data.app_metadata.is_empty() {
43+
return Ok(());
44+
}
45+
46+
let metadata =
47+
FlightAppMetadata::decode(flight_data.app_metadata.as_ref()).map_err(|e| {
48+
FlightError::ProtocolError(format!("failed to decode app_metadata: {}", e))
49+
})?;
50+
51+
let Some(content) = metadata.content else {
52+
return Err(FlightError::ProtocolError(
53+
"expected Some content in app_metadata".to_string(),
54+
));
55+
};
56+
57+
let AppMetadata::MetricsCollection(task_metrics_set) = content;
58+
59+
for task_metrics in task_metrics_set.tasks {
60+
let Some(stage_key) = task_metrics.stage_key else {
61+
return Err(FlightError::ProtocolError(
62+
"expected Some StageKey in MetricsCollectingStream, got None".to_string(),
63+
));
64+
};
65+
self.metrics_collection
66+
.insert(stage_key, task_metrics.metrics);
67+
}
68+
69+
flight_data.app_metadata.clear();
70+
Ok(())
71+
}
72+
}
73+
74+
impl<S> Stream for MetricsCollectingStream<S>
75+
where
76+
S: Stream<Item = Result<FlightData, FlightError>> + Send + Unpin,
77+
{
78+
type Item = Result<FlightData, FlightError>;
79+
80+
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
81+
match Pin::new(&mut self.inner).poll_next(cx) {
82+
Poll::Ready(Some(Ok(mut flight_data))) => {
83+
// Extract metrics from app_metadata if present.
84+
match self.extract_metrics_from_flight_data(&mut flight_data) {
85+
Ok(_) => Poll::Ready(Some(Ok(flight_data))),
86+
Err(e) => Poll::Ready(Some(Err(e))),
87+
}
88+
}
89+
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
90+
Poll::Ready(None) => Poll::Ready(None),
91+
Poll::Pending => Poll::Pending,
92+
}
93+
}
94+
}
95+
96+
#[cfg(test)]
97+
mod tests {
98+
use super::*;
99+
use crate::protobuf::{
100+
AppMetadata, FlightAppMetadata, MetricsCollection, StageKey, TaskMetrics,
101+
};
102+
use crate::test_utils::metrics::make_test_metrics_set_proto_from_seed;
103+
use arrow_flight::FlightData;
104+
use futures::stream::{self, StreamExt};
105+
use prost::{bytes::Bytes, Message};
106+
107+
fn assert_protocol_error(result: Result<FlightData, FlightError>, expected_msg: &str) {
108+
let FlightError::ProtocolError(msg) = result.unwrap_err() else {
109+
panic!("expected FlightError::ProtocolError");
110+
};
111+
assert!(msg.contains(expected_msg));
112+
}
113+
114+
fn make_flight_data(data: Vec<u8>, metadata: Option<FlightAppMetadata>) -> FlightData {
115+
let metadata_bytes = match metadata {
116+
Some(metadata) => metadata.encode_to_vec().into(),
117+
None => Bytes::new(),
118+
};
119+
FlightData {
120+
flight_descriptor: None,
121+
data_header: Bytes::new(),
122+
app_metadata: metadata_bytes,
123+
data_body: data.into(),
124+
}
125+
}
126+
127+
#[tokio::test]
128+
async fn test_metrics_collecting_stream_extracts_and_removes_metadata() {
129+
let stage_keys = vec![
130+
StageKey {
131+
query_id: "test_query".to_string(),
132+
stage_id: 1,
133+
task_number: 1,
134+
},
135+
StageKey {
136+
query_id: "test_query".to_string(),
137+
stage_id: 1,
138+
task_number: 2,
139+
},
140+
];
141+
142+
let app_metadatas = stage_keys
143+
.iter()
144+
.map(|stage_key| FlightAppMetadata {
145+
content: Some(AppMetadata::MetricsCollection(MetricsCollection {
146+
tasks: vec![TaskMetrics {
147+
stage_key: Some(stage_key.clone()),
148+
// use the task number to seed the test metrics set for convenience
149+
metrics: vec![make_test_metrics_set_proto_from_seed(stage_key.task_number)],
150+
}],
151+
})),
152+
})
153+
.collect::<Vec<_>>();
154+
155+
// Create test FlightData messages - some with metadata, some without
156+
let flight_messages = vec![
157+
make_flight_data(vec![1], Some(app_metadatas[0].clone())),
158+
make_flight_data(vec![2], None),
159+
make_flight_data(vec![3], Some(app_metadatas[1].clone())),
160+
]
161+
.into_iter()
162+
.map(Ok);
163+
164+
// Collect all messages from the stream. All should have empty app_metadata.
165+
let metrics_collection = Arc::new(DashMap::new());
166+
let input_stream = stream::iter(flight_messages);
167+
let collecting_stream =
168+
MetricsCollectingStream::new(input_stream, metrics_collection.clone());
169+
let collected_messages: Vec<FlightData> = collecting_stream
170+
.map(|result| result.unwrap())
171+
.collect()
172+
.await;
173+
174+
// Assert the data is unchanged and app_metadata is cleared
175+
assert_eq!(collected_messages.len(), 3);
176+
assert!(collected_messages
177+
.iter()
178+
.all(|msg| msg.app_metadata.is_empty()));
179+
180+
// Verify the data in the messages.
181+
assert_eq!(collected_messages[0].data_body, vec![1]);
182+
assert_eq!(collected_messages[1].data_body, vec![2]);
183+
assert_eq!(collected_messages[2].data_body, vec![3]);
184+
185+
// Verify the correct metrics were collected
186+
assert_eq!(metrics_collection.len(), 2);
187+
for stage_key in stage_keys {
188+
let collected_metrics = metrics_collection.get(&stage_key).unwrap();
189+
assert_eq!(collected_metrics.len(), 1);
190+
assert_eq!(
191+
collected_metrics[0],
192+
make_test_metrics_set_proto_from_seed(stage_key.task_number)
193+
);
194+
}
195+
}
196+
197+
#[tokio::test]
198+
async fn test_metrics_collecting_stream_error_missing_stage_key() {
199+
let task_metrics_with_no_stage_key = TaskMetrics {
200+
stage_key: None,
201+
metrics: vec![make_test_metrics_set_proto_from_seed(1)],
202+
};
203+
204+
let invalid_app_metadata = FlightAppMetadata {
205+
content: Some(AppMetadata::MetricsCollection(MetricsCollection {
206+
tasks: vec![task_metrics_with_no_stage_key],
207+
})),
208+
};
209+
210+
let invalid_flight_data = make_flight_data(vec![1], Some(invalid_app_metadata));
211+
212+
let error_stream = stream::iter(vec![Ok(invalid_flight_data)]);
213+
let mut collecting_stream =
214+
MetricsCollectingStream::new(error_stream, Arc::new(DashMap::new()));
215+
216+
let result = collecting_stream.next().await.unwrap();
217+
assert_protocol_error(
218+
result,
219+
"expected Some StageKey in MetricsCollectingStream, got None",
220+
);
221+
}
222+
223+
#[tokio::test]
224+
async fn test_metrics_collecting_stream_error_decoding_metadata() {
225+
let flight_data_with_invalid_metadata = FlightData {
226+
flight_descriptor: None,
227+
data_header: Bytes::new(),
228+
app_metadata: vec![0xFF, 0xFF, 0xFF, 0xFF].into(), // Invalid protobuf data
229+
data_body: vec![4, 5, 6].into(),
230+
};
231+
232+
let error_stream = stream::iter(vec![Ok(flight_data_with_invalid_metadata)]);
233+
let mut collecting_stream =
234+
MetricsCollectingStream::new(error_stream, Arc::new(DashMap::new()));
235+
236+
let result = collecting_stream.next().await.unwrap();
237+
assert_protocol_error(result, "failed to decode app_metadata");
238+
}
239+
240+
#[tokio::test]
241+
async fn test_metrics_collecting_stream_error_propagation() {
242+
let metrics_collection = Arc::new(DashMap::new());
243+
244+
// Create a stream that emits an error - should be propagated through
245+
let stream_error = FlightError::ProtocolError("stream error from inner stream".to_string());
246+
let error_stream = stream::iter(vec![Err(stream_error)]);
247+
let mut collecting_stream = MetricsCollectingStream::new(error_stream, metrics_collection);
248+
249+
let result = collecting_stream.next().await.unwrap();
250+
assert_protocol_error(result, "stream error from inner stream");
251+
}
252+
}

src/execution_plans/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
mod metrics;
2+
mod metrics_collecting_stream;
23
mod network_coalesce;
34
mod network_shuffle;
45
mod partition_isolator;

0 commit comments

Comments
 (0)