|
| 1 | +use std::pin::Pin; |
| 2 | +use std::task::{Context, Poll}; |
| 3 | + |
| 4 | +use crate::metrics::proto::MetricsSetProto; |
| 5 | +use crate::protobuf::StageKey; |
| 6 | +use crate::protobuf::{AppMetadata, FlightAppMetadata}; |
| 7 | +use arrow_flight::{FlightData, error::FlightError}; |
| 8 | +use dashmap::DashMap; |
| 9 | +use futures::stream::Stream; |
| 10 | +use pin_project::pin_project; |
| 11 | +use prost::Message; |
| 12 | +use std::sync::Arc; |
| 13 | + |
| 14 | +/// MetricsCollectingStream wraps a FlightData stream and extracts metrics from app_metadata |
| 15 | +/// while passing through all the other FlightData unchanged. |
| 16 | +#[pin_project] |
| 17 | +pub struct MetricsCollectingStream<S> |
| 18 | +where |
| 19 | + S: Stream<Item = Result<FlightData, FlightError>> + Send, |
| 20 | +{ |
| 21 | + #[pin] |
| 22 | + inner: S, |
| 23 | + metrics_collection: Arc<DashMap<StageKey, Vec<MetricsSetProto>>>, |
| 24 | +} |
| 25 | + |
| 26 | +impl<S> MetricsCollectingStream<S> |
| 27 | +where |
| 28 | + S: Stream<Item = Result<FlightData, FlightError>> + Send, |
| 29 | +{ |
| 30 | + #[allow(dead_code)] |
| 31 | + pub fn new( |
| 32 | + stream: S, |
| 33 | + metrics_collection: Arc<DashMap<StageKey, Vec<MetricsSetProto>>>, |
| 34 | + ) -> Self { |
| 35 | + Self { |
| 36 | + inner: stream, |
| 37 | + metrics_collection, |
| 38 | + } |
| 39 | + } |
| 40 | + |
| 41 | + fn extract_metrics_from_flight_data( |
| 42 | + metrics_collection: Arc<DashMap<StageKey, Vec<MetricsSetProto>>>, |
| 43 | + flight_data: &mut FlightData, |
| 44 | + ) -> Result<(), FlightError> { |
| 45 | + if flight_data.app_metadata.is_empty() { |
| 46 | + return Ok(()); |
| 47 | + } |
| 48 | + |
| 49 | + let metadata = |
| 50 | + FlightAppMetadata::decode(flight_data.app_metadata.as_ref()).map_err(|e| { |
| 51 | + FlightError::ProtocolError(format!("failed to decode app_metadata: {}", e)) |
| 52 | + })?; |
| 53 | + |
| 54 | + let Some(content) = metadata.content else { |
| 55 | + return Err(FlightError::ProtocolError( |
| 56 | + "expected Some content in app_metadata".to_string(), |
| 57 | + )); |
| 58 | + }; |
| 59 | + |
| 60 | + let AppMetadata::MetricsCollection(task_metrics_set) = content; |
| 61 | + |
| 62 | + for task_metrics in task_metrics_set.tasks { |
| 63 | + let Some(stage_key) = task_metrics.stage_key else { |
| 64 | + return Err(FlightError::ProtocolError( |
| 65 | + "expected Some StageKey in MetricsCollectingStream, got None".to_string(), |
| 66 | + )); |
| 67 | + }; |
| 68 | + metrics_collection.insert(stage_key, task_metrics.metrics); |
| 69 | + } |
| 70 | + |
| 71 | + flight_data.app_metadata.clear(); |
| 72 | + Ok(()) |
| 73 | + } |
| 74 | +} |
| 75 | + |
| 76 | +impl<S> Stream for MetricsCollectingStream<S> |
| 77 | +where |
| 78 | + S: Stream<Item = Result<FlightData, FlightError>> + Send, |
| 79 | +{ |
| 80 | + type Item = Result<FlightData, FlightError>; |
| 81 | + |
| 82 | + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
| 83 | + let this = self.project(); |
| 84 | + match this.inner.poll_next(cx) { |
| 85 | + Poll::Ready(Some(Ok(mut flight_data))) => { |
| 86 | + // Extract metrics from app_metadata if present. |
| 87 | + match Self::extract_metrics_from_flight_data( |
| 88 | + this.metrics_collection.clone(), |
| 89 | + &mut flight_data, |
| 90 | + ) { |
| 91 | + Ok(_) => Poll::Ready(Some(Ok(flight_data))), |
| 92 | + Err(e) => Poll::Ready(Some(Err(e))), |
| 93 | + } |
| 94 | + } |
| 95 | + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), |
| 96 | + Poll::Ready(None) => Poll::Ready(None), |
| 97 | + Poll::Pending => Poll::Pending, |
| 98 | + } |
| 99 | + } |
| 100 | +} |
| 101 | + |
| 102 | +#[cfg(test)] |
| 103 | +mod tests { |
| 104 | + use super::*; |
| 105 | + use crate::protobuf::{ |
| 106 | + AppMetadata, FlightAppMetadata, MetricsCollection, StageKey, TaskMetrics, |
| 107 | + }; |
| 108 | + use crate::test_utils::metrics::make_test_metrics_set_proto_from_seed; |
| 109 | + use arrow_flight::FlightData; |
| 110 | + use futures::stream::{self, StreamExt}; |
| 111 | + use prost::{Message, bytes::Bytes}; |
| 112 | + |
| 113 | + fn assert_protocol_error(result: Result<FlightData, FlightError>, expected_msg: &str) { |
| 114 | + let FlightError::ProtocolError(msg) = result.unwrap_err() else { |
| 115 | + panic!("expected FlightError::ProtocolError"); |
| 116 | + }; |
| 117 | + assert!(msg.contains(expected_msg)); |
| 118 | + } |
| 119 | + |
| 120 | + fn make_flight_data(data: Vec<u8>, metadata: Option<FlightAppMetadata>) -> FlightData { |
| 121 | + let metadata_bytes = match metadata { |
| 122 | + Some(metadata) => metadata.encode_to_vec().into(), |
| 123 | + None => Bytes::new(), |
| 124 | + }; |
| 125 | + FlightData { |
| 126 | + flight_descriptor: None, |
| 127 | + data_header: Bytes::new(), |
| 128 | + app_metadata: metadata_bytes, |
| 129 | + data_body: data.into(), |
| 130 | + } |
| 131 | + } |
| 132 | + |
| 133 | + // Tests that metrics are extracted from FlightData. Metrics are always stored per task, so this |
| 134 | + // tests creates some random metrics and tasks and puts them in the app_metadata field of |
| 135 | + // FlightData message. Then, it streams these messages through the MetricsCollectingStream |
| 136 | + // and asserts that the metrics are collected correctly. |
| 137 | + #[tokio::test] |
| 138 | + async fn test_metrics_collecting_stream_extracts_and_removes_metadata() { |
| 139 | + let stage_keys = vec![ |
| 140 | + StageKey { |
| 141 | + query_id: "test_query".to_string(), |
| 142 | + stage_id: 1, |
| 143 | + task_number: 1, |
| 144 | + }, |
| 145 | + StageKey { |
| 146 | + query_id: "test_query".to_string(), |
| 147 | + stage_id: 1, |
| 148 | + task_number: 2, |
| 149 | + }, |
| 150 | + ]; |
| 151 | + |
| 152 | + let app_metadatas = stage_keys |
| 153 | + .iter() |
| 154 | + .map(|stage_key| FlightAppMetadata { |
| 155 | + content: Some(AppMetadata::MetricsCollection(MetricsCollection { |
| 156 | + tasks: vec![TaskMetrics { |
| 157 | + stage_key: Some(stage_key.clone()), |
| 158 | + // use the task number to seed the test metrics set for convenience |
| 159 | + metrics: vec![make_test_metrics_set_proto_from_seed(stage_key.task_number)], |
| 160 | + }], |
| 161 | + })), |
| 162 | + }) |
| 163 | + .collect::<Vec<_>>(); |
| 164 | + |
| 165 | + // Create test FlightData messages - some with metadata, some without |
| 166 | + let flight_messages = vec![ |
| 167 | + make_flight_data(vec![1], Some(app_metadatas[0].clone())), |
| 168 | + make_flight_data(vec![2], None), |
| 169 | + make_flight_data(vec![3], Some(app_metadatas[1].clone())), |
| 170 | + ] |
| 171 | + .into_iter() |
| 172 | + .map(Ok); |
| 173 | + |
| 174 | + // Collect all messages from the stream. All should have empty app_metadata. |
| 175 | + let metrics_collection = Arc::new(DashMap::new()); |
| 176 | + let input_stream = stream::iter(flight_messages); |
| 177 | + let collecting_stream = |
| 178 | + MetricsCollectingStream::new(input_stream, metrics_collection.clone()); |
| 179 | + let collected_messages: Vec<FlightData> = collecting_stream |
| 180 | + .map(|result| result.unwrap()) |
| 181 | + .collect() |
| 182 | + .await; |
| 183 | + |
| 184 | + // Assert the data is unchanged and app_metadata is cleared |
| 185 | + assert_eq!(collected_messages.len(), 3); |
| 186 | + assert!( |
| 187 | + collected_messages |
| 188 | + .iter() |
| 189 | + .all(|msg| msg.app_metadata.is_empty()) |
| 190 | + ); |
| 191 | + |
| 192 | + // Verify the data in the messages. |
| 193 | + assert_eq!(collected_messages[0].data_body, vec![1]); |
| 194 | + assert_eq!(collected_messages[1].data_body, vec![2]); |
| 195 | + assert_eq!(collected_messages[2].data_body, vec![3]); |
| 196 | + |
| 197 | + // Verify the correct metrics were collected |
| 198 | + assert_eq!(metrics_collection.len(), 2); |
| 199 | + for stage_key in stage_keys { |
| 200 | + let collected_metrics = metrics_collection.get(&stage_key).unwrap(); |
| 201 | + assert_eq!(collected_metrics.len(), 1); |
| 202 | + assert_eq!( |
| 203 | + collected_metrics[0], |
| 204 | + make_test_metrics_set_proto_from_seed(stage_key.task_number) |
| 205 | + ); |
| 206 | + } |
| 207 | + } |
| 208 | + |
| 209 | + #[tokio::test] |
| 210 | + async fn test_metrics_collecting_stream_error_missing_stage_key() { |
| 211 | + let task_metrics_with_no_stage_key = TaskMetrics { |
| 212 | + stage_key: None, |
| 213 | + metrics: vec![make_test_metrics_set_proto_from_seed(1)], |
| 214 | + }; |
| 215 | + |
| 216 | + let invalid_app_metadata = FlightAppMetadata { |
| 217 | + content: Some(AppMetadata::MetricsCollection(MetricsCollection { |
| 218 | + tasks: vec![task_metrics_with_no_stage_key], |
| 219 | + })), |
| 220 | + }; |
| 221 | + |
| 222 | + let invalid_flight_data = make_flight_data(vec![1], Some(invalid_app_metadata)); |
| 223 | + |
| 224 | + let error_stream = stream::iter(vec![Ok(invalid_flight_data)]); |
| 225 | + let mut collecting_stream = |
| 226 | + MetricsCollectingStream::new(error_stream, Arc::new(DashMap::new())); |
| 227 | + |
| 228 | + let result = collecting_stream.next().await.unwrap(); |
| 229 | + assert_protocol_error( |
| 230 | + result, |
| 231 | + "expected Some StageKey in MetricsCollectingStream, got None", |
| 232 | + ); |
| 233 | + } |
| 234 | + |
| 235 | + #[tokio::test] |
| 236 | + async fn test_metrics_collecting_stream_error_decoding_metadata() { |
| 237 | + let flight_data_with_invalid_metadata = FlightData { |
| 238 | + flight_descriptor: None, |
| 239 | + data_header: Bytes::new(), |
| 240 | + app_metadata: vec![0xFF, 0xFF, 0xFF, 0xFF].into(), // Invalid protobuf data |
| 241 | + data_body: vec![4, 5, 6].into(), |
| 242 | + }; |
| 243 | + |
| 244 | + let error_stream = stream::iter(vec![Ok(flight_data_with_invalid_metadata)]); |
| 245 | + let mut collecting_stream = |
| 246 | + MetricsCollectingStream::new(error_stream, Arc::new(DashMap::new())); |
| 247 | + |
| 248 | + let result = collecting_stream.next().await.unwrap(); |
| 249 | + assert_protocol_error(result, "failed to decode app_metadata"); |
| 250 | + } |
| 251 | + |
| 252 | + #[tokio::test] |
| 253 | + async fn test_metrics_collecting_stream_error_propagation() { |
| 254 | + let metrics_collection = Arc::new(DashMap::new()); |
| 255 | + |
| 256 | + // Create a stream that emits an error - should be propagated through |
| 257 | + let stream_error = FlightError::ProtocolError("stream error from inner stream".to_string()); |
| 258 | + let error_stream = stream::iter(vec![Err(stream_error)]); |
| 259 | + let mut collecting_stream = MetricsCollectingStream::new(error_stream, metrics_collection); |
| 260 | + |
| 261 | + let result = collecting_stream.next().await.unwrap(); |
| 262 | + assert_protocol_error(result, "stream error from inner stream"); |
| 263 | + } |
| 264 | +} |
0 commit comments