Skip to content

Commit 30f44db

Browse files
execution_plans: add MetricsCollectingStream
This change adds a new type called `MetricsCollectingStream`. It wraps a stream of `FlightData` and collects any metrics that are passed in the `app_metadata`. This change also introduces an `FlightAppMetadata` enum proto which can be used to define our app metadata protocol.
1 parent e428e7f commit 30f44db

File tree

4 files changed

+354
-0
lines changed

4 files changed

+354
-0
lines changed
Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
use std::pin::Pin;
2+
use std::task::{Context, Poll};
3+
4+
use crate::metrics::proto::MetricsSetProto;
5+
use crate::protobuf::StageKey;
6+
use crate::protobuf::{AppMetadata, FlightAppMetadata};
7+
use arrow_flight::{error::FlightError, FlightData};
8+
use dashmap::DashMap;
9+
use futures::stream::Stream;
10+
use prost::Message;
11+
use std::sync::Arc;
12+
13+
/// MetricsCollectingStream wraps a FlightData stream and extracts metrics from app_metadata
14+
/// while passing through all the other FlightData unchanged.
15+
pub struct MetricsCollectingStream<S>
16+
where
17+
S: Stream<Item = Result<FlightData, FlightError>> + Send + Unpin,
18+
{
19+
inner: S,
20+
metrics_collection: Arc<DashMap<StageKey, Vec<MetricsSetProto>>>,
21+
}
22+
23+
impl<S> MetricsCollectingStream<S>
24+
where
25+
S: Stream<Item = Result<FlightData, FlightError>> + Send + Unpin,
26+
{
27+
#[allow(dead_code)]
28+
pub fn new(
29+
stream: S,
30+
metrics_collection: Arc<DashMap<StageKey, Vec<MetricsSetProto>>>,
31+
) -> Self {
32+
Self {
33+
inner: stream,
34+
metrics_collection,
35+
}
36+
}
37+
38+
fn extract_metrics_from_flight_data(
39+
&self,
40+
flight_data: &mut FlightData,
41+
) -> Result<(), FlightError> {
42+
if !flight_data.app_metadata.is_empty() {
43+
return match FlightAppMetadata::decode(flight_data.app_metadata.as_ref()) {
44+
Ok(metadata) => {
45+
if let Some(content) = metadata.content {
46+
match content {
47+
AppMetadata::MetricsCollection(task_metrics_set) => {
48+
for task_metrics in task_metrics_set.tasks {
49+
if let Some(stage_key) = task_metrics.stage_key {
50+
self.metrics_collection
51+
.insert(stage_key, task_metrics.metrics);
52+
} else {
53+
return Err(FlightError::ProtocolError("expected Some StageKey in MetricsCollectingStream, got None".to_string()));
54+
}
55+
}
56+
}
57+
}
58+
}
59+
flight_data.app_metadata.clear();
60+
Ok(())
61+
}
62+
Err(e) => Err(FlightError::ProtocolError(format!(
63+
"failed to decode app_metadata: {}",
64+
e
65+
))),
66+
};
67+
}
68+
Ok(())
69+
}
70+
}
71+
72+
impl<S> Stream for MetricsCollectingStream<S>
73+
where
74+
S: Stream<Item = Result<FlightData, FlightError>> + Send + Unpin,
75+
{
76+
type Item = Result<FlightData, FlightError>;
77+
78+
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
79+
match Pin::new(&mut self.inner).poll_next(cx) {
80+
Poll::Ready(Some(Ok(mut flight_data))) => {
81+
// Extract metrics from app_metadata if present.
82+
match self.extract_metrics_from_flight_data(&mut flight_data) {
83+
Ok(_) => Poll::Ready(Some(Ok(flight_data))),
84+
Err(e) => Poll::Ready(Some(Err(e))),
85+
}
86+
}
87+
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
88+
Poll::Ready(None) => Poll::Ready(None),
89+
Poll::Pending => Poll::Pending,
90+
}
91+
}
92+
}
93+
94+
#[cfg(test)]
95+
mod tests {
96+
use super::*;
97+
use crate::metrics::proto::{
98+
MetricProto, MetricValueProto, MetricsSetProto, NamedCount, NamedGauge,
99+
};
100+
use crate::protobuf::{
101+
AppMetadata, FlightAppMetadata, MetricsCollection, StageKey, TaskMetrics,
102+
};
103+
use arrow_flight::FlightData;
104+
use futures::stream::{self, StreamExt};
105+
use prost::{bytes::Bytes, Message};
106+
107+
#[tokio::test]
108+
async fn test_metrics_collecting_stream_extracts_and_removes_metadata() {
109+
let test_metrics_set = MetricsSetProto {
110+
metrics: vec![
111+
MetricProto {
112+
metric: Some(MetricValueProto::Count(NamedCount {
113+
name: "test_count".to_string(),
114+
value: 42,
115+
})),
116+
labels: vec![],
117+
partition: Some(0),
118+
},
119+
MetricProto {
120+
metric: Some(MetricValueProto::Gauge(NamedGauge {
121+
name: "test_gauge".to_string(),
122+
value: 99,
123+
})),
124+
labels: vec![],
125+
partition: Some(0),
126+
},
127+
],
128+
};
129+
130+
let stage_keys = vec![
131+
StageKey {
132+
query_id: "test_query".to_string(),
133+
stage_id: 1,
134+
task_number: 1,
135+
},
136+
StageKey {
137+
query_id: "test_query_2".to_string(),
138+
stage_id: 2,
139+
task_number: 2,
140+
},
141+
];
142+
143+
let app_metadatas = stage_keys
144+
.iter()
145+
.map(|stage_key| FlightAppMetadata {
146+
content: Some(AppMetadata::MetricsCollection(MetricsCollection {
147+
tasks: vec![TaskMetrics {
148+
stage_key: Some(stage_key.clone()),
149+
metrics: vec![test_metrics_set.clone()],
150+
}],
151+
})),
152+
})
153+
.collect::<Vec<_>>();
154+
155+
// Create test FlightData messages - some with metadata, some without
156+
let flight_data_with_metadata = FlightData {
157+
flight_descriptor: None,
158+
data_header: Bytes::new(),
159+
app_metadata: app_metadatas[0].encode_to_vec().into(),
160+
data_body: vec![1, 2, 3].into(),
161+
};
162+
163+
let flight_data_without_metadata = FlightData {
164+
flight_descriptor: None,
165+
data_header: Bytes::new(),
166+
app_metadata: Bytes::new(),
167+
data_body: vec![4, 5, 6].into(),
168+
};
169+
170+
let flight_data_with_metadata2 = FlightData {
171+
flight_descriptor: None,
172+
data_header: Bytes::new(),
173+
app_metadata: app_metadatas[1].encode_to_vec().into(),
174+
data_body: vec![7, 8, 9].into(),
175+
};
176+
177+
let input_stream = stream::iter(vec![
178+
Ok(flight_data_with_metadata),
179+
Ok(flight_data_without_metadata),
180+
Ok(flight_data_with_metadata2),
181+
]);
182+
183+
let metrics_collection = Arc::new(DashMap::new());
184+
let mut collecting_stream =
185+
MetricsCollectingStream::new(input_stream, metrics_collection.clone());
186+
187+
// Collect all messages from the stream. All should have empty app_metadata.
188+
let mut collected_messages = vec![];
189+
while let Some(result) = collecting_stream.next().await {
190+
collected_messages.push(result.unwrap());
191+
}
192+
assert_eq!(collected_messages.len(), 3);
193+
for msg in &collected_messages {
194+
assert!(
195+
msg.app_metadata.is_empty(),
196+
"app_metadata should be empty after collection"
197+
);
198+
}
199+
200+
// Verify the data in the messages.
201+
assert_eq!(collected_messages[0].data_body, vec![1, 2, 3]);
202+
assert_eq!(collected_messages[1].data_body, vec![4, 5, 6]);
203+
assert_eq!(collected_messages[2].data_body, vec![7, 8, 9]);
204+
205+
// Verify metrics were collected
206+
assert_eq!(metrics_collection.len(), 2);
207+
for stage_key in stage_keys {
208+
let collected_metrics = metrics_collection.get(&stage_key).unwrap();
209+
assert_eq!(collected_metrics.len(), 1);
210+
assert_eq!(collected_metrics[0].metrics.len(), 2); // We have 2 metrics: Count and Gauge
211+
// Verify the first metric value (Count)
212+
if let Some(MetricValueProto::Count(count)) = &collected_metrics[0].metrics[0].metric {
213+
assert_eq!(count.name, "test_count");
214+
assert_eq!(count.value, 42);
215+
} else {
216+
panic!("expected Count metric");
217+
}
218+
219+
// Verify the second metric value (Gauge)
220+
if let Some(MetricValueProto::Gauge(gauge)) = &collected_metrics[0].metrics[1].metric {
221+
assert_eq!(gauge.name, "test_gauge");
222+
assert_eq!(gauge.value, 99);
223+
} else {
224+
panic!("expected Gauge metric");
225+
}
226+
}
227+
}
228+
229+
#[tokio::test]
230+
async fn test_metrics_collecting_stream_error_missing_stage_key() {
231+
let metrics_collection = Arc::new(DashMap::new());
232+
let task_metrics_with_no_stage_key = TaskMetrics {
233+
stage_key: None,
234+
metrics: vec![MetricsSetProto {
235+
metrics: vec![MetricProto {
236+
metric: Some(MetricValueProto::Count(NamedCount {
237+
name: "test_count".to_string(),
238+
value: 42,
239+
})),
240+
labels: vec![],
241+
partition: Some(0),
242+
}],
243+
}],
244+
};
245+
246+
let invalid_app_metadata = FlightAppMetadata {
247+
content: Some(AppMetadata::MetricsCollection(MetricsCollection {
248+
tasks: vec![task_metrics_with_no_stage_key],
249+
})),
250+
};
251+
252+
let invalid_flight_data = FlightData {
253+
flight_descriptor: None,
254+
data_header: Bytes::new(),
255+
app_metadata: invalid_app_metadata.encode_to_vec().into(),
256+
data_body: vec![1, 2, 3].into(),
257+
};
258+
259+
let error_stream = stream::iter(vec![Ok(invalid_flight_data)]);
260+
let mut collecting_stream = MetricsCollectingStream::new(error_stream, metrics_collection);
261+
262+
let result = collecting_stream.next().await.unwrap();
263+
assert!(result.is_err());
264+
if let Err(FlightError::ProtocolError(msg)) = result {
265+
assert!(msg.contains("expected Some StageKey in MetricsCollectingStream, got None"));
266+
} else {
267+
panic!("expected FlightError::ProtocolError with stage key error");
268+
}
269+
}
270+
271+
#[tokio::test]
272+
async fn test_metrics_collecting_stream_error_invalid_metadata() {
273+
let metrics_collection = Arc::new(DashMap::new());
274+
275+
let flight_data_with_invalid_metadata = FlightData {
276+
flight_descriptor: None,
277+
data_header: Bytes::new(),
278+
app_metadata: vec![0xFF, 0xFF, 0xFF, 0xFF].into(), // Invalid protobuf data
279+
data_body: vec![4, 5, 6].into(),
280+
};
281+
282+
let error_stream = stream::iter(vec![Ok(flight_data_with_invalid_metadata)]);
283+
let mut collecting_stream = MetricsCollectingStream::new(error_stream, metrics_collection);
284+
285+
let result = collecting_stream.next().await.unwrap();
286+
assert!(result.is_err());
287+
if let Err(FlightError::ProtocolError(msg)) = result {
288+
assert!(msg.contains("failed to decode app_metadata"));
289+
} else {
290+
panic!("expected FlightError::ProtocolError with decode error");
291+
}
292+
}
293+
294+
#[tokio::test]
295+
async fn test_metrics_collecting_stream_error_propagation() {
296+
let metrics_collection = Arc::new(DashMap::new());
297+
298+
// Create a stream that emits an error - should be propagated through
299+
let stream_error = FlightError::ProtocolError("stream error from inner stream".to_string());
300+
let error_stream = stream::iter(vec![Err(stream_error)]);
301+
let mut collecting_stream = MetricsCollectingStream::new(error_stream, metrics_collection);
302+
303+
let result = collecting_stream.next().await.unwrap();
304+
assert!(result.is_err());
305+
if let Err(FlightError::ProtocolError(msg)) = result {
306+
assert!(msg.contains("stream error from inner stream"));
307+
} else {
308+
panic!("expected FlightError::ProtocolError with inner stream error");
309+
}
310+
}
311+
}

src/execution_plans/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
mod arrow_flight_read;
22
mod metrics;
3+
mod metrics_collecting_stream;
34
mod partition_isolator;
45
mod stage;
56

src/protobuf/app_metadata.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
use crate::metrics::proto::MetricsSetProto;
2+
use crate::protobuf::StageKey;
3+
4+
/// A collection of metrics for a set of tasks in an ExecutionPlan. each
5+
/// entry should have a distinct [StageKey].
6+
#[derive(Clone, PartialEq, ::prost::Message)]
7+
pub struct MetricsCollection {
8+
#[prost(message, repeated, tag = "1")]
9+
pub tasks: Vec<TaskMetrics>,
10+
}
11+
12+
/// TaskMetrics represents the metrics for a single task.
13+
#[derive(Clone, PartialEq, ::prost::Message)]
14+
pub struct TaskMetrics {
15+
/// stage_key uniquely identifies this task.
16+
///
17+
/// This field is always present. It's marked optional due to protobuf rules.
18+
#[prost(message, optional, tag = "1")]
19+
pub stage_key: Option<StageKey>,
20+
/// metrics[i] is the set of metrics for plan node `i` where plan nodes are in pre-order
21+
/// traversal order.
22+
#[prost(message, repeated, tag = "2")]
23+
pub metrics: Vec<MetricsSetProto>,
24+
}
25+
26+
// FlightAppMetadata represents all types of app_metadata which we use in the distributed execution.
27+
#[derive(Clone, PartialEq, ::prost::Message)]
28+
pub struct FlightAppMetadata {
29+
#[prost(oneof = "AppMetadata", tags = "1")]
30+
pub content: Option<AppMetadata>,
31+
}
32+
33+
#[derive(Clone, PartialEq, ::prost::Oneof)]
34+
pub enum AppMetadata {
35+
#[prost(message, tag = "1")]
36+
MetricsCollection(MetricsCollection),
37+
// Note: For every additional enum variant, ensure to add tags to [FlightAppMetadata]. ex. `#[prost(oneof = "AppMetadata", tags = "1,2,3")]` etc.
38+
// If you don't the proto will compile but you may encounter errors during serialization/deserialization.
39+
}

src/protobuf/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
mod app_metadata;
12
mod distributed_codec;
23
mod stage_proto;
34
mod user_codec;
45

6+
#[allow(unused_imports)]
7+
pub(crate) use app_metadata::{AppMetadata, FlightAppMetadata, MetricsCollection, TaskMetrics};
58
pub(crate) use distributed_codec::DistributedCodec;
69
pub(crate) use stage_proto::{proto_from_stage, stage_from_proto, StageExecProto, StageKey};
710
pub(crate) use user_codec::{get_distributed_user_codec, set_distributed_user_codec};

0 commit comments

Comments
 (0)