diff --git a/Cargo.lock b/Cargo.lock index 675beb6..00881dd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1123,6 +1123,7 @@ dependencies = [ "itertools", "object_store", "parquet", + "pin-project", "prost", "rand 0.8.5", "structopt", diff --git a/Cargo.toml b/Cargo.toml index f05cbb0..0381b34 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ parquet = { version = "55.2.0", optional = true } arrow = { version = "55.2.0", optional = true } tokio-stream = { version = "0.1.17", optional = true } hyper-util = { version = "0.1.16", optional = true } +pin-project = "1.1.10" [features] integration = [ diff --git a/src/execution_plans/metrics.rs b/src/execution_plans/metrics.rs index f772881..4a1ff59 100644 --- a/src/execution_plans/metrics.rs +++ b/src/execution_plans/metrics.rs @@ -272,9 +272,6 @@ impl ExecutionPlan for MetricsWrapperExec { #[cfg(test)] mod tests { - use crate::metrics::proto::{ - ElapsedCompute, EndTimestamp, MetricProto, MetricValueProto, OutputRows, StartTimestamp, - }; use super::*; use datafusion::arrow::array::{Int32Array, StringArray}; @@ -283,6 +280,7 @@ mod tests { use crate::DistributedExt; use crate::DistributedPhysicalOptimizerRule; use crate::test_utils::in_memory_channel_resolver::InMemoryChannelResolver; + use crate::test_utils::metrics::make_test_metrics_set_proto_from_seed; use crate::test_utils::session_context::register_temp_parquet_table; use datafusion::execution::{SessionStateBuilder, context::SessionContext}; use datafusion::physical_plan::metrics::MetricValue; @@ -367,47 +365,12 @@ mod tests { (stage_exec, ctx) } - /// creates a "distinct" set of metrics from the provided seed - fn make_distinct_metrics_set(seed: u64) -> MetricsSetProto { - const TEST_TIMESTAMP: i64 = 1758200400000000000; // 2025-09-18 13:00:00 UTC - MetricsSetProto { - metrics: vec![ - MetricProto { - metric: Some(MetricValueProto::OutputRows(OutputRows { value: seed })), - labels: vec![], - partition: None, - }, - MetricProto { - metric: Some(MetricValueProto::ElapsedCompute(ElapsedCompute { - value: seed, - })), - labels: vec![], - partition: None, - }, - MetricProto { - metric: Some(MetricValueProto::StartTimestamp(StartTimestamp { - value: Some(TEST_TIMESTAMP + (seed as i64 * 1_000_000_000)), - })), - labels: vec![], - partition: None, - }, - MetricProto { - metric: Some(MetricValueProto::EndTimestamp(EndTimestamp { - value: Some(TEST_TIMESTAMP + ((seed as i64 + 1) * 1_000_000_000)), - })), - labels: vec![], - partition: None, - }, - ], - } - } - #[tokio::test] #[ignore] async fn test_metrics_rewriter() { let (test_stage, _ctx) = make_test_stage_exec_with_5_nodes().await; let test_metrics_sets = (0..5) // 5 nodes excluding NetworkShuffleExec - .map(|i| make_distinct_metrics_set(i + 10)) + .map(|i| make_test_metrics_set_proto_from_seed(i + 10)) .collect::>(); let rewriter = TaskMetricsRewriter::new(test_metrics_sets.clone()); @@ -433,7 +396,7 @@ mod tests { #[tokio::test] #[ignore] async fn test_metrics_rewriter_correct_number_of_metrics() { - let test_metrics_set = make_distinct_metrics_set(10); + let test_metrics_set = make_test_metrics_set_proto_from_seed(10); let (executable_plan, _ctx) = make_test_stage_exec_with_5_nodes().await; let task_plan = executable_plan .as_any() diff --git a/src/execution_plans/metrics_collecting_stream.rs b/src/execution_plans/metrics_collecting_stream.rs new file mode 100644 index 0000000..1321f63 --- /dev/null +++ b/src/execution_plans/metrics_collecting_stream.rs @@ -0,0 +1,264 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + +use crate::metrics::proto::MetricsSetProto; +use crate::protobuf::StageKey; +use crate::protobuf::{AppMetadata, FlightAppMetadata}; +use arrow_flight::{FlightData, error::FlightError}; +use dashmap::DashMap; +use futures::stream::Stream; +use pin_project::pin_project; +use prost::Message; +use std::sync::Arc; + +/// MetricsCollectingStream wraps a FlightData stream and extracts metrics from app_metadata +/// while passing through all the other FlightData unchanged. +#[pin_project] +pub struct MetricsCollectingStream +where + S: Stream> + Send, +{ + #[pin] + inner: S, + metrics_collection: Arc>>, +} + +impl MetricsCollectingStream +where + S: Stream> + Send, +{ + #[allow(dead_code)] + pub fn new( + stream: S, + metrics_collection: Arc>>, + ) -> Self { + Self { + inner: stream, + metrics_collection, + } + } + + fn extract_metrics_from_flight_data( + metrics_collection: Arc>>, + flight_data: &mut FlightData, + ) -> Result<(), FlightError> { + if flight_data.app_metadata.is_empty() { + return Ok(()); + } + + let metadata = + FlightAppMetadata::decode(flight_data.app_metadata.as_ref()).map_err(|e| { + FlightError::ProtocolError(format!("failed to decode app_metadata: {}", e)) + })?; + + let Some(content) = metadata.content else { + return Err(FlightError::ProtocolError( + "expected Some content in app_metadata".to_string(), + )); + }; + + let AppMetadata::MetricsCollection(task_metrics_set) = content; + + for task_metrics in task_metrics_set.tasks { + let Some(stage_key) = task_metrics.stage_key else { + return Err(FlightError::ProtocolError( + "expected Some StageKey in MetricsCollectingStream, got None".to_string(), + )); + }; + metrics_collection.insert(stage_key, task_metrics.metrics); + } + + flight_data.app_metadata.clear(); + Ok(()) + } +} + +impl Stream for MetricsCollectingStream +where + S: Stream> + Send, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + match this.inner.poll_next(cx) { + Poll::Ready(Some(Ok(mut flight_data))) => { + // Extract metrics from app_metadata if present. + match Self::extract_metrics_from_flight_data( + this.metrics_collection.clone(), + &mut flight_data, + ) { + Ok(_) => Poll::Ready(Some(Ok(flight_data))), + Err(e) => Poll::Ready(Some(Err(e))), + } + } + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protobuf::{ + AppMetadata, FlightAppMetadata, MetricsCollection, StageKey, TaskMetrics, + }; + use crate::test_utils::metrics::make_test_metrics_set_proto_from_seed; + use arrow_flight::FlightData; + use futures::stream::{self, StreamExt}; + use prost::{Message, bytes::Bytes}; + + fn assert_protocol_error(result: Result, expected_msg: &str) { + let FlightError::ProtocolError(msg) = result.unwrap_err() else { + panic!("expected FlightError::ProtocolError"); + }; + assert!(msg.contains(expected_msg)); + } + + fn make_flight_data(data: Vec, metadata: Option) -> FlightData { + let metadata_bytes = match metadata { + Some(metadata) => metadata.encode_to_vec().into(), + None => Bytes::new(), + }; + FlightData { + flight_descriptor: None, + data_header: Bytes::new(), + app_metadata: metadata_bytes, + data_body: data.into(), + } + } + + // Tests that metrics are extracted from FlightData. Metrics are always stored per task, so this + // tests creates some random metrics and tasks and puts them in the app_metadata field of + // FlightData message. Then, it streams these messages through the MetricsCollectingStream + // and asserts that the metrics are collected correctly. + #[tokio::test] + async fn test_metrics_collecting_stream_extracts_and_removes_metadata() { + let stage_keys = vec![ + StageKey { + query_id: "test_query".to_string(), + stage_id: 1, + task_number: 1, + }, + StageKey { + query_id: "test_query".to_string(), + stage_id: 1, + task_number: 2, + }, + ]; + + let app_metadatas = stage_keys + .iter() + .map(|stage_key| FlightAppMetadata { + content: Some(AppMetadata::MetricsCollection(MetricsCollection { + tasks: vec![TaskMetrics { + stage_key: Some(stage_key.clone()), + // use the task number to seed the test metrics set for convenience + metrics: vec![make_test_metrics_set_proto_from_seed(stage_key.task_number)], + }], + })), + }) + .collect::>(); + + // Create test FlightData messages - some with metadata, some without + let flight_messages = vec![ + make_flight_data(vec![1], Some(app_metadatas[0].clone())), + make_flight_data(vec![2], None), + make_flight_data(vec![3], Some(app_metadatas[1].clone())), + ] + .into_iter() + .map(Ok); + + // Collect all messages from the stream. All should have empty app_metadata. + let metrics_collection = Arc::new(DashMap::new()); + let input_stream = stream::iter(flight_messages); + let collecting_stream = + MetricsCollectingStream::new(input_stream, metrics_collection.clone()); + let collected_messages: Vec = collecting_stream + .map(|result| result.unwrap()) + .collect() + .await; + + // Assert the data is unchanged and app_metadata is cleared + assert_eq!(collected_messages.len(), 3); + assert!( + collected_messages + .iter() + .all(|msg| msg.app_metadata.is_empty()) + ); + + // Verify the data in the messages. + assert_eq!(collected_messages[0].data_body, vec![1]); + assert_eq!(collected_messages[1].data_body, vec![2]); + assert_eq!(collected_messages[2].data_body, vec![3]); + + // Verify the correct metrics were collected + assert_eq!(metrics_collection.len(), 2); + for stage_key in stage_keys { + let collected_metrics = metrics_collection.get(&stage_key).unwrap(); + assert_eq!(collected_metrics.len(), 1); + assert_eq!( + collected_metrics[0], + make_test_metrics_set_proto_from_seed(stage_key.task_number) + ); + } + } + + #[tokio::test] + async fn test_metrics_collecting_stream_error_missing_stage_key() { + let task_metrics_with_no_stage_key = TaskMetrics { + stage_key: None, + metrics: vec![make_test_metrics_set_proto_from_seed(1)], + }; + + let invalid_app_metadata = FlightAppMetadata { + content: Some(AppMetadata::MetricsCollection(MetricsCollection { + tasks: vec![task_metrics_with_no_stage_key], + })), + }; + + let invalid_flight_data = make_flight_data(vec![1], Some(invalid_app_metadata)); + + let error_stream = stream::iter(vec![Ok(invalid_flight_data)]); + let mut collecting_stream = + MetricsCollectingStream::new(error_stream, Arc::new(DashMap::new())); + + let result = collecting_stream.next().await.unwrap(); + assert_protocol_error( + result, + "expected Some StageKey in MetricsCollectingStream, got None", + ); + } + + #[tokio::test] + async fn test_metrics_collecting_stream_error_decoding_metadata() { + let flight_data_with_invalid_metadata = FlightData { + flight_descriptor: None, + data_header: Bytes::new(), + app_metadata: vec![0xFF, 0xFF, 0xFF, 0xFF].into(), // Invalid protobuf data + data_body: vec![4, 5, 6].into(), + }; + + let error_stream = stream::iter(vec![Ok(flight_data_with_invalid_metadata)]); + let mut collecting_stream = + MetricsCollectingStream::new(error_stream, Arc::new(DashMap::new())); + + let result = collecting_stream.next().await.unwrap(); + assert_protocol_error(result, "failed to decode app_metadata"); + } + + #[tokio::test] + async fn test_metrics_collecting_stream_error_propagation() { + let metrics_collection = Arc::new(DashMap::new()); + + // Create a stream that emits an error - should be propagated through + let stream_error = FlightError::ProtocolError("stream error from inner stream".to_string()); + let error_stream = stream::iter(vec![Err(stream_error)]); + let mut collecting_stream = MetricsCollectingStream::new(error_stream, metrics_collection); + + let result = collecting_stream.next().await.unwrap(); + assert_protocol_error(result, "stream error from inner stream"); + } +} diff --git a/src/execution_plans/mod.rs b/src/execution_plans/mod.rs index 0b36606..37e13e3 100644 --- a/src/execution_plans/mod.rs +++ b/src/execution_plans/mod.rs @@ -1,4 +1,5 @@ mod metrics; +mod metrics_collecting_stream; mod network_coalesce; mod network_shuffle; mod partition_isolator; diff --git a/src/metrics/proto.rs b/src/metrics/proto.rs index df086ed..4222d53 100644 --- a/src/metrics/proto.rs +++ b/src/metrics/proto.rs @@ -27,7 +27,7 @@ pub struct MetricsSetProto { } /// MetricValueProto is a protobuf mirror of the [datafusion::physical_plan::metrics::MetricValue] enum. -#[derive(Clone, PartialEq, ::prost::Oneof)] +#[derive(Clone, PartialEq, Eq, ::prost::Oneof)] pub enum MetricValueProto { #[prost(message, tag = "1")] OutputRows(OutputRows), @@ -53,43 +53,43 @@ pub enum MetricValueProto { EndTimestamp(EndTimestamp), } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, ::prost::Message)] pub struct OutputRows { #[prost(uint64, tag = "1")] pub value: u64, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, ::prost::Message)] pub struct ElapsedCompute { #[prost(uint64, tag = "1")] pub value: u64, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, ::prost::Message)] pub struct SpillCount { #[prost(uint64, tag = "1")] pub value: u64, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, ::prost::Message)] pub struct SpilledBytes { #[prost(uint64, tag = "1")] pub value: u64, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, ::prost::Message)] pub struct SpilledRows { #[prost(uint64, tag = "1")] pub value: u64, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, ::prost::Message)] pub struct CurrentMemoryUsage { #[prost(uint64, tag = "1")] pub value: u64, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, ::prost::Message)] pub struct NamedCount { #[prost(string, tag = "1")] pub name: String, @@ -97,7 +97,7 @@ pub struct NamedCount { pub value: u64, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, ::prost::Message)] pub struct NamedGauge { #[prost(string, tag = "1")] pub name: String, @@ -105,7 +105,7 @@ pub struct NamedGauge { pub value: u64, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, ::prost::Message)] pub struct NamedTime { #[prost(string, tag = "1")] pub name: String, @@ -113,20 +113,20 @@ pub struct NamedTime { pub value: u64, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, ::prost::Message)] pub struct StartTimestamp { #[prost(int64, optional, tag = "1")] pub value: Option, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, ::prost::Message)] pub struct EndTimestamp { #[prost(int64, optional, tag = "1")] pub value: Option, } /// A ProtoLabel mirrors [datafusion::physical_plan::metrics::Label]. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, ::prost::Message)] pub struct ProtoLabel { #[prost(string, tag = "1")] pub name: String, @@ -390,11 +390,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - None => Ok(Arc::new(Metric::new_with_labels( - MetricValue::StartTimestamp(Timestamp::new()), - partition, - labels, - ))), + None => internal_err!("encountered invalid start timestamp metric with no value"), }, Some(MetricValueProto::EndTimestamp(end_ts)) => match end_ts.value { Some(value) => { @@ -406,11 +402,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - None => Ok(Arc::new(Metric::new_with_labels( - MetricValue::EndTimestamp(Timestamp::new()), - partition, - labels, - ))), + None => internal_err!("encountered invalid end timestamp metric with no value"), }, None => internal_err!("proto metric is missing the metric field"), } @@ -419,9 +411,11 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion #[cfg(test)] mod tests { use super::*; + use datafusion::physical_plan::metrics::CustomMetricValue; use datafusion::physical_plan::metrics::{Count, Gauge, Label, MetricsSet, Time, Timestamp}; use datafusion::physical_plan::metrics::{Metric, MetricValue}; use std::borrow::Cow; + use std::sync::Arc; fn test_roundtrip_helper(metrics_set: MetricsSet, test_name: &str) { // Serialize and deserialize the metrics set. @@ -783,10 +777,6 @@ mod tests { #[test] fn test_custom_metrics_filtering() { - use datafusion::physical_plan::metrics::{Count, CustomMetricValue, MetricsSet}; - use std::sync::Arc; - - // Create a simple custom metric value implementation for testing #[derive(Debug)] struct TestCustomMetric; @@ -845,9 +835,6 @@ mod tests { #[test] fn test_unrepresentable_timestamp_error() { - use datafusion::physical_plan::metrics::{MetricsSet, Timestamp}; - use std::sync::Arc; - // Use a timestamp that is beyond the range that timestamp_nanos_opt() can handle. let mut metrics_set = MetricsSet::new(); let timestamp = Timestamp::new(); @@ -863,16 +850,21 @@ mod tests { proto_result.is_err(), "should return error for unrepresentable timestamp" ); + } - let error = proto_result.unwrap_err(); - if let DataFusionError::Internal(msg) = error { - assert!( - msg.contains("cannot be represented via a nanosecond timestamp"), - "should be timestamp conversion error, got: {}", - msg - ); - } else { - panic!("expected internal error, got: {:?}", error); - } + #[test] + fn test_invalid_proto_timestamp_error() { + // Create a MetricProto with EndTimestamp that has no value (None) + let invalid_end_timestamp_proto = MetricProto { + metric: Some(MetricValueProto::EndTimestamp(EndTimestamp { value: None })), + labels: vec![], + partition: Some(0), + }; + + let result = metric_proto_to_df(invalid_end_timestamp_proto); + assert!( + result.is_err(), + "should return error for invalid end timestamp with no value" + ); } } diff --git a/src/protobuf/app_metadata.rs b/src/protobuf/app_metadata.rs new file mode 100644 index 0000000..ee74221 --- /dev/null +++ b/src/protobuf/app_metadata.rs @@ -0,0 +1,40 @@ +use crate::metrics::proto::MetricsSetProto; +use crate::protobuf::StageKey; + +/// A collection of metrics for a set of tasks in an ExecutionPlan. each +/// entry should have a distinct [StageKey]. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct MetricsCollection { + #[prost(message, repeated, tag = "1")] + pub tasks: Vec, +} + +/// TaskMetrics represents the metrics for a single task. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TaskMetrics { + /// stage_key uniquely identifies this task. + /// + /// This field is always present. It's marked optional due to protobuf rules. + #[prost(message, optional, tag = "1")] + pub stage_key: Option, + /// metrics[i] is the set of metrics for plan node `i` where plan nodes are in pre-order + /// traversal order. + #[prost(message, repeated, tag = "2")] + pub metrics: Vec, +} + +// FlightAppMetadata represents all types of app_metadata which we use in the distributed execution. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FlightAppMetadata { + // content should always be Some, but it is optional due to protobuf rules. + #[prost(oneof = "AppMetadata", tags = "1")] + pub content: Option, +} + +#[derive(Clone, PartialEq, ::prost::Oneof)] +pub enum AppMetadata { + #[prost(message, tag = "1")] + MetricsCollection(MetricsCollection), + // Note: For every additional enum variant, ensure to add tags to [FlightAppMetadata]. ex. `#[prost(oneof = "AppMetadata", tags = "1,2,3")]` etc. + // If you don't the proto will compile but you may encounter errors during serialization/deserialization. +} diff --git a/src/protobuf/mod.rs b/src/protobuf/mod.rs index 342eb03..0bfa785 100644 --- a/src/protobuf/mod.rs +++ b/src/protobuf/mod.rs @@ -1,8 +1,11 @@ +mod app_metadata; mod distributed_codec; mod errors; mod stage_proto; mod user_codec; +#[allow(unused_imports)] +pub(crate) use app_metadata::{AppMetadata, FlightAppMetadata, MetricsCollection, TaskMetrics}; pub(crate) use distributed_codec::DistributedCodec; pub(crate) use errors::{ datafusion_error_to_tonic_status, map_flight_to_datafusion_error, diff --git a/src/test_utils/metrics.rs b/src/test_utils/metrics.rs new file mode 100644 index 0000000..4335cb7 --- /dev/null +++ b/src/test_utils/metrics.rs @@ -0,0 +1,37 @@ +use crate::metrics::proto::{ElapsedCompute, EndTimestamp, OutputRows, StartTimestamp}; +use crate::metrics::proto::{MetricProto, MetricValueProto, MetricsSetProto}; + +/// creates a "distinct" set of metrics from the provided seed +pub fn make_test_metrics_set_proto_from_seed(seed: u64) -> MetricsSetProto { + const TEST_TIMESTAMP: i64 = 1758200400000000000; // 2025-09-18 13:00:00 UTC + MetricsSetProto { + metrics: vec![ + MetricProto { + metric: Some(MetricValueProto::OutputRows(OutputRows { value: seed })), + labels: vec![], + partition: None, + }, + MetricProto { + metric: Some(MetricValueProto::ElapsedCompute(ElapsedCompute { + value: seed, + })), + labels: vec![], + partition: None, + }, + MetricProto { + metric: Some(MetricValueProto::StartTimestamp(StartTimestamp { + value: Some(TEST_TIMESTAMP + (seed as i64 * 1_000_000_000)), + })), + labels: vec![], + partition: None, + }, + MetricProto { + metric: Some(MetricValueProto::EndTimestamp(EndTimestamp { + value: Some(TEST_TIMESTAMP + ((seed as i64 + 1) * 1_000_000_000)), + })), + labels: vec![], + partition: None, + }, + ], + } +} diff --git a/src/test_utils/mod.rs b/src/test_utils/mod.rs index 4dd283c..140bd63 100644 --- a/src/test_utils/mod.rs +++ b/src/test_utils/mod.rs @@ -1,6 +1,7 @@ pub mod in_memory_channel_resolver; pub mod insta; pub mod localhost; +pub mod metrics; pub mod mock_exec; pub mod parquet; pub mod session_context;