Skip to content

Commit 147579f

Browse files
execution_plans: add MetricsCollectingStream
This change adds a new type called `MetricsCollectingStream`. It wraps a stream of `FlightData` and collects any metrics that are passed in the `app_metadata`. This change also introduces an `FlightAppMetadata` enum proto which can be used to define our app metadata protocol.
1 parent 63570d3 commit 147579f

File tree

10 files changed

+366
-63
lines changed

10 files changed

+366
-63
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ parquet = { version = "55.2.0", optional = true }
3939
arrow = { version = "55.2.0", optional = true }
4040
tokio-stream = { version = "0.1.17", optional = true }
4141
hyper-util = { version = "0.1.16", optional = true }
42+
pin-project = "1.1.10"
4243

4344
[features]
4445
integration = [

src/execution_plans/metrics.rs

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,6 @@ impl ExecutionPlan for MetricsWrapperExec {
272272

273273
#[cfg(test)]
274274
mod tests {
275-
use crate::metrics::proto::{
276-
ElapsedCompute, EndTimestamp, MetricProto, MetricValueProto, OutputRows, StartTimestamp,
277-
};
278275

279276
use super::*;
280277
use datafusion::arrow::array::{Int32Array, StringArray};
@@ -283,6 +280,7 @@ mod tests {
283280
use crate::DistributedExt;
284281
use crate::DistributedPhysicalOptimizerRule;
285282
use crate::test_utils::in_memory_channel_resolver::InMemoryChannelResolver;
283+
use crate::test_utils::metrics::make_test_metrics_set_proto_from_seed;
286284
use crate::test_utils::session_context::register_temp_parquet_table;
287285
use datafusion::execution::{SessionStateBuilder, context::SessionContext};
288286
use datafusion::physical_plan::metrics::MetricValue;
@@ -367,47 +365,12 @@ mod tests {
367365
(stage_exec, ctx)
368366
}
369367

370-
/// creates a "distinct" set of metrics from the provided seed
371-
fn make_distinct_metrics_set(seed: u64) -> MetricsSetProto {
372-
const TEST_TIMESTAMP: i64 = 1758200400000000000; // 2025-09-18 13:00:00 UTC
373-
MetricsSetProto {
374-
metrics: vec![
375-
MetricProto {
376-
metric: Some(MetricValueProto::OutputRows(OutputRows { value: seed })),
377-
labels: vec![],
378-
partition: None,
379-
},
380-
MetricProto {
381-
metric: Some(MetricValueProto::ElapsedCompute(ElapsedCompute {
382-
value: seed,
383-
})),
384-
labels: vec![],
385-
partition: None,
386-
},
387-
MetricProto {
388-
metric: Some(MetricValueProto::StartTimestamp(StartTimestamp {
389-
value: Some(TEST_TIMESTAMP + (seed as i64 * 1_000_000_000)),
390-
})),
391-
labels: vec![],
392-
partition: None,
393-
},
394-
MetricProto {
395-
metric: Some(MetricValueProto::EndTimestamp(EndTimestamp {
396-
value: Some(TEST_TIMESTAMP + ((seed as i64 + 1) * 1_000_000_000)),
397-
})),
398-
labels: vec![],
399-
partition: None,
400-
},
401-
],
402-
}
403-
}
404-
405368
#[tokio::test]
406369
#[ignore]
407370
async fn test_metrics_rewriter() {
408371
let (test_stage, _ctx) = make_test_stage_exec_with_5_nodes().await;
409372
let test_metrics_sets = (0..5) // 5 nodes excluding NetworkShuffleExec
410-
.map(|i| make_distinct_metrics_set(i + 10))
373+
.map(|i| make_test_metrics_set_proto_from_seed(i + 10))
411374
.collect::<Vec<MetricsSetProto>>();
412375

413376
let rewriter = TaskMetricsRewriter::new(test_metrics_sets.clone());
@@ -433,7 +396,7 @@ mod tests {
433396
#[tokio::test]
434397
#[ignore]
435398
async fn test_metrics_rewriter_correct_number_of_metrics() {
436-
let test_metrics_set = make_distinct_metrics_set(10);
399+
let test_metrics_set = make_test_metrics_set_proto_from_seed(10);
437400
let (executable_plan, _ctx) = make_test_stage_exec_with_5_nodes().await;
438401
let task_plan = executable_plan
439402
.as_any()
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
use std::pin::Pin;
2+
use std::task::{Context, Poll};
3+
4+
use crate::metrics::proto::MetricsSetProto;
5+
use crate::protobuf::StageKey;
6+
use crate::protobuf::{AppMetadata, FlightAppMetadata};
7+
use arrow_flight::{FlightData, error::FlightError};
8+
use dashmap::DashMap;
9+
use futures::stream::Stream;
10+
use pin_project::pin_project;
11+
use prost::Message;
12+
use std::sync::Arc;
13+
14+
/// MetricsCollectingStream wraps a FlightData stream and extracts metrics from app_metadata
15+
/// while passing through all the other FlightData unchanged.
16+
#[pin_project]
17+
pub struct MetricsCollectingStream<S>
18+
where
19+
S: Stream<Item = Result<FlightData, FlightError>> + Send,
20+
{
21+
#[pin]
22+
inner: S,
23+
metrics_collection: Arc<DashMap<StageKey, Vec<MetricsSetProto>>>,
24+
}
25+
26+
impl<S> MetricsCollectingStream<S>
27+
where
28+
S: Stream<Item = Result<FlightData, FlightError>> + Send,
29+
{
30+
#[allow(dead_code)]
31+
pub fn new(
32+
stream: S,
33+
metrics_collection: Arc<DashMap<StageKey, Vec<MetricsSetProto>>>,
34+
) -> Self {
35+
Self {
36+
inner: stream,
37+
metrics_collection,
38+
}
39+
}
40+
41+
fn extract_metrics_from_flight_data(
42+
metrics_collection: Arc<DashMap<StageKey, Vec<MetricsSetProto>>>,
43+
flight_data: &mut FlightData,
44+
) -> Result<(), FlightError> {
45+
if flight_data.app_metadata.is_empty() {
46+
return Ok(());
47+
}
48+
49+
let metadata =
50+
FlightAppMetadata::decode(flight_data.app_metadata.as_ref()).map_err(|e| {
51+
FlightError::ProtocolError(format!("failed to decode app_metadata: {}", e))
52+
})?;
53+
54+
let Some(content) = metadata.content else {
55+
return Err(FlightError::ProtocolError(
56+
"expected Some content in app_metadata".to_string(),
57+
));
58+
};
59+
60+
let AppMetadata::MetricsCollection(task_metrics_set) = content;
61+
62+
for task_metrics in task_metrics_set.tasks {
63+
let Some(stage_key) = task_metrics.stage_key else {
64+
return Err(FlightError::ProtocolError(
65+
"expected Some StageKey in MetricsCollectingStream, got None".to_string(),
66+
));
67+
};
68+
metrics_collection.insert(stage_key, task_metrics.metrics);
69+
}
70+
71+
flight_data.app_metadata.clear();
72+
Ok(())
73+
}
74+
}
75+
76+
impl<S> Stream for MetricsCollectingStream<S>
77+
where
78+
S: Stream<Item = Result<FlightData, FlightError>> + Send,
79+
{
80+
type Item = Result<FlightData, FlightError>;
81+
82+
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
83+
let this = self.project();
84+
match this.inner.poll_next(cx) {
85+
Poll::Ready(Some(Ok(mut flight_data))) => {
86+
// Extract metrics from app_metadata if present.
87+
match Self::extract_metrics_from_flight_data(
88+
this.metrics_collection.clone(),
89+
&mut flight_data,
90+
) {
91+
Ok(_) => Poll::Ready(Some(Ok(flight_data))),
92+
Err(e) => Poll::Ready(Some(Err(e))),
93+
}
94+
}
95+
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
96+
Poll::Ready(None) => Poll::Ready(None),
97+
Poll::Pending => Poll::Pending,
98+
}
99+
}
100+
}
101+
102+
#[cfg(test)]
103+
mod tests {
104+
use super::*;
105+
use crate::protobuf::{
106+
AppMetadata, FlightAppMetadata, MetricsCollection, StageKey, TaskMetrics,
107+
};
108+
use crate::test_utils::metrics::make_test_metrics_set_proto_from_seed;
109+
use arrow_flight::FlightData;
110+
use futures::stream::{self, StreamExt};
111+
use prost::{Message, bytes::Bytes};
112+
113+
fn assert_protocol_error(result: Result<FlightData, FlightError>, expected_msg: &str) {
114+
let FlightError::ProtocolError(msg) = result.unwrap_err() else {
115+
panic!("expected FlightError::ProtocolError");
116+
};
117+
assert!(msg.contains(expected_msg));
118+
}
119+
120+
fn make_flight_data(data: Vec<u8>, metadata: Option<FlightAppMetadata>) -> FlightData {
121+
let metadata_bytes = match metadata {
122+
Some(metadata) => metadata.encode_to_vec().into(),
123+
None => Bytes::new(),
124+
};
125+
FlightData {
126+
flight_descriptor: None,
127+
data_header: Bytes::new(),
128+
app_metadata: metadata_bytes,
129+
data_body: data.into(),
130+
}
131+
}
132+
133+
// Tests that metrics are extracted from FlightData. Metrics are always stored per task, so this
134+
// tests creates some random metrics and tasks and puts them in the app_metadata field of
135+
// FlightData message. Then, it streams these messages through the MetricsCollectingStream
136+
// and asserts that the metrics are collected correctly.
137+
#[tokio::test]
138+
async fn test_metrics_collecting_stream_extracts_and_removes_metadata() {
139+
let stage_keys = vec![
140+
StageKey {
141+
query_id: "test_query".to_string(),
142+
stage_id: 1,
143+
task_number: 1,
144+
},
145+
StageKey {
146+
query_id: "test_query".to_string(),
147+
stage_id: 1,
148+
task_number: 2,
149+
},
150+
];
151+
152+
let app_metadatas = stage_keys
153+
.iter()
154+
.map(|stage_key| FlightAppMetadata {
155+
content: Some(AppMetadata::MetricsCollection(MetricsCollection {
156+
tasks: vec![TaskMetrics {
157+
stage_key: Some(stage_key.clone()),
158+
// use the task number to seed the test metrics set for convenience
159+
metrics: vec![make_test_metrics_set_proto_from_seed(stage_key.task_number)],
160+
}],
161+
})),
162+
})
163+
.collect::<Vec<_>>();
164+
165+
// Create test FlightData messages - some with metadata, some without
166+
let flight_messages = vec![
167+
make_flight_data(vec![1], Some(app_metadatas[0].clone())),
168+
make_flight_data(vec![2], None),
169+
make_flight_data(vec![3], Some(app_metadatas[1].clone())),
170+
]
171+
.into_iter()
172+
.map(Ok);
173+
174+
// Collect all messages from the stream. All should have empty app_metadata.
175+
let metrics_collection = Arc::new(DashMap::new());
176+
let input_stream = stream::iter(flight_messages);
177+
let collecting_stream =
178+
MetricsCollectingStream::new(input_stream, metrics_collection.clone());
179+
let collected_messages: Vec<FlightData> = collecting_stream
180+
.map(|result| result.unwrap())
181+
.collect()
182+
.await;
183+
184+
// Assert the data is unchanged and app_metadata is cleared
185+
assert_eq!(collected_messages.len(), 3);
186+
assert!(
187+
collected_messages
188+
.iter()
189+
.all(|msg| msg.app_metadata.is_empty())
190+
);
191+
192+
// Verify the data in the messages.
193+
assert_eq!(collected_messages[0].data_body, vec![1]);
194+
assert_eq!(collected_messages[1].data_body, vec![2]);
195+
assert_eq!(collected_messages[2].data_body, vec![3]);
196+
197+
// Verify the correct metrics were collected
198+
assert_eq!(metrics_collection.len(), 2);
199+
for stage_key in stage_keys {
200+
let collected_metrics = metrics_collection.get(&stage_key).unwrap();
201+
assert_eq!(collected_metrics.len(), 1);
202+
assert_eq!(
203+
collected_metrics[0],
204+
make_test_metrics_set_proto_from_seed(stage_key.task_number)
205+
);
206+
}
207+
}
208+
209+
#[tokio::test]
210+
async fn test_metrics_collecting_stream_error_missing_stage_key() {
211+
let task_metrics_with_no_stage_key = TaskMetrics {
212+
stage_key: None,
213+
metrics: vec![make_test_metrics_set_proto_from_seed(1)],
214+
};
215+
216+
let invalid_app_metadata = FlightAppMetadata {
217+
content: Some(AppMetadata::MetricsCollection(MetricsCollection {
218+
tasks: vec![task_metrics_with_no_stage_key],
219+
})),
220+
};
221+
222+
let invalid_flight_data = make_flight_data(vec![1], Some(invalid_app_metadata));
223+
224+
let error_stream = stream::iter(vec![Ok(invalid_flight_data)]);
225+
let mut collecting_stream =
226+
MetricsCollectingStream::new(error_stream, Arc::new(DashMap::new()));
227+
228+
let result = collecting_stream.next().await.unwrap();
229+
assert_protocol_error(
230+
result,
231+
"expected Some StageKey in MetricsCollectingStream, got None",
232+
);
233+
}
234+
235+
#[tokio::test]
236+
async fn test_metrics_collecting_stream_error_decoding_metadata() {
237+
let flight_data_with_invalid_metadata = FlightData {
238+
flight_descriptor: None,
239+
data_header: Bytes::new(),
240+
app_metadata: vec![0xFF, 0xFF, 0xFF, 0xFF].into(), // Invalid protobuf data
241+
data_body: vec![4, 5, 6].into(),
242+
};
243+
244+
let error_stream = stream::iter(vec![Ok(flight_data_with_invalid_metadata)]);
245+
let mut collecting_stream =
246+
MetricsCollectingStream::new(error_stream, Arc::new(DashMap::new()));
247+
248+
let result = collecting_stream.next().await.unwrap();
249+
assert_protocol_error(result, "failed to decode app_metadata");
250+
}
251+
252+
#[tokio::test]
253+
async fn test_metrics_collecting_stream_error_propagation() {
254+
let metrics_collection = Arc::new(DashMap::new());
255+
256+
// Create a stream that emits an error - should be propagated through
257+
let stream_error = FlightError::ProtocolError("stream error from inner stream".to_string());
258+
let error_stream = stream::iter(vec![Err(stream_error)]);
259+
let mut collecting_stream = MetricsCollectingStream::new(error_stream, metrics_collection);
260+
261+
let result = collecting_stream.next().await.unwrap();
262+
assert_protocol_error(result, "stream error from inner stream");
263+
}
264+
}

src/execution_plans/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
mod metrics;
2+
mod metrics_collecting_stream;
23
mod network_coalesce;
34
mod network_shuffle;
45
mod partition_isolator;

0 commit comments

Comments
 (0)