diff --git a/src/execution_plans/metrics.rs b/src/execution_plans/metrics.rs index 4a1ff59..b424c67 100644 --- a/src/execution_plans/metrics.rs +++ b/src/execution_plans/metrics.rs @@ -1,198 +1,16 @@ -use crate::execution_plans::{NetworkCoalesceExec, NetworkShuffleExec, StageExec}; -use crate::metrics::proto::{MetricsSetProto, metrics_set_proto_to_df}; -use crate::protobuf::StageKey; -use datafusion::common::internal_err; -use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter}; -use datafusion::error::{DataFusionError, Result}; -use datafusion::execution::SendableRecordBatchStream; -use datafusion::execution::TaskContext; -use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_plan::metrics::MetricsSet; +use std::sync::Arc; + +use datafusion::error::Result; +use datafusion::execution::{SendableRecordBatchStream, TaskContext}; +use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_plan::{DisplayAs, DisplayFormatType, PlanProperties}; use std::any::Any; -use std::collections::HashMap; use std::fmt::{Debug, Formatter}; -use std::sync::Arc; - -/// TaskMetricsCollector is used to collect metrics from a task. It implements [TreeNodeRewriter]. -/// Note: TaskMetricsCollector is not a [datafusion::physical_plan::ExecutionPlanVisitor] to keep -/// parity between [TaskMetricsCollector] and [TaskMetricsRewriter]. -pub struct TaskMetricsCollector { - /// metrics contains the metrics for the current task. - task_metrics: Vec, - /// child_task_metrics contains metrics for tasks from child [StageExec]s if they were - /// collected. - child_task_metrics: HashMap>, -} - -/// MetricsCollectorResult is the result of collecting metrics from a task. -#[allow(dead_code)] -pub struct MetricsCollectorResult { - // metrics is a collection of metrics for a task ordered using a pre-order traversal of the task's plan. - task_metrics: Vec, - // child_task_metrics contains metrics for child tasks if they were collected. - child_task_metrics: HashMap>, -} - -impl TreeNodeRewriter for TaskMetricsCollector { - type Node = Arc; - - fn f_down(&mut self, plan: Self::Node) -> Result> { - // If the plan is an NetworkShuffleExec, assume it has collected metrics already - // from child tasks. - let metrics_collection = - if let Some(node) = plan.as_any().downcast_ref::() { - let NetworkShuffleExec::Ready(ready) = node else { - return internal_err!( - "unexpected NetworkShuffleExec::Pending during metrics collection" - ); - }; - Some(Arc::clone(&ready.metrics_collection)) - } else if let Some(node) = plan.as_any().downcast_ref::() { - let NetworkCoalesceExec::Ready(ready) = node else { - return internal_err!( - "unexpected NetworkCoalesceExec::Pending during metrics collection" - ); - }; - Some(Arc::clone(&ready.metrics_collection)) - } else { - None - }; - - if let Some(metrics_collection) = metrics_collection { - for mut entry in metrics_collection.iter_mut() { - let stage_key = entry.key().clone(); - let task_metrics = std::mem::take(entry.value_mut()); // Avoid copy. - match self.child_task_metrics.get(&stage_key) { - // There should never be two NetworkShuffleExec with metrics for the same stage_key. - // By convention, the NetworkShuffleExec which runs the last partition in a task should be - // sent metrics (the NetworkShuffleExec tracks it for us). - Some(_) => { - return internal_err!( - "duplicate task metrics for key {} during metrics collection", - stage_key - ); - } - None => { - self.child_task_metrics - .insert(stage_key.clone(), task_metrics); - } - } - } - // Skip the subtree of the NetworkShuffleExec. - return Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)); - } - - // For plan nodes in this task, collect metrics. - match plan.metrics() { - Some(metrics) => self.task_metrics.push(metrics.clone()), - None => { - // TODO: Consider using a more efficent encoding scheme to avoid empty slots in the vec. - self.task_metrics.push(MetricsSet::new()) - } - } - Ok(Transformed::new(plan, false, TreeNodeRecursion::Continue)) - } -} - -impl TaskMetricsCollector { - #[allow(dead_code)] - pub fn new() -> Self { - Self { - task_metrics: Vec::new(), - child_task_metrics: HashMap::new(), - } - } - - /// collect metrics from a StageExec plan and any child tasks. - /// Returns - /// - a vec representing the metrics for the current task (ordered using a pre-order traversal) - /// - a map representing the metrics for some subset of child tasks collected from NetworkShuffleExec leaves - #[allow(dead_code)] - pub fn collect(mut self, stage: &StageExec) -> Result { - stage.plan.clone().rewrite(&mut self)?; - Ok(MetricsCollectorResult { - task_metrics: self.task_metrics, - child_task_metrics: self.child_task_metrics, - }) - } -} - -/// TaskMetricsRewriter is used to enrich a task with metrics by re-writing the plan using [MetricsWrapperExec] nodes. -/// -/// Ex. for a plan with the form -/// AggregateExec -/// └── ProjectionExec -/// └── NetworkShuffleExec -/// -/// the task will be rewritten as -/// -/// MetricsWrapperExec (wrapped: AggregateExec) -/// └── MetricsWrapperExec (wrapped: ProjectionExec) -/// └── NetworkShuffleExec -/// (Note that the NetworkShuffleExec node is not wrapped) -pub struct TaskMetricsRewriter { - metrics: Vec, - idx: usize, -} - -impl TaskMetricsRewriter { - /// Create a new TaskMetricsRewriter. The provided metrics will be used to enrich the plan. - #[allow(dead_code)] - pub fn new(metrics: Vec) -> Self { - Self { metrics, idx: 0 } - } - - /// enrich_task_with_metrics rewrites the plan by wrapping nodes. If the length of the provided metrics set vec does not - /// match the number of nodes in the plan, an error will be returned. - #[allow(dead_code)] - pub fn enrich_task_with_metrics( - mut self, - plan: Arc, - ) -> Result> { - let transformed = plan.rewrite(&mut self)?; - if self.idx != self.metrics.len() { - return internal_err!( - "too many metrics sets provided to rewrite task: {} metrics sets provided, {} nodes in plan", - self.metrics.len(), - self.idx - ); - } - Ok(transformed.data) - } -} - -impl TreeNodeRewriter for TaskMetricsRewriter { - type Node = Arc; - - fn f_down(&mut self, plan: Self::Node) -> Result> { - if plan.as_any().is::() { - return Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)); - } - if plan.as_any().is::() { - return Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)); - } - if self.idx >= self.metrics.len() { - return internal_err!( - "not enough metrics provided to rewrite task: {} metrics provided", - self.metrics.len() - ); - } - let proto_metrics = &self.metrics[self.idx]; - - let wrapped_plan_node: Arc = Arc::new(MetricsWrapperExec::new( - plan.clone(), - metrics_set_proto_to_df(proto_metrics)?, - )); - let result = Transformed::new(wrapped_plan_node, true, TreeNodeRecursion::Continue); - self.idx += 1; - Ok(result) - } -} /// A transparent wrapper that delegates all execution to its child but returns custom metrics. This node is invisible during display. /// The structure of a plan tree is closely tied to the [TaskMetricsRewriter]. -struct MetricsWrapperExec { +pub struct MetricsWrapperExec { inner: Arc, /// metrics for this plan node. metrics: MetricsSet, @@ -269,220 +87,3 @@ impl ExecutionPlan for MetricsWrapperExec { Some(self.metrics.clone()) } } - -#[cfg(test)] -mod tests { - - use super::*; - use datafusion::arrow::array::{Int32Array, StringArray}; - use datafusion::arrow::record_batch::RecordBatch; - - use crate::DistributedExt; - use crate::DistributedPhysicalOptimizerRule; - use crate::test_utils::in_memory_channel_resolver::InMemoryChannelResolver; - use crate::test_utils::metrics::make_test_metrics_set_proto_from_seed; - use crate::test_utils::session_context::register_temp_parquet_table; - use datafusion::execution::{SessionStateBuilder, context::SessionContext}; - use datafusion::physical_plan::metrics::MetricValue; - use datafusion::prelude::SessionConfig; - use datafusion::{ - arrow::datatypes::{DataType, Field, Schema}, - physical_plan::display::DisplayableExecutionPlan, - }; - use std::sync::Arc; - - /// Creates a stage with the following structure: - /// - /// SortPreservingMergeExec - /// SortExec - /// ProjectionExec - /// AggregateExec - /// CoalesceBatchesExec - /// NetworkShuffleExec - /// - /// ... (for the purposes of these tests, we don't care about child stages). - async fn make_test_stage_exec_with_5_nodes() -> (StageExec, SessionContext) { - // Create distributed session state with in-memory channel resolver - let config = SessionConfig::new().with_target_partitions(2); - - let state = SessionStateBuilder::new() - .with_default_features() - .with_config(config) - .with_distributed_channel_resolver(InMemoryChannelResolver::new()) - .with_physical_optimizer_rule(Arc::new( - DistributedPhysicalOptimizerRule::default() - .with_network_coalesce_tasks(2) - .with_network_shuffle_tasks(2), - )) - .build(); - - let ctx = SessionContext::from(state); - - // Create test data - let schema = Arc::new(Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, false), - ])); - - let batches = vec![ - RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![1, 2, 3])), - Arc::new(StringArray::from(vec!["a", "b", "c"])), - ], - ) - .unwrap(), - RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![4, 5, 6])), - Arc::new(StringArray::from(vec!["d", "e", "f"])), - ], - ) - .unwrap(), - ]; - - // Register the test data as a parquet table - let _ = register_temp_parquet_table("test_table", schema.clone(), batches, &ctx) - .await - .unwrap(); - - let df = ctx - .sql("SELECT id, COUNT(*) as count FROM test_table WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10") - .await - .unwrap(); - let physical_distributed = df.create_physical_plan().await.unwrap(); - - let stage_exec = match physical_distributed.as_any().downcast_ref::() { - Some(stage_exec) => stage_exec.clone(), - None => panic!( - "Expected StageExec from distributed optimization, got: {}", - physical_distributed.name() - ), - }; - - (stage_exec, ctx) - } - - #[tokio::test] - #[ignore] - async fn test_metrics_rewriter() { - let (test_stage, _ctx) = make_test_stage_exec_with_5_nodes().await; - let test_metrics_sets = (0..5) // 5 nodes excluding NetworkShuffleExec - .map(|i| make_test_metrics_set_proto_from_seed(i + 10)) - .collect::>(); - - let rewriter = TaskMetricsRewriter::new(test_metrics_sets.clone()); - let plan_with_metrics = rewriter - .enrich_task_with_metrics(test_stage.plan.clone()) - .unwrap(); - - let plan_str = - DisplayableExecutionPlan::with_full_metrics(plan_with_metrics.as_ref()).indent(true); - // Expected distributed plan output with metrics - let expected = [ - r"SortPreservingMergeExec: [id@0 ASC NULLS LAST], fetch=10, metrics=[output_rows=10, elapsed_compute=10ns, start_timestamp=2025-09-18 13:00:10 UTC, end_timestamp=2025-09-18 13:00:11 UTC]", - r" SortExec: TopK(fetch=10), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true], metrics=[output_rows=11, elapsed_compute=11ns, start_timestamp=2025-09-18 13:00:11 UTC, end_timestamp=2025-09-18 13:00:12 UTC]", - r" ProjectionExec: expr=[id@0 as id, count(Int64(1))@1 as count], metrics=[output_rows=12, elapsed_compute=12ns, start_timestamp=2025-09-18 13:00:12 UTC, end_timestamp=2025-09-18 13:00:13 UTC]", - r" AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[count(Int64(1))], metrics=[output_rows=13, elapsed_compute=13ns, start_timestamp=2025-09-18 13:00:13 UTC, end_timestamp=2025-09-18 13:00:14 UTC]", - r" CoalesceBatchesExec: target_batch_size=8192, metrics=[output_rows=14, elapsed_compute=14ns, start_timestamp=2025-09-18 13:00:14 UTC, end_timestamp=2025-09-18 13:00:15 UTC]", - r" NetworkShuffleExec, metrics=[]", - "" // trailing newline - ].join("\n"); - assert_eq!(expected, plan_str.to_string()); - } - - #[tokio::test] - #[ignore] - async fn test_metrics_rewriter_correct_number_of_metrics() { - let test_metrics_set = make_test_metrics_set_proto_from_seed(10); - let (executable_plan, _ctx) = make_test_stage_exec_with_5_nodes().await; - let task_plan = executable_plan - .as_any() - .downcast_ref::() - .unwrap() - .plan - .clone(); - - // Too few metrics sets. - let rewriter = TaskMetricsRewriter::new(vec![test_metrics_set.clone()]); - let result = rewriter.enrich_task_with_metrics(task_plan.clone()); - assert!(result.is_err()); - - // Too many metrics sets. - let rewriter = TaskMetricsRewriter::new(vec![ - test_metrics_set.clone(), - test_metrics_set.clone(), - test_metrics_set.clone(), - test_metrics_set.clone(), - ]); - let result = rewriter.enrich_task_with_metrics(task_plan.clone()); - assert!(result.is_err()); - } - - #[tokio::test] - #[ignore] - async fn test_metrics_collection() { - let (stage_exec, ctx) = make_test_stage_exec_with_5_nodes().await; - - // Execute the plan to completion. - let task_ctx = ctx.task_ctx(); - let stream = stage_exec.execute(0, task_ctx).unwrap(); - - use futures::StreamExt; - let mut stream = stream; - while let Some(_batch) = stream.next().await {} - - let collector = TaskMetricsCollector::new(); - let result = collector.collect(&stage_exec).unwrap(); - - // With the distributed optimizer, we get a much more complex plan structure - // The exact number of metrics sets depends on the plan optimization, so be flexible - assert_eq!(result.task_metrics.len(), 5); - - let expected_metrics_count = [4, 10, 8, 16, 8]; - for (node_idx, metrics_set) in result.task_metrics.iter().enumerate() { - let metrics_count = metrics_set.iter().count(); - assert_eq!(metrics_count, expected_metrics_count[node_idx]); - - // Each node should have basic metrics: ElapsedCompute, OutputRows, StartTimestamp, EndTimestamp. - let mut has_start_timestamp = false; - let mut has_end_timestamp = false; - let mut has_elapsed_compute = false; - let mut has_output_rows = false; - - for metric in metrics_set.iter() { - let metric_name = metric.value().name(); - let metric_value = metric.value(); - - match metric_value { - MetricValue::StartTimestamp(_) if metric_name == "start_timestamp" => { - has_start_timestamp = true; - } - MetricValue::EndTimestamp(_) if metric_name == "end_timestamp" => { - has_end_timestamp = true; - } - MetricValue::ElapsedCompute(_) if metric_name == "elapsed_compute" => { - has_elapsed_compute = true; - } - MetricValue::OutputRows(_) if metric_name == "output_rows" => { - has_output_rows = true; - } - _ => { - // Other metrics are fine, we just validate the core ones - } - } - } - - // Each node should have the four basic metrics - assert!(has_start_timestamp); - assert!(has_end_timestamp); - assert!(has_elapsed_compute); - assert!(has_output_rows); - } - - // TODO: once we propagate metrics from child stages, we can assert this. - assert_eq!(0, result.child_task_metrics.len()); - } -} diff --git a/src/execution_plans/mod.rs b/src/execution_plans/mod.rs index 551b36e..bf421e6 100644 --- a/src/execution_plans/mod.rs +++ b/src/execution_plans/mod.rs @@ -1,10 +1,10 @@ mod metrics; -mod metrics_collecting_stream; mod network_coalesce; mod network_shuffle; mod partition_isolator; mod stage; +pub use metrics::MetricsWrapperExec; pub use network_coalesce::{NetworkCoalesceExec, NetworkCoalesceReady}; pub use network_shuffle::{NetworkShuffleExec, NetworkShuffleReadyExec}; pub use partition_isolator::PartitionIsolatorExec; diff --git a/src/execution_plans/network_coalesce.rs b/src/execution_plans/network_coalesce.rs index 5b5ca6d..d2b4972 100644 --- a/src/execution_plans/network_coalesce.rs +++ b/src/execution_plans/network_coalesce.rs @@ -5,6 +5,7 @@ use crate::config_extension_ext::ContextGrpcMetadata; use crate::distributed_physical_optimizer_rule::{NetworkBoundary, limit_tasks_err}; use crate::execution_plans::{DistributedTaskContext, StageExec}; use crate::flight_service::DoGet; +use crate::metrics::MetricsCollectingStream; use crate::metrics::proto::MetricsSetProto; use crate::protobuf::{DistributedCodec, StageKey, proto_from_input_stage}; use crate::protobuf::{map_flight_to_datafusion_error, map_status_to_datafusion_error}; @@ -282,6 +283,7 @@ impl ExecutionPlan for NetworkCoalesceExec { return internal_err!("NetworkCoalesceExec: task is unassigned, cannot proceed"); }; + let metrics_collection_capture = self_ready.metrics_collection.clone(); let stream = async move { let channel = channel_resolver.get_channel_for_url(&url).await?; let stream = FlightServiceClient::new(channel) @@ -291,8 +293,13 @@ impl ExecutionPlan for NetworkCoalesceExec { .into_inner() .map_err(|err| FlightError::Tonic(Box::new(err))); - Ok(FlightRecordBatchStream::new_from_flight_data(stream) - .map_err(map_flight_to_datafusion_error)) + let metrics_collecting_stream = + MetricsCollectingStream::new(stream, metrics_collection_capture); + + Ok( + FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream) + .map_err(map_flight_to_datafusion_error), + ) } .try_flatten_stream(); diff --git a/src/execution_plans/network_shuffle.rs b/src/execution_plans/network_shuffle.rs index c8adf68..0ecbad4 100644 --- a/src/execution_plans/network_shuffle.rs +++ b/src/execution_plans/network_shuffle.rs @@ -5,6 +5,7 @@ use crate::config_extension_ext::ContextGrpcMetadata; use crate::distributed_physical_optimizer_rule::NetworkBoundary; use crate::execution_plans::{DistributedTaskContext, StageExec}; use crate::flight_service::DoGet; +use crate::metrics::MetricsCollectingStream; use crate::metrics::proto::MetricsSetProto; use crate::protobuf::{DistributedCodec, StageKey, proto_from_input_stage}; use crate::protobuf::{map_flight_to_datafusion_error, map_status_to_datafusion_error}; @@ -330,6 +331,7 @@ impl ExecutionPlan for NetworkShuffleExec { }, ); + let metrics_collection_capture = self_ready.metrics_collection.clone(); async move { let url = task.url.ok_or(internal_datafusion_err!( "NetworkShuffleExec: task is unassigned, cannot proceed" @@ -343,8 +345,13 @@ impl ExecutionPlan for NetworkShuffleExec { .into_inner() .map_err(|err| FlightError::Tonic(Box::new(err))); - Ok(FlightRecordBatchStream::new_from_flight_data(stream) - .map_err(map_flight_to_datafusion_error)) + let metrics_collecting_stream = + MetricsCollectingStream::new(stream, metrics_collection_capture); + + Ok( + FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream) + .map_err(map_flight_to_datafusion_error), + ) } .try_flatten_stream() .boxed() diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index ece336a..363b533 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -3,9 +3,17 @@ use crate::config_extension_ext::ContextGrpcMetadata; use crate::execution_plans::{DistributedTaskContext, StageExec}; use crate::flight_service::service::ArrowFlightEndpoint; use crate::flight_service::session_builder::DistributedSessionBuilderContext; +use crate::flight_service::trailing_flight_data_stream::TrailingFlightDataStream; +use crate::metrics::TaskMetricsCollector; +use crate::metrics::proto::df_metrics_set_to_proto; use crate::protobuf::{ - DistributedCodec, StageKey, datafusion_error_to_tonic_status, stage_from_proto, + AppMetadata, DistributedCodec, FlightAppMetadata, MetricsCollection, StageKey, TaskMetrics, + datafusion_error_to_tonic_status, stage_from_proto, }; +use arrow::array::RecordBatch; +use arrow::datatypes::SchemaRef; +use arrow::ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions}; +use arrow_flight::FlightData; use arrow_flight::Ticket; use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::error::FlightError; @@ -15,6 +23,7 @@ use datafusion::common::exec_datafusion_err; use datafusion::execution::SendableRecordBatchStream; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use futures::TryStreamExt; +use futures::{Stream, stream}; use prost::Message; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -97,12 +106,6 @@ impl ArrowFlightEndpoint { }) .await?; let stage = Arc::clone(&stage_data.stage); - let num_partitions_remaining = Arc::clone(&stage_data.num_partitions_remaining); - - // If all the partitions are done, remove the stage from the cache. - if num_partitions_remaining.fetch_sub(1, Ordering::SeqCst) <= 1 { - self.task_data_entries.remove(key); - } // Find out which partition group we are executing let cfg = session_state.config_mut(); @@ -130,16 +133,29 @@ impl ArrowFlightEndpoint { .map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?; let schema = stream.schema(); + + // TODO: We don't need to do this since the stage / plan is captured again by the + // TrailingFlightDataStream. However, we will eventuall only use the TrailingFlightDataStream + // if we are running an `explain (analyze)` command. We should update this section + // to only use one or the other - not both. + let plan_capture = stage.plan.clone(); let stream = with_callback(stream, move |_| { // We need to hold a reference to the plan for at least as long as the stream is // execution. Some plans might store state necessary for the stream to work, and // dropping the plan early could drop this state too soon. - let _ = stage.plan; + let _ = plan_capture; }); - Ok(record_batch_stream_to_response(Box::pin( - RecordBatchStreamAdapter::new(schema, stream), - ))) + let record_batch_stream = Box::pin(RecordBatchStreamAdapter::new(schema, stream)); + let task_data_capture = self.task_data_entries.clone(); + Ok(flight_stream_from_record_batch_stream( + key.clone(), + stage_data.clone(), + move || { + task_data_capture.remove(key.clone()); + }, + record_batch_stream, + )) } } @@ -147,7 +163,12 @@ fn missing(field: &'static str) -> impl FnOnce() -> Status { move || Status::invalid_argument(format!("Missing field '{field}'")) } -fn record_batch_stream_to_response( +/// Creates a tonic response from a stream of record batches. Handles +/// - RecordBatch to flight conversion partition tracking, stage eviction, and trailing metrics. +fn flight_stream_from_record_batch_stream( + stage_key: StageKey, + stage_data: TaskData, + evict_stage: impl FnOnce() + Send + 'static, stream: SendableRecordBatchStream, ) -> Response<::DoGetStream> { let flight_data_stream = @@ -157,12 +178,109 @@ fn record_batch_stream_to_response( FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err))) })); - Response::new(Box::pin(flight_data_stream.map_err(|err| match err { + let trailing_metrics_stream = TrailingFlightDataStream::new( + move || { + if stage_data + .num_partitions_remaining + .fetch_sub(1, Ordering::SeqCst) + == 1 + { + evict_stage(); + + let metrics_stream = + collect_and_create_metrics_flight_data(stage_key, stage_data.stage).map_err( + |err| { + Status::internal(format!( + "error collecting metrics in arrow flight endpoint: {err}" + )) + }, + )?; + + return Ok(Some(metrics_stream)); + } + + Ok(None) + }, + flight_data_stream, + ); + + Response::new(Box::pin(trailing_metrics_stream.map_err(|err| match err { FlightError::Tonic(status) => *status, _ => Status::internal(format!("Error during flight stream: {err}")), }))) } +// Collects metrics from the provided stage and encodes it into a stream of flight data using +// the schema of the stage. +fn collect_and_create_metrics_flight_data( + stage_key: StageKey, + stage: Arc, +) -> Result> + Send + 'static, FlightError> { + // Get the metrics for the task executed on this worker. Separately, collect metrics for child tasks. + let mut result = TaskMetricsCollector::new() + .collect(stage.plan.clone()) + .map_err(|err| FlightError::ProtocolError(err.to_string()))?; + + // Add the metrics for this task into the collection of task metrics. + // Skip any metrics that can't be converted to proto (unsupported types) + let proto_task_metrics = result + .task_metrics + .iter() + .map(|metrics| { + df_metrics_set_to_proto(metrics) + .map_err(|err| FlightError::ProtocolError(err.to_string())) + }) + .collect::, _>>()?; + result + .input_task_metrics + .insert(stage_key, proto_task_metrics); + + // Serialize the metrics for all tasks. + let mut task_metrics_set = vec![]; + for (stage_key, metrics) in result.input_task_metrics.into_iter() { + task_metrics_set.push(TaskMetrics { + stage_key: Some(stage_key), + metrics, + }); + } + + let flight_app_metadata = FlightAppMetadata { + content: Some(AppMetadata::MetricsCollection(MetricsCollection { + tasks: task_metrics_set, + })), + }; + + let metrics_flight_data = + empty_flight_data_with_app_metadata(flight_app_metadata, stage.plan.schema())?; + Ok(Box::pin(stream::once( + async move { Ok(metrics_flight_data) }, + ))) +} + +/// Creates a FlightData with the given app_metadata and empty RecordBatch using the provided schema. +/// We don't use [arrow_flight::encode::FlightDataEncoder] (and by extension, the [arrow_flight::encode::FlightDataEncoderBuilder]) +/// since they skip messages with empty RecordBatch data. +pub fn empty_flight_data_with_app_metadata( + metadata: FlightAppMetadata, + schema: SchemaRef, +) -> Result { + let mut buf = vec![]; + metadata + .encode(&mut buf) + .map_err(|err| FlightError::ProtocolError(err.to_string()))?; + + let empty_batch = RecordBatch::new_empty(schema); + let options = IpcWriteOptions::default(); + let data_gen = IpcDataGenerator::default(); + let mut dictionary_tracker = DictionaryTracker::new(true); + let (_, encoded_data) = data_gen + .encoded_batch(&empty_batch, &mut dictionary_tracker, &options) + .map_err(|e| { + FlightError::ProtocolError(format!("Failed to create empty batch FlightData: {e}")) + })?; + Ok(FlightData::from(encoded_data).with_app_metadata(buf)) +} + #[cfg(test)] mod tests { use super::*; @@ -228,9 +346,9 @@ mod tests { let stage_proto = proto_from_stage(&stage, &DefaultPhysicalExtensionCodec {}).unwrap(); let stage_proto_for_closure = stage_proto.clone(); let endpoint_ref = &endpoint; + let do_get = async move |partition: u64, task_number: u64, stage_key: StageKey| { let stage_proto = stage_proto_for_closure.clone(); - // Create DoGet message let doget = DoGet { stage_proto: stage_proto.encode_to_vec().into(), target_task_index: task_number, @@ -238,14 +356,17 @@ mod tests { stage_key: Some(stage_key), }; - // Create Flight ticket let ticket = Ticket { ticket: Bytes::from(doget.encode_to_vec()), }; - // Call the actual get() method let request = Request::new(ticket); - endpoint_ref.get(request).await + let response = endpoint_ref.get(request).await?; + let mut stream = response.into_inner(); + + // Consume the stream. + while let Some(_flight_data) = stream.try_next().await? {} + Ok::<(), Status>(()) }; // For each task, call do_get() for each partition except the last. @@ -261,7 +382,7 @@ mod tests { // Run the last partition of task 0. Any partition number works. Verify that the task state // is evicted because all partitions have been processed. - let result = do_get(1, 0, task_keys[0].clone()).await; + let result = do_get(2, 0, task_keys[0].clone()).await; assert!(result.is_ok()); let stored_stage_keys = endpoint.task_data_entries.keys().collect::>(); assert_eq!(stored_stage_keys.len(), 2); @@ -269,14 +390,14 @@ mod tests { assert!(stored_stage_keys.contains(&task_keys[2])); // Run the last partition of task 1. - let result = do_get(1, 1, task_keys[1].clone()).await; + let result = do_get(2, 1, task_keys[1].clone()).await; assert!(result.is_ok()); let stored_stage_keys = endpoint.task_data_entries.keys().collect::>(); assert_eq!(stored_stage_keys.len(), 1); assert!(stored_stage_keys.contains(&task_keys[2])); // Run the last partition of the last task. - let result = do_get(1, 2, task_keys[2].clone()).await; + let result = do_get(2, 2, task_keys[2].clone()).await; assert!(result.is_ok()); let stored_stage_keys = endpoint.task_data_entries.keys().collect::>(); assert_eq!(stored_stage_keys.len(), 0); diff --git a/src/flight_service/mod.rs b/src/flight_service/mod.rs index fdc99c4..db3bd91 100644 --- a/src/flight_service/mod.rs +++ b/src/flight_service/mod.rs @@ -1,7 +1,7 @@ mod do_get; mod service; mod session_builder; -mod trailing_flight_data_stream; +pub(super) mod trailing_flight_data_stream; pub(crate) use do_get::DoGet; pub use service::ArrowFlightEndpoint; diff --git a/src/flight_service/service.rs b/src/flight_service/service.rs index 675cd51..df01be7 100644 --- a/src/flight_service/service.rs +++ b/src/flight_service/service.rs @@ -17,7 +17,7 @@ use tonic::{Request, Response, Status, Streaming}; pub struct ArrowFlightEndpoint { pub(super) runtime: Arc, - pub(super) task_data_entries: TTLMap>>, + pub(super) task_data_entries: Arc>>>, pub(super) session_builder: Arc, } @@ -28,7 +28,7 @@ impl ArrowFlightEndpoint { let ttl_map = TTLMap::try_new(TTLMapConfig::default())?; Ok(Self { runtime: Arc::new(RuntimeEnv::default()), - task_data_entries: ttl_map, + task_data_entries: Arc::new(ttl_map), session_builder: Arc::new(session_builder), }) } diff --git a/src/flight_service/trailing_flight_data_stream.rs b/src/flight_service/trailing_flight_data_stream.rs index db85903..36e6778 100644 --- a/src/flight_service/trailing_flight_data_stream.rs +++ b/src/flight_service/trailing_flight_data_stream.rs @@ -8,25 +8,25 @@ use tokio::pin; /// TrailingFlightDataStream - wraps a FlightData stream. It calls the `on_complete` closure when the stream is finished. /// If the closure returns a new stream, it will be appended to the original stream and consumed. #[pin_project] -pub struct TrailingFlightDataStream +pub struct TrailingFlightDataStream where S: Stream> + Send, - F: FnOnce() -> Result, FlightError>, + T: Stream> + Send, + F: FnOnce() -> Result, FlightError>, { #[pin] inner: S, on_complete: Option, #[pin] - trailing_stream: Option, + trailing_stream: Option, } -impl TrailingFlightDataStream +impl TrailingFlightDataStream where S: Stream> + Send, - F: FnOnce() -> Result, FlightError>, + T: Stream> + Send, + F: FnOnce() -> Result, FlightError>, { - // TODO: remove - #[allow(dead_code)] pub fn new(on_complete: F, inner: S) -> Self { Self { inner, @@ -36,10 +36,11 @@ where } } -impl Stream for TrailingFlightDataStream +impl Stream for TrailingFlightDataStream where S: Stream> + Send, - F: FnOnce() -> Result, FlightError>, + T: Stream> + Send, + F: FnOnce() -> Result, FlightError>, { type Item = Result; @@ -74,7 +75,7 @@ mod tests { use arrow::record_batch::RecordBatch; use arrow_flight::FlightData; use arrow_flight::decode::FlightRecordBatchStream; - use arrow_flight::encode::FlightDataEncoderBuilder; + use arrow_flight::encode::{FlightDataEncoder, FlightDataEncoderBuilder}; use futures::stream::{self, StreamExt}; use std::sync::Arc; @@ -186,7 +187,7 @@ mod tests { )))), ]; let inner_stream = stream::iter(data); - let on_complete = || Ok(None); + let on_complete = || Ok(None::); let trailing_stream = TrailingFlightDataStream::new(on_complete, inner_stream); let record_batches = FlightRecordBatchStream::new_from_flight_data(trailing_stream) .collect::>>() @@ -202,8 +203,7 @@ mod tests { let name_array = StringArray::from(vec!["item1"]); let value_array = Int32Array::from(vec![1]); let inner_stream = create_flight_data_stream(name_array, value_array); - - let on_complete = || -> Result, FlightError> { + let on_complete = || -> Result, FlightError> { Err(FlightError::ExternalError(Box::new(std::io::Error::new( std::io::ErrorKind::Other, "callback error", @@ -225,7 +225,7 @@ mod tests { StringArray::from(vec!["item1"] as Vec<&str>), Int32Array::from(vec![1] as Vec), ); - let on_complete = || Ok(None); + let on_complete = || Ok(None::); let trailing_stream = TrailingFlightDataStream::new(on_complete, inner_stream); let record_batches = FlightRecordBatchStream::new_from_flight_data(trailing_stream) .collect::>>() diff --git a/src/execution_plans/metrics_collecting_stream.rs b/src/metrics/metrics_collecting_stream.rs similarity index 98% rename from src/execution_plans/metrics_collecting_stream.rs rename to src/metrics/metrics_collecting_stream.rs index 1321f63..32a424d 100644 --- a/src/execution_plans/metrics_collecting_stream.rs +++ b/src/metrics/metrics_collecting_stream.rs @@ -27,7 +27,6 @@ impl MetricsCollectingStream where S: Stream> + Send, { - #[allow(dead_code)] pub fn new( stream: S, metrics_collection: Arc>>, @@ -67,8 +66,8 @@ where }; metrics_collection.insert(stage_key, task_metrics.metrics); } - flight_data.app_metadata.clear(); + Ok(()) } } @@ -256,7 +255,8 @@ mod tests { // Create a stream that emits an error - should be propagated through let stream_error = FlightError::ProtocolError("stream error from inner stream".to_string()); let error_stream = stream::iter(vec![Err(stream_error)]); - let mut collecting_stream = MetricsCollectingStream::new(error_stream, metrics_collection); + let mut collecting_stream = + MetricsCollectingStream::new(error_stream, metrics_collection.clone()); let result = collecting_stream.next().await.unwrap(); assert_protocol_error(result, "stream error from inner stream"); diff --git a/src/metrics/mod.rs b/src/metrics/mod.rs index 611e8e1..b82d8c0 100644 --- a/src/metrics/mod.rs +++ b/src/metrics/mod.rs @@ -1 +1,6 @@ +mod metrics_collecting_stream; pub(crate) mod proto; +mod task_metrics_collector; +mod task_metrics_rewriter; +pub(crate) use metrics_collecting_stream::MetricsCollectingStream; +pub(crate) use task_metrics_collector::TaskMetricsCollector; diff --git a/src/metrics/proto.rs b/src/metrics/proto.rs index 4222d53..ad95438 100644 --- a/src/metrics/proto.rs +++ b/src/metrics/proto.rs @@ -137,7 +137,6 @@ pub struct ProtoLabel { /// df_metrics_set_to_proto converts a [datafusion::physical_plan::metrics::MetricsSet] to a [MetricsSetProto]. /// Custom metrics are filtered out, but any other errors are returned. /// TODO(#140): Support custom metrics. -#[allow(dead_code)] pub fn df_metrics_set_to_proto( metrics_set: &MetricsSet, ) -> Result { @@ -164,7 +163,6 @@ pub fn df_metrics_set_to_proto( } /// metrics_set_proto_to_df converts a [MetricsSetProto] to a [datafusion::physical_plan::metrics::MetricsSet]. -#[allow(dead_code)] pub fn metrics_set_proto_to_df( metrics_set_proto: &MetricsSetProto, ) -> Result { @@ -178,12 +176,10 @@ pub fn metrics_set_proto_to_df( } /// Custom metrics are not supported in proto conversion. -#[allow(dead_code)] const CUSTOM_METRICS_NOT_SUPPORTED: &str = "custom metrics are not supported in metrics proto conversion"; /// df_metric_to_proto converts a `datafusion::physical_plan::metrics::Metric` to a `MetricProto`. It does not consume the Arc. -#[allow(dead_code)] pub fn df_metric_to_proto(metric: Arc) -> Result { let partition = metric.partition().map(|p| p as u64); let labels = metric @@ -380,30 +376,28 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::StartTimestamp(start_ts)) => match start_ts.value { - Some(value) => { - let timestamp = Timestamp::new(); + Some(MetricValueProto::StartTimestamp(start_ts)) => { + let timestamp = Timestamp::new(); + if let Some(value) = start_ts.value { timestamp.set(DateTime::from_timestamp_nanos(value)); - Ok(Arc::new(Metric::new_with_labels( - MetricValue::StartTimestamp(timestamp), - partition, - labels, - ))) } - None => internal_err!("encountered invalid start timestamp metric with no value"), - }, - Some(MetricValueProto::EndTimestamp(end_ts)) => match end_ts.value { - Some(value) => { - let timestamp = Timestamp::new(); + Ok(Arc::new(Metric::new_with_labels( + MetricValue::StartTimestamp(timestamp), + partition, + labels, + ))) + } + Some(MetricValueProto::EndTimestamp(end_ts)) => { + let timestamp = Timestamp::new(); + if let Some(value) = end_ts.value { timestamp.set(DateTime::from_timestamp_nanos(value)); - Ok(Arc::new(Metric::new_with_labels( - MetricValue::EndTimestamp(timestamp), - partition, - labels, - ))) } - None => internal_err!("encountered invalid end timestamp metric with no value"), - }, + Ok(Arc::new(Metric::new_with_labels( + MetricValue::EndTimestamp(timestamp), + partition, + labels, + ))) + } None => internal_err!("proto metric is missing the metric field"), } } @@ -853,18 +847,22 @@ mod tests { } #[test] - fn test_invalid_proto_timestamp_error() { - // Create a MetricProto with EndTimestamp that has no value (None) - let invalid_end_timestamp_proto = MetricProto { - metric: Some(MetricValueProto::EndTimestamp(EndTimestamp { value: None })), - labels: vec![], - partition: Some(0), - }; - - let result = metric_proto_to_df(invalid_end_timestamp_proto); + fn test_default_timestamp_roundtrip() { + let default_timestamp = Timestamp::default(); + let metric_with_default_timestamp = + Metric::new(MetricValue::EndTimestamp(default_timestamp), Some(0)); + + let proto_result = df_metric_to_proto(Arc::new(metric_with_default_timestamp)); + assert!( + proto_result.is_ok(), + "should successfully convert default timestamp to proto" + ); + + let proto_metric = proto_result.unwrap(); + let roundtrip_result = metric_proto_to_df(proto_metric); assert!( - result.is_err(), - "should return error for invalid end timestamp with no value" + roundtrip_result.is_ok(), + "should successfully roundtrip default timestamp" ); } } diff --git a/src/metrics/task_metrics_collector.rs b/src/metrics/task_metrics_collector.rs new file mode 100644 index 0000000..047e0f8 --- /dev/null +++ b/src/metrics/task_metrics_collector.rs @@ -0,0 +1,338 @@ +use crate::execution_plans::NetworkCoalesceExec; +use crate::execution_plans::NetworkShuffleExec; +use crate::metrics::proto::MetricsSetProto; +use crate::protobuf::StageKey; +use datafusion::common::HashMap; +use datafusion::common::tree_node::Transformed; +use datafusion::common::tree_node::TreeNode; +use datafusion::common::tree_node::TreeNodeRecursion; +use datafusion::common::tree_node::TreeNodeRewriter; +use datafusion::error::DataFusionError; +use datafusion::error::Result; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::internal_err; +use datafusion::physical_plan::metrics::MetricsSet; +use std::sync::Arc; + +/// TaskMetricsCollector is used to collect metrics from a task. It implements [TreeNodeRewriter]. +/// Note: TaskMetricsCollector is not a [datafusion::physical_plan::ExecutionPlanVisitor] to keep +/// parity between [TaskMetricsCollector] and [TaskMetricsRewriter]. +pub struct TaskMetricsCollector { + /// metrics contains the metrics for the current task. + task_metrics: Vec, + /// input_task_metrics contains metrics for tasks from child [StageExec]s if they were + /// collected. + input_task_metrics: HashMap>, +} + +/// MetricsCollectorResult is the result of collecting metrics from a task. +pub struct MetricsCollectorResult { + // metrics is a collection of metrics for a task ordered using a pre-order traversal of the task's plan. + pub task_metrics: Vec, + // input_task_metrics contains metrics for child tasks if they were collected. + pub input_task_metrics: HashMap>, +} + +impl TreeNodeRewriter for TaskMetricsCollector { + type Node = Arc; + + fn f_down(&mut self, plan: Self::Node) -> Result> { + // If the plan is an NetworkShuffleExec, assume it has collected metrics already + // from child tasks. + let metrics_collection = + if let Some(node) = plan.as_any().downcast_ref::() { + let NetworkShuffleExec::Ready(ready) = node else { + return internal_err!( + "unexpected NetworkShuffleExec::Pending during metrics collection" + ); + }; + Some(Arc::clone(&ready.metrics_collection)) + } else if let Some(node) = plan.as_any().downcast_ref::() { + let NetworkCoalesceExec::Ready(ready) = node else { + return internal_err!( + "unexpected NetworkCoalesceExec::Pending during metrics collection" + ); + }; + Some(Arc::clone(&ready.metrics_collection)) + } else { + None + }; + + if let Some(metrics_collection) = metrics_collection { + for mut entry in metrics_collection.iter_mut() { + let stage_key = entry.key().clone(); + let task_metrics = std::mem::take(entry.value_mut()); // Avoid copy. + match self.input_task_metrics.get(&stage_key) { + // There should never be two NetworkShuffleExec with metrics for the same stage_key. + // By convention, the NetworkShuffleExec which runs the last partition in a task should be + // sent metrics (the NetworkShuffleExec tracks it for us). + Some(_) => { + return internal_err!( + "duplicate task metrics for key {} during metrics collection", + stage_key + ); + } + None => { + self.input_task_metrics + .insert(stage_key.clone(), task_metrics); + } + } + } + // Skip the subtree of the NetworkShuffleExec. + return Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)); + } + + // For plan nodes in this task, collect metrics. + match plan.metrics() { + Some(metrics) => self.task_metrics.push(metrics.clone()), + None => { + // TODO: Consider using a more efficent encoding scheme to avoid empty slots in the vec. + self.task_metrics.push(MetricsSet::new()) + } + } + Ok(Transformed::new(plan, false, TreeNodeRecursion::Continue)) + } +} + +impl TaskMetricsCollector { + pub fn new() -> Self { + Self { + task_metrics: Vec::new(), + input_task_metrics: HashMap::new(), + } + } + + /// collect metrics from an [ExecutionPlan] (usually a [StageExec].plan) and any child tasks. + /// Returns + /// - a vec representing the metrics for the current task (ordered using a pre-order traversal) + /// - a map representing the metrics for some subset of child tasks collected from NetworkShuffleExec leaves + pub fn collect( + mut self, + plan: Arc, + ) -> Result { + plan.rewrite(&mut self)?; + Ok(MetricsCollectorResult { + task_metrics: self.task_metrics, + input_task_metrics: self.input_task_metrics, + }) + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use datafusion::arrow::array::{Int32Array, StringArray}; + use datafusion::arrow::record_batch::RecordBatch; + use futures::StreamExt; + + use crate::DistributedPhysicalOptimizerRule; + use crate::metrics::proto::{df_metrics_set_to_proto, metrics_set_proto_to_df}; + use crate::test_utils::in_memory_channel_resolver::InMemoryChannelResolver; + use crate::test_utils::plans::{count_plan_nodes, get_stages_and_stage_keys}; + use crate::test_utils::session_context::register_temp_parquet_table; + use crate::{DistributedExt, StageExec}; + use datafusion::execution::{SessionStateBuilder, context::SessionContext}; + use datafusion::prelude::SessionConfig; + use datafusion::{ + arrow::datatypes::{DataType, Field, Schema}, + physical_plan::display::DisplayableExecutionPlan, + }; + use std::sync::Arc; + + /// Creates a session context and registers two tables: + /// - table1 (id: int, name: string) + /// - table2 (id: int, name: string, phone: string, balance: float64) + async fn make_test_ctx() -> SessionContext { + // Create distributed session state with in-memory channel resolver + let config = SessionConfig::new().with_target_partitions(2); + + let state = SessionStateBuilder::new() + .with_default_features() + .with_config(config) + .with_distributed_channel_resolver(InMemoryChannelResolver::new()) + .with_physical_optimizer_rule(Arc::new( + DistributedPhysicalOptimizerRule::default() + .with_network_coalesce_tasks(2) + .with_network_shuffle_tasks(2), + )) + .build(); + + let ctx = SessionContext::from(state); + + // Create test data for table1 + let schema1 = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + ])); + + let batches1 = vec![ + RecordBatch::try_new( + schema1.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ], + ) + .unwrap(), + ]; + + // Create test data for table2 with extended schema + let schema2 = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + Field::new("phone", DataType::Utf8, false), + Field::new("balance", DataType::Float64, false), + ])); + + let batches2 = vec![ + RecordBatch::try_new( + schema2.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec![ + "customer1", + "customer2", + "customer3", + ])), + Arc::new(StringArray::from(vec![ + "13-123-4567", + "31-456-7890", + "23-789-0123", + ])), + Arc::new(datafusion::arrow::array::Float64Array::from(vec![ + 100.5, 250.0, 50.25, + ])), + ], + ) + .unwrap(), + ]; + + // Register the test data as parquet tables + let _ = register_temp_parquet_table("table1", schema1, batches1, &ctx) + .await + .unwrap(); + + let _ = register_temp_parquet_table("table2", schema2, batches2, &ctx) + .await + .unwrap(); + + ctx + } + + /// runs a sql query and returns the coordinator StageExec + async fn plan_sql(ctx: &SessionContext, sql: &str) -> StageExec { + let df = ctx.sql(sql).await.unwrap(); + let physical_distributed = df.create_physical_plan().await.unwrap(); + + let stage_exec = match physical_distributed.as_any().downcast_ref::() { + Some(stage_exec) => stage_exec.clone(), + None => panic!( + "expected StageExec from distributed optimization, got: {}", + physical_distributed.name() + ), + }; + stage_exec + } + + async fn execute_plan(stage_exec: &StageExec, ctx: &SessionContext) { + let task_ctx = ctx.task_ctx(); + let stream = stage_exec.execute(0, task_ctx).unwrap(); + + let mut stream = stream; + while let Some(batch) = stream.next().await { + batch.unwrap(); + } + } + + /// Asserts that we can collect metrics from a distributed plan generated from the + /// SQL query. It ensures that metrics are collected for all stages and are propagated + /// through network boundaries. + async fn run_metrics_collection_e2e_test(sql: &str) { + // Plan and execute the query + let ctx = make_test_ctx().await; + let stage_exec = plan_sql(&ctx, sql).await; + execute_plan(&stage_exec, &ctx).await; + + // Assert to ensure the distributed test case is sufficiently complex. + let (stages, expected_stage_keys) = get_stages_and_stage_keys(&stage_exec); + assert!( + expected_stage_keys.len() > 1, + "expected more than 1 stage key in test. the plan was not distributed):\n{}", + DisplayableExecutionPlan::new(&stage_exec).indent(true) + ); + + // Collect metrics for all tasks from the root StageExec. + let collector = TaskMetricsCollector::new(); + let result = collector.collect(stage_exec.plan.clone()).unwrap(); + let mut actual_collected_metrics = result.input_task_metrics; + actual_collected_metrics.insert( + StageKey { + query_id: stage_exec.query_id.to_string(), + stage_id: stage_exec.num as u64, + task_number: 0, + }, + result + .task_metrics + .iter() + .map(|m| df_metrics_set_to_proto(m).unwrap()) + .collect::>(), + ); + + // Ensure that there's metrics for each node for each task for each stage. + for expected_stage_key in expected_stage_keys { + // Get the collected metrics for this task. + let actual_metrics = actual_collected_metrics.get(&expected_stage_key).unwrap(); + + // Assert that there's metrics for each node in this task. + let stage = stages.get(&(expected_stage_key.stage_id as usize)).unwrap(); + assert_eq!(actual_metrics.len(), count_plan_nodes(&stage.plan)); + + // Ensure each node has at least one metric which was collected. + for metrics_set in actual_metrics.iter() { + let metrics_set = metrics_set_proto_to_df(metrics_set).unwrap(); + assert!(metrics_set.iter().count() > 0); + } + } + } + + #[tokio::test] + async fn test_metrics_collection_e2e_1() { + run_metrics_collection_e2e_test("SELECT id, COUNT(*) as count FROM table1 WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10").await; + } + + #[tokio::test] + async fn test_metrics_collection_e2e_2() { + run_metrics_collection_e2e_test( + "SELECT sum(balance) / 7.0 as avg_yearly + FROM table2 + WHERE name LIKE 'customer%' + AND balance < ( + SELECT 0.2 * avg(balance) + FROM table2 t2_inner + WHERE t2_inner.id = table2.id + )", + ) + .await; + } + + #[tokio::test] + async fn test_metrics_collection_e2e_3() { + run_metrics_collection_e2e_test( + "SELECT + substring(phone, 1, 2) as country_code, + count(*) as num_customers, + sum(balance) as total_balance + FROM table2 + WHERE substring(phone, 1, 2) IN ('13', '31', '23', '29', '30', '18') + AND balance > ( + SELECT avg(balance) + FROM table2 + WHERE balance > 0.00 + ) + GROUP BY substring(phone, 1, 2) + ORDER BY country_code", + ) + .await; + } +} diff --git a/src/metrics/task_metrics_rewriter.rs b/src/metrics/task_metrics_rewriter.rs new file mode 100644 index 0000000..529237e --- /dev/null +++ b/src/metrics/task_metrics_rewriter.rs @@ -0,0 +1,244 @@ +use crate::execution_plans::MetricsWrapperExec; +use crate::execution_plans::NetworkCoalesceExec; +use crate::execution_plans::NetworkShuffleExec; +use crate::metrics::proto::MetricsSetProto; +use crate::metrics::proto::metrics_set_proto_to_df; +use datafusion::common::tree_node::Transformed; +use datafusion::common::tree_node::TreeNode; +use datafusion::common::tree_node::TreeNodeRecursion; +use datafusion::common::tree_node::TreeNodeRewriter; +use datafusion::error::Result; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::internal_err; +use std::sync::Arc; + +/// TaskMetricsRewriter is used to enrich a task with metrics by re-writing the plan using [MetricsWrapperExec] nodes. +/// +/// Ex. for a plan with the form +/// AggregateExec +/// └── ProjectionExec +/// └── NetworkShuffleExec +/// +/// the task will be rewritten as +/// +/// MetricsWrapperExec (wrapped: AggregateExec) +/// └── MetricsWrapperExec (wrapped: ProjectionExec) +/// └── NetworkShuffleExec +/// (Note that the NetworkShuffleExec node is not wrapped) +pub struct TaskMetricsRewriter { + metrics: Vec, + idx: usize, +} + +impl TaskMetricsRewriter { + /// Create a new TaskMetricsRewriter. The provided metrics will be used to enrich the plan. + #[allow(dead_code)] + pub fn new(metrics: Vec) -> Self { + Self { metrics, idx: 0 } + } + + /// enrich_task_with_metrics rewrites the plan by wrapping nodes. If the length of the provided metrics set vec does not + /// match the number of nodes in the plan, an error will be returned. + #[allow(dead_code)] + pub fn enrich_task_with_metrics( + mut self, + plan: Arc, + ) -> Result> { + let transformed = plan.rewrite(&mut self)?; + if self.idx != self.metrics.len() { + return internal_err!( + "too many metrics sets provided to rewrite task: {} metrics sets provided, {} nodes in plan", + self.metrics.len(), + self.idx + ); + } + Ok(transformed.data) + } +} + +impl TreeNodeRewriter for TaskMetricsRewriter { + type Node = Arc; + + fn f_down(&mut self, plan: Self::Node) -> Result> { + if plan.as_any().is::() { + return Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)); + } + if plan.as_any().is::() { + return Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)); + } + if self.idx >= self.metrics.len() { + return internal_err!( + "not enough metrics provided to rewrite task: {} metrics provided", + self.metrics.len() + ); + } + let proto_metrics = &self.metrics[self.idx]; + + let wrapped_plan_node: Arc = Arc::new(MetricsWrapperExec::new( + plan.clone(), + metrics_set_proto_to_df(proto_metrics)?, + )); + let result = Transformed::new(wrapped_plan_node, true, TreeNodeRecursion::Continue); + self.idx += 1; + Ok(result) + } +} + +#[cfg(test)] +mod tests { + use crate::metrics::proto::MetricsSetProto; + use crate::metrics::proto::df_metrics_set_to_proto; + use crate::metrics::task_metrics_collector::TaskMetricsCollector; + use crate::metrics::task_metrics_rewriter::TaskMetricsRewriter; + use crate::test_utils::metrics::make_test_metrics_set_proto_from_seed; + use crate::test_utils::plans::count_plan_nodes; + use crate::test_utils::session_context::register_temp_parquet_table; + use datafusion::arrow::array::{Int32Array, StringArray}; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::arrow::record_batch::RecordBatch; + use datafusion::execution::SessionStateBuilder; + use datafusion::prelude::SessionConfig; + use datafusion::prelude::SessionContext; + use std::sync::Arc; + + /// Creates a non-distributed session context and registers two tables: + /// - table1 (id: int, name: string) + /// - table2 (id: int, name: string, phone: string, balance: float64) + async fn make_test_ctx() -> SessionContext { + let config = SessionConfig::new().with_target_partitions(2); + let state = SessionStateBuilder::new() + .with_default_features() + .with_config(config) + .build(); + let ctx = SessionContext::from(state); + + // Create test data for table1 + let schema1 = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + ])); + + let batches1 = vec![ + RecordBatch::try_new( + schema1.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ], + ) + .unwrap(), + ]; + + // Create test data for table2 with extended schema + let schema2 = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + Field::new("phone", DataType::Utf8, false), + Field::new("balance", DataType::Float64, false), + ])); + + let batches2 = vec![ + RecordBatch::try_new( + schema2.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec![ + "customer1", + "customer2", + "customer3", + ])), + Arc::new(StringArray::from(vec![ + "13-123-4567", + "31-456-7890", + "23-789-0123", + ])), + Arc::new(datafusion::arrow::array::Float64Array::from(vec![ + 100.5, 250.0, 50.25, + ])), + ], + ) + .unwrap(), + ]; + + // Register the test data as parquet tables + let _ = register_temp_parquet_table("table1", schema1, batches1, &ctx) + .await + .unwrap(); + + let _ = register_temp_parquet_table("table2", schema2, batches2, &ctx) + .await + .unwrap(); + + ctx + } + + /// Asserts that we successfully re-write the metrics of a plan generated from the provided SQL query. + /// Also asserts that the order which metrics are collected from a plan matches the order which + /// they are re-written (ie. ensures we don't assign metrics to the wrong nodes) + /// + /// Only tests single node plans since the [TaskMetricsRewriter] stops on [NetworkBoundary]. + async fn run_metrics_rewriter_test(sql: &str) { + // Generate the plan + let ctx = make_test_ctx().await; + let plan = ctx + .sql(sql) + .await + .unwrap() + .create_physical_plan() + .await + .unwrap(); + + // Generate metrics for each plan node. + let expected_metrics = (0..count_plan_nodes(&plan)) + .map(|i| make_test_metrics_set_proto_from_seed(i as u64 + 10)) + .collect::>(); + + // Rewrite the metrics. + let rewriter = TaskMetricsRewriter::new(expected_metrics.clone()); + let rewritten_plan = rewriter.enrich_task_with_metrics(plan.clone()).unwrap(); + + // Collect metrics + let actual_metrics = TaskMetricsCollector::new() + .collect(rewritten_plan) + .unwrap() + .task_metrics; + + // Assert that all the metrics are present and in the same order. + assert_eq!(actual_metrics.len(), expected_metrics.len()); + for (actual_metrics_set, expected_metrics_set) in actual_metrics + .iter() + .map(|m| df_metrics_set_to_proto(m).unwrap()) + .zip(expected_metrics) + { + assert_eq!(actual_metrics_set, expected_metrics_set); + } + } + + #[tokio::test] + async fn test_metrics_rewriter_1() { + run_metrics_rewriter_test( + "SELECT sum(balance) / 7.0 as avg_yearly from table2 group by name", + ) + .await; + } + + #[tokio::test] + async fn test_metrics_rewriter_2() { + run_metrics_rewriter_test("SELECT id, COUNT(*) as count FROM table1 WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10").await; + } + + #[tokio::test] + async fn test_metrics_rewriter_3() { + run_metrics_rewriter_test( + "SELECT sum(balance) / 7.0 as avg_yearly + FROM table2 + WHERE name LIKE 'customer%' + AND balance < ( + SELECT 0.2 * avg(balance) + FROM table2 t2_inner + WHERE t2_inner.id = table2.id + )", + ) + .await; + } +} diff --git a/src/test_utils/mod.rs b/src/test_utils/mod.rs index 140bd63..a5b3bec 100644 --- a/src/test_utils/mod.rs +++ b/src/test_utils/mod.rs @@ -4,5 +4,6 @@ pub mod localhost; pub mod metrics; pub mod mock_exec; pub mod parquet; +pub mod plans; pub mod session_context; pub mod tpch; diff --git a/src/test_utils/plans.rs b/src/test_utils/plans.rs new file mode 100644 index 0000000..a37622b --- /dev/null +++ b/src/test_utils/plans.rs @@ -0,0 +1,75 @@ +use std::sync::Arc; + +use datafusion::{ + common::{HashMap, HashSet}, + physical_plan::ExecutionPlan, +}; + +use crate::{ + StageExec, + execution_plans::{InputStage, NetworkCoalesceExec, NetworkShuffleExec}, + protobuf::StageKey, +}; + +/// count_plan_nodes counts the number of execution plan nodes in a plan using BFS traversal. +/// This does NOT traverse child stages, only the execution plan tree within this stage. +/// Excludes [NetworkBoundary] nodes from the count. +pub fn count_plan_nodes(plan: &Arc) -> usize { + let mut count = 0; + let mut queue = vec![plan]; + + while let Some(plan) = queue.pop() { + // Skip [NetworkBoundary] nodes from the count. + if !plan.as_any().is::() && !plan.as_any().is::() { + count += 1; + } + + // Add children to the queue for BFS traversal + for child in plan.children() { + queue.push(child); + } + } + count +} + +/// Returns +/// - a map of all stages +/// - a set of all the stage keys (one per task) +pub fn get_stages_and_stage_keys( + stage: &StageExec, +) -> (HashMap, HashSet) { + let query_id = stage.query_id; + let mut i = 0; + let mut queue = vec![stage]; + let mut stage_keys = HashSet::new(); + let mut stages_map = HashMap::new(); + + while i < queue.len() { + let stage = queue[i]; + stages_map.insert(stage.num, stage); + i += 1; + + // Add each task. + for j in 0..stage.tasks.len() { + let stage_key = StageKey { + query_id: query_id.to_string(), + stage_id: stage.num as u64, + task_number: j as u64, + }; + stage_keys.insert(stage_key); + } + + // Add any child stages + queue.extend( + stage + .input_stages_iter() + .map(|input_stage| match input_stage { + InputStage::Decoded(plan) => StageExec::from_dyn(plan), + InputStage::Encoded { .. } => { + unimplemented!(); + } + }), + ); + } + (stages_map, stage_keys) +}