From 4103bae06b27473ee00f39f9cffd3c7831077e02 Mon Sep 17 00:00:00 2001 From: Jayant Shrivastava Date: Thu, 25 Sep 2025 21:16:34 -0400 Subject: [PATCH 1/3] flight_service: emit metrics from ArrowFlightEndpoint This change updates the ArrowFlightEndpoint to collect metrics and emit them. When the last partition in a task is finished, the ArrowFlightEndpoint collects metrics and emits them via the TrailingFlightDataStream. Previously, we would determine if a partition is finished when the request first hit the endpoint. Now, we do it on stream completition. This is crutial for metrics collection because we need to know that the stream is exhausted, meaning that there's no data flowing in the plan and metrics are not actively being updated. Since the ArrowFlightEndpoint now emits metrics and NetworkBoundary plan nodes collect metrics, all coordinating StageExecs will now have the full collection of metrics for all tasks. This commit adds integration style tests that assert that the coordinator is recieving the full set of metrics. Follow up work - Only collect metrics if a configuration is set in the SessionContext, removing extra overhead - Display metrics in the plan using EXPLAIN (ANALYZE) - consider using sqllogictest or similar to test the output --- src/execution_plans/metrics.rs | 470 ++++++++++++------ .../metrics_collecting_stream.rs | 5 +- src/execution_plans/mod.rs | 1 + src/execution_plans/network_coalesce.rs | 11 +- src/execution_plans/network_shuffle.rs | 11 +- src/flight_service/do_get.rs | 73 ++- src/flight_service/mod.rs | 2 +- src/flight_service/service.rs | 4 +- .../trailing_flight_data_stream.rs | 26 +- src/metrics/proto.rs | 64 +-- src/test_utils/mod.rs | 1 + src/test_utils/plans.rs | 66 +++ 12 files changed, 516 insertions(+), 218 deletions(-) create mode 100644 src/test_utils/plans.rs diff --git a/src/execution_plans/metrics.rs b/src/execution_plans/metrics.rs index 4a1ff59..55560a1 100644 --- a/src/execution_plans/metrics.rs +++ b/src/execution_plans/metrics.rs @@ -1,18 +1,30 @@ use crate::execution_plans::{NetworkCoalesceExec, NetworkShuffleExec, StageExec}; use crate::metrics::proto::{MetricsSetProto, metrics_set_proto_to_df}; +use arrow::ipc::writer::DictionaryTracker; +use arrow::record_batch::RecordBatch; +use arrow_flight::FlightData; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::ipc::writer::IpcDataGenerator; +use datafusion::arrow::ipc::writer::IpcWriteOptions; +use datafusion::execution::TaskContext; +use datafusion::physical_plan::metrics::MetricsSet; +use futures::{Stream, stream}; +use std::collections::HashMap; +use std::sync::Arc; + +use crate::metrics::proto::df_metrics_set_to_proto; use crate::protobuf::StageKey; +use crate::protobuf::{AppMetadata, FlightAppMetadata, MetricsCollection, TaskMetrics}; +use arrow_flight::error::FlightError; use datafusion::common::internal_err; use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter}; use datafusion::error::{DataFusionError, Result}; use datafusion::execution::SendableRecordBatchStream; -use datafusion::execution::TaskContext; use datafusion::physical_plan::ExecutionPlan; -use datafusion::physical_plan::metrics::MetricsSet; use datafusion::physical_plan::{DisplayAs, DisplayFormatType, PlanProperties}; +use prost::Message; use std::any::Any; -use std::collections::HashMap; use std::fmt::{Debug, Formatter}; -use std::sync::Arc; /// TaskMetricsCollector is used to collect metrics from a task. It implements [TreeNodeRewriter]. /// Note: TaskMetricsCollector is not a [datafusion::physical_plan::ExecutionPlanVisitor] to keep @@ -29,9 +41,9 @@ pub struct TaskMetricsCollector { #[allow(dead_code)] pub struct MetricsCollectorResult { // metrics is a collection of metrics for a task ordered using a pre-order traversal of the task's plan. - task_metrics: Vec, + pub(super) task_metrics: Vec, // child_task_metrics contains metrics for child tasks if they were collected. - child_task_metrics: HashMap>, + pub(super) child_task_metrics: HashMap>, } impl TreeNodeRewriter for TaskMetricsCollector { @@ -104,13 +116,16 @@ impl TaskMetricsCollector { } } - /// collect metrics from a StageExec plan and any child tasks. + /// collect metrics from an [ExecutionPlan] (usually a [StageExec].plan) and any child tasks. /// Returns /// - a vec representing the metrics for the current task (ordered using a pre-order traversal) /// - a map representing the metrics for some subset of child tasks collected from NetworkShuffleExec leaves #[allow(dead_code)] - pub fn collect(mut self, stage: &StageExec) -> Result { - stage.plan.clone().rewrite(&mut self)?; + pub fn collect( + mut self, + plan: Arc, + ) -> Result { + plan.rewrite(&mut self)?; Ok(MetricsCollectorResult { task_metrics: self.task_metrics, child_task_metrics: self.child_task_metrics, @@ -270,20 +285,92 @@ impl ExecutionPlan for MetricsWrapperExec { } } +// Collects metrics from the provided stage and encodes it into a stream of flight data using +// the schema of the stage. +pub fn collect_and_create_metrics_flight_data( + stage_key: StageKey, + stage: Arc, +) -> Result> + Send + 'static, FlightError> { + // Get the metrics for the task executed on this worker. Separately, collect metrics for child tasks. + let mut result = TaskMetricsCollector::new() + .collect(stage.plan.clone()) + .map_err(|err| FlightError::ProtocolError(err.to_string()))?; + + // Add the metrics for this task into the collection of task metrics. + // Skip any metrics that can't be converted to proto (unsupported types) + let proto_task_metrics = result + .task_metrics + .iter() + .map(|metrics| { + df_metrics_set_to_proto(metrics) + .map_err(|err| FlightError::ProtocolError(err.to_string())) + }) + .collect::, FlightError>>()?; + result + .child_task_metrics + .insert(stage_key.clone(), proto_task_metrics.clone()); + + // Serialize the metrics for all tasks. + let mut task_metrics_set = vec![]; + for (stage_key, metrics) in result.child_task_metrics.into_iter() { + task_metrics_set.push(TaskMetrics { + stage_key: Some(stage_key), + metrics, + }); + } + + let flight_app_metadata = FlightAppMetadata { + content: Some(AppMetadata::MetricsCollection(MetricsCollection { + tasks: task_metrics_set, + })), + }; + + let metrics_flight_data = + empty_flight_data_with_app_metadata(flight_app_metadata, stage.plan.schema())?; + Ok(Box::pin(stream::once( + async move { Ok(metrics_flight_data) }, + ))) +} + +/// Creates a FlightData with the given app_metadata and empty RecordBatch using the provided schema. +/// We don't use [arrow_flight::encode::FlightDataEncoder] (and by extension, the [arrow_flight::encode::FlightDataEncoderBuilder]) +/// since they skip messages with empty RecordBatch data. +pub fn empty_flight_data_with_app_metadata( + metadata: FlightAppMetadata, + schema: SchemaRef, +) -> Result { + let mut buf = vec![]; + metadata + .encode(&mut buf) + .map_err(|err| FlightError::ProtocolError(err.to_string()))?; + + let empty_batch = RecordBatch::new_empty(schema); + let options = IpcWriteOptions::default(); + let data_gen = IpcDataGenerator::default(); + let mut dictionary_tracker = DictionaryTracker::new(true); + let (_, encoded_data) = data_gen + .encoded_batch(&empty_batch, &mut dictionary_tracker, &options) + .map_err(|e| { + FlightError::ProtocolError(format!("Failed to create empty batch FlightData: {e}")) + })?; + Ok(FlightData::from(encoded_data).with_app_metadata(buf)) +} + #[cfg(test)] mod tests { use super::*; use datafusion::arrow::array::{Int32Array, StringArray}; use datafusion::arrow::record_batch::RecordBatch; + use futures::StreamExt; use crate::DistributedExt; use crate::DistributedPhysicalOptimizerRule; use crate::test_utils::in_memory_channel_resolver::InMemoryChannelResolver; use crate::test_utils::metrics::make_test_metrics_set_proto_from_seed; + use crate::test_utils::plans::{count_plan_nodes, get_stages_and_stage_keys}; use crate::test_utils::session_context::register_temp_parquet_table; use datafusion::execution::{SessionStateBuilder, context::SessionContext}; - use datafusion::physical_plan::metrics::MetricValue; use datafusion::prelude::SessionConfig; use datafusion::{ arrow::datatypes::{DataType, Field, Schema}, @@ -291,198 +378,283 @@ mod tests { }; use std::sync::Arc; - /// Creates a stage with the following structure: - /// - /// SortPreservingMergeExec - /// SortExec - /// ProjectionExec - /// AggregateExec - /// CoalesceBatchesExec - /// NetworkShuffleExec - /// - /// ... (for the purposes of these tests, we don't care about child stages). - async fn make_test_stage_exec_with_5_nodes() -> (StageExec, SessionContext) { + /// Creates a single node session context + async fn make_test_ctx_single_node() -> SessionContext { + make_test_ctx_helper(false).await + } + + /// Creates a distributed session context with in-memory distributed engine + async fn make_test_ctx() -> SessionContext { + make_test_ctx_helper(true).await + } + + /// Creates a session context and registers two tables: + /// - table1 (id: int, name: string) + /// - table2 (id: int, name: string, phone: string, balance: float64) + async fn make_test_ctx_helper(distributed: bool) -> SessionContext { // Create distributed session state with in-memory channel resolver let config = SessionConfig::new().with_target_partitions(2); - let state = SessionStateBuilder::new() + let mut builder = SessionStateBuilder::new() .with_default_features() - .with_config(config) - .with_distributed_channel_resolver(InMemoryChannelResolver::new()) - .with_physical_optimizer_rule(Arc::new( - DistributedPhysicalOptimizerRule::default() - .with_network_coalesce_tasks(2) - .with_network_shuffle_tasks(2), - )) - .build(); + .with_config(config); + + if distributed { + builder = builder + .with_distributed_channel_resolver(InMemoryChannelResolver::new()) + .with_physical_optimizer_rule(Arc::new( + DistributedPhysicalOptimizerRule::default() + .with_network_coalesce_tasks(2) + .with_network_shuffle_tasks(2), + )) + } + let state = builder.build(); let ctx = SessionContext::from(state); - // Create test data - let schema = Arc::new(Schema::new(vec![ + // Create test data for table1 + let schema1 = Arc::new(Schema::new(vec![ Field::new("id", DataType::Int32, false), Field::new("name", DataType::Utf8, false), ])); - let batches = vec![ + let batches1 = vec![ RecordBatch::try_new( - schema.clone(), + schema1.clone(), vec![ Arc::new(Int32Array::from(vec![1, 2, 3])), Arc::new(StringArray::from(vec!["a", "b", "c"])), ], ) .unwrap(), + ]; + + // Create test data for table2 with extended schema + let schema2 = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + Field::new("phone", DataType::Utf8, false), + Field::new("balance", DataType::Float64, false), + ])); + + let batches2 = vec![ RecordBatch::try_new( - schema.clone(), + schema2.clone(), vec![ - Arc::new(Int32Array::from(vec![4, 5, 6])), - Arc::new(StringArray::from(vec!["d", "e", "f"])), + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec![ + "customer1", + "customer2", + "customer3", + ])), + Arc::new(StringArray::from(vec![ + "13-123-4567", + "31-456-7890", + "23-789-0123", + ])), + Arc::new(datafusion::arrow::array::Float64Array::from(vec![ + 100.5, 250.0, 50.25, + ])), ], ) .unwrap(), ]; - // Register the test data as a parquet table - let _ = register_temp_parquet_table("test_table", schema.clone(), batches, &ctx) + // Register the test data as parquet tables + let _ = register_temp_parquet_table("table1", schema1, batches1, &ctx) .await .unwrap(); - let df = ctx - .sql("SELECT id, COUNT(*) as count FROM test_table WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10") + let _ = register_temp_parquet_table("table2", schema2, batches2, &ctx) .await .unwrap(); + + ctx + } + + /// runs a sql query and returns the coordinator StageExec + async fn plan_sql(ctx: &SessionContext, sql: &str) -> StageExec { + let df = ctx.sql(sql).await.unwrap(); let physical_distributed = df.create_physical_plan().await.unwrap(); let stage_exec = match physical_distributed.as_any().downcast_ref::() { Some(stage_exec) => stage_exec.clone(), None => panic!( - "Expected StageExec from distributed optimization, got: {}", + "expected StageExec from distributed optimization, got: {}", physical_distributed.name() ), }; + stage_exec + } - (stage_exec, ctx) + async fn execute_plan(stage_exec: &StageExec, ctx: &SessionContext) { + let task_ctx = ctx.task_ctx(); + let stream = stage_exec.execute(0, task_ctx).unwrap(); + + let mut stream = stream; + while let Some(batch) = stream.next().await { + batch.unwrap(); + } } - #[tokio::test] - #[ignore] - async fn test_metrics_rewriter() { - let (test_stage, _ctx) = make_test_stage_exec_with_5_nodes().await; - let test_metrics_sets = (0..5) // 5 nodes excluding NetworkShuffleExec - .map(|i| make_test_metrics_set_proto_from_seed(i + 10)) - .collect::>(); + /// Asserts that we can collect metrics from a distributed plan generated from the + /// SQL query. It ensures that metrics are collected for all stages and are propagated + /// through network boundaries. + async fn run_metrics_collection_e2e_test(sql: &str) { + // Plan and execute the query + let ctx = make_test_ctx().await; + let stage_exec = plan_sql(&ctx, sql).await; + execute_plan(&stage_exec, &ctx).await; + + // Assert to ensure the distributed test case is sufficiently complex. + let (stages, expected_stage_keys) = get_stages_and_stage_keys(&stage_exec); + assert!( + expected_stage_keys.len() > 1, + "expected more than 1 stage key in test. the plan was not distributed):\n{}", + DisplayableExecutionPlan::new(&stage_exec).indent(true) + ); + + // Collect metrics for all tasks from the root StageExec. + let collector = TaskMetricsCollector::new(); + let result = collector.collect(stage_exec.plan.clone()).unwrap(); + let mut actual_collected_metrics = result.child_task_metrics; + actual_collected_metrics.insert( + StageKey { + query_id: stage_exec.query_id.to_string(), + stage_id: stage_exec.num as u64, + task_number: 0, + }, + result + .task_metrics + .iter() + .map(|m| df_metrics_set_to_proto(m).unwrap()) + .collect::>(), + ); + + // Ensure that there's metrics for each node for each task for each stage. + for expected_stage_key in expected_stage_keys { + // Get the collected metrics for this task. + let actual_metrics = actual_collected_metrics.get(&expected_stage_key).unwrap(); + + // Assert that there's metrics for each node in this task. + let stage = stages.get(&(expected_stage_key.stage_id as usize)).unwrap(); + assert_eq!(actual_metrics.len(), count_plan_nodes(&stage.plan)); + + // Ensure each node has at least one metric which was collected. + for metrics_set in actual_metrics.iter() { + let metrics_set = metrics_set_proto_to_df(metrics_set).unwrap(); + assert!(metrics_set.iter().count() > 0); + } + } + } - let rewriter = TaskMetricsRewriter::new(test_metrics_sets.clone()); - let plan_with_metrics = rewriter - .enrich_task_with_metrics(test_stage.plan.clone()) + /// Asserts that we successfully re-write the metrics of a plan generated from the provided SQL query. + /// Also asserts that the order which metrics are collected from a plan matches the order which + /// they are re-written (ie. ensures we don't assign metrics to the wrong nodes) + /// + /// Only tests single node plans since the [TaskMetricsRewriter] stops on [NetworkBoundary]. + async fn run_metrics_rewriter_test(sql: &str) { + // Generate the plan + let ctx = make_test_ctx_single_node().await; + let plan = ctx + .sql(sql) + .await + .unwrap() + .create_physical_plan() + .await .unwrap(); - let plan_str = - DisplayableExecutionPlan::with_full_metrics(plan_with_metrics.as_ref()).indent(true); - // Expected distributed plan output with metrics - let expected = [ - r"SortPreservingMergeExec: [id@0 ASC NULLS LAST], fetch=10, metrics=[output_rows=10, elapsed_compute=10ns, start_timestamp=2025-09-18 13:00:10 UTC, end_timestamp=2025-09-18 13:00:11 UTC]", - r" SortExec: TopK(fetch=10), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true], metrics=[output_rows=11, elapsed_compute=11ns, start_timestamp=2025-09-18 13:00:11 UTC, end_timestamp=2025-09-18 13:00:12 UTC]", - r" ProjectionExec: expr=[id@0 as id, count(Int64(1))@1 as count], metrics=[output_rows=12, elapsed_compute=12ns, start_timestamp=2025-09-18 13:00:12 UTC, end_timestamp=2025-09-18 13:00:13 UTC]", - r" AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[count(Int64(1))], metrics=[output_rows=13, elapsed_compute=13ns, start_timestamp=2025-09-18 13:00:13 UTC, end_timestamp=2025-09-18 13:00:14 UTC]", - r" CoalesceBatchesExec: target_batch_size=8192, metrics=[output_rows=14, elapsed_compute=14ns, start_timestamp=2025-09-18 13:00:14 UTC, end_timestamp=2025-09-18 13:00:15 UTC]", - r" NetworkShuffleExec, metrics=[]", - "" // trailing newline - ].join("\n"); - assert_eq!(expected, plan_str.to_string()); - } + // Generate metrics for each plan node. + let expected_metrics = (0..count_plan_nodes(&plan)) + .map(|i| make_test_metrics_set_proto_from_seed(i as u64 + 10)) + .collect::>(); - #[tokio::test] - #[ignore] - async fn test_metrics_rewriter_correct_number_of_metrics() { - let test_metrics_set = make_test_metrics_set_proto_from_seed(10); - let (executable_plan, _ctx) = make_test_stage_exec_with_5_nodes().await; - let task_plan = executable_plan - .as_any() - .downcast_ref::() + // Rewrite the metrics. + let rewriter = TaskMetricsRewriter::new(expected_metrics.clone()); + let rewritten_plan = rewriter.enrich_task_with_metrics(plan.clone()).unwrap(); + + // Collect metrics + let actual_metrics = TaskMetricsCollector::new() + .collect(rewritten_plan) .unwrap() - .plan - .clone(); - - // Too few metrics sets. - let rewriter = TaskMetricsRewriter::new(vec![test_metrics_set.clone()]); - let result = rewriter.enrich_task_with_metrics(task_plan.clone()); - assert!(result.is_err()); - - // Too many metrics sets. - let rewriter = TaskMetricsRewriter::new(vec![ - test_metrics_set.clone(), - test_metrics_set.clone(), - test_metrics_set.clone(), - test_metrics_set.clone(), - ]); - let result = rewriter.enrich_task_with_metrics(task_plan.clone()); - assert!(result.is_err()); + .task_metrics; + + // Assert that all the metrics are present and in the same order. + assert_eq!(actual_metrics.len(), expected_metrics.len()); + for (actual_metrics_set, expected_metrics_set) in actual_metrics + .iter() + .map(|m| df_metrics_set_to_proto(m).unwrap()) + .zip(expected_metrics) + { + assert_eq!(actual_metrics_set, expected_metrics_set); + } } #[tokio::test] - #[ignore] - async fn test_metrics_collection() { - let (stage_exec, ctx) = make_test_stage_exec_with_5_nodes().await; + async fn test_metrics_rewriter_1() { + run_metrics_rewriter_test( + "SELECT sum(balance) / 7.0 as avg_yearly from table2 group by name", + ) + .await; + } - // Execute the plan to completion. - let task_ctx = ctx.task_ctx(); - let stream = stage_exec.execute(0, task_ctx).unwrap(); + #[tokio::test] + async fn test_metrics_rewriter_2() { + run_metrics_rewriter_test("SELECT id, COUNT(*) as count FROM table1 WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10").await; + } - use futures::StreamExt; - let mut stream = stream; - while let Some(_batch) = stream.next().await {} + #[tokio::test] + async fn test_metrics_rewriter_3() { + run_metrics_rewriter_test( + "SELECT sum(balance) / 7.0 as avg_yearly + FROM table2 + WHERE name LIKE 'customer%' + AND balance < ( + SELECT 0.2 * avg(balance) + FROM table2 t2_inner + WHERE t2_inner.id = table2.id + )", + ) + .await; + } - let collector = TaskMetricsCollector::new(); - let result = collector.collect(&stage_exec).unwrap(); - - // With the distributed optimizer, we get a much more complex plan structure - // The exact number of metrics sets depends on the plan optimization, so be flexible - assert_eq!(result.task_metrics.len(), 5); - - let expected_metrics_count = [4, 10, 8, 16, 8]; - for (node_idx, metrics_set) in result.task_metrics.iter().enumerate() { - let metrics_count = metrics_set.iter().count(); - assert_eq!(metrics_count, expected_metrics_count[node_idx]); - - // Each node should have basic metrics: ElapsedCompute, OutputRows, StartTimestamp, EndTimestamp. - let mut has_start_timestamp = false; - let mut has_end_timestamp = false; - let mut has_elapsed_compute = false; - let mut has_output_rows = false; - - for metric in metrics_set.iter() { - let metric_name = metric.value().name(); - let metric_value = metric.value(); - - match metric_value { - MetricValue::StartTimestamp(_) if metric_name == "start_timestamp" => { - has_start_timestamp = true; - } - MetricValue::EndTimestamp(_) if metric_name == "end_timestamp" => { - has_end_timestamp = true; - } - MetricValue::ElapsedCompute(_) if metric_name == "elapsed_compute" => { - has_elapsed_compute = true; - } - MetricValue::OutputRows(_) if metric_name == "output_rows" => { - has_output_rows = true; - } - _ => { - // Other metrics are fine, we just validate the core ones - } - } - } + #[tokio::test] + async fn test_metrics_collection_e2e_1() { + run_metrics_collection_e2e_test("SELECT id, COUNT(*) as count FROM table1 WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10").await; + } - // Each node should have the four basic metrics - assert!(has_start_timestamp); - assert!(has_end_timestamp); - assert!(has_elapsed_compute); - assert!(has_output_rows); - } + #[tokio::test] + async fn test_metrics_collection_e2e_2() { + run_metrics_collection_e2e_test( + "SELECT sum(balance) / 7.0 as avg_yearly + FROM table2 + WHERE name LIKE 'customer%' + AND balance < ( + SELECT 0.2 * avg(balance) + FROM table2 t2_inner + WHERE t2_inner.id = table2.id + )", + ) + .await; + } - // TODO: once we propagate metrics from child stages, we can assert this. - assert_eq!(0, result.child_task_metrics.len()); + #[tokio::test] + async fn test_metrics_collection_e2e_3() { + run_metrics_collection_e2e_test( + "SELECT + substring(phone, 1, 2) as country_code, + count(*) as num_customers, + sum(balance) as total_balance + FROM table2 + WHERE substring(phone, 1, 2) IN ('13', '31', '23', '29', '30', '18') + AND balance > ( + SELECT avg(balance) + FROM table2 + WHERE balance > 0.00 + ) + GROUP BY substring(phone, 1, 2) + ORDER BY country_code", + ) + .await; } } diff --git a/src/execution_plans/metrics_collecting_stream.rs b/src/execution_plans/metrics_collecting_stream.rs index 1321f63..d8379b6 100644 --- a/src/execution_plans/metrics_collecting_stream.rs +++ b/src/execution_plans/metrics_collecting_stream.rs @@ -67,8 +67,8 @@ where }; metrics_collection.insert(stage_key, task_metrics.metrics); } - flight_data.app_metadata.clear(); + Ok(()) } } @@ -256,7 +256,8 @@ mod tests { // Create a stream that emits an error - should be propagated through let stream_error = FlightError::ProtocolError("stream error from inner stream".to_string()); let error_stream = stream::iter(vec![Err(stream_error)]); - let mut collecting_stream = MetricsCollectingStream::new(error_stream, metrics_collection); + let mut collecting_stream = + MetricsCollectingStream::new(error_stream, metrics_collection.clone()); let result = collecting_stream.next().await.unwrap(); assert_protocol_error(result, "stream error from inner stream"); diff --git a/src/execution_plans/mod.rs b/src/execution_plans/mod.rs index 37e13e3..48fe68b 100644 --- a/src/execution_plans/mod.rs +++ b/src/execution_plans/mod.rs @@ -5,6 +5,7 @@ mod network_shuffle; mod partition_isolator; mod stage; +pub use metrics::collect_and_create_metrics_flight_data; pub use network_coalesce::{NetworkCoalesceExec, NetworkCoalesceReady}; pub use network_shuffle::{NetworkShuffleExec, NetworkShuffleReadyExec}; pub use partition_isolator::PartitionIsolatorExec; diff --git a/src/execution_plans/network_coalesce.rs b/src/execution_plans/network_coalesce.rs index be9a5fe..2d88113 100644 --- a/src/execution_plans/network_coalesce.rs +++ b/src/execution_plans/network_coalesce.rs @@ -3,6 +3,7 @@ use crate::channel_resolver_ext::get_distributed_channel_resolver; use crate::common::scale_partitioning_props; use crate::config_extension_ext::ContextGrpcMetadata; use crate::distributed_physical_optimizer_rule::{NetworkBoundary, limit_tasks_err}; +use crate::execution_plans::metrics_collecting_stream::MetricsCollectingStream; use crate::execution_plans::{DistributedTaskContext, StageExec}; use crate::flight_service::DoGet; use crate::metrics::proto::MetricsSetProto; @@ -297,6 +298,7 @@ impl ExecutionPlan for NetworkCoalesceExec { return internal_err!("NetworkCoalesceExec: task is unassigned, cannot proceed"); }; + let metrics_collection_capture = self_ready.metrics_collection.clone(); let stream = async move { let channel = channel_resolver.get_channel_for_url(&url).await?; let stream = FlightServiceClient::new(channel) @@ -306,8 +308,13 @@ impl ExecutionPlan for NetworkCoalesceExec { .into_inner() .map_err(|err| FlightError::Tonic(Box::new(err))); - Ok(FlightRecordBatchStream::new_from_flight_data(stream) - .map_err(map_flight_to_datafusion_error)) + let metrics_collecting_stream = + MetricsCollectingStream::new(stream, metrics_collection_capture); + + Ok( + FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream) + .map_err(map_flight_to_datafusion_error), + ) } .try_flatten_stream(); diff --git a/src/execution_plans/network_shuffle.rs b/src/execution_plans/network_shuffle.rs index eed04b7..9e531bf 100644 --- a/src/execution_plans/network_shuffle.rs +++ b/src/execution_plans/network_shuffle.rs @@ -3,6 +3,7 @@ use crate::channel_resolver_ext::get_distributed_channel_resolver; use crate::common::scale_partitioning; use crate::config_extension_ext::ContextGrpcMetadata; use crate::distributed_physical_optimizer_rule::NetworkBoundary; +use crate::execution_plans::metrics_collecting_stream::MetricsCollectingStream; use crate::execution_plans::{DistributedTaskContext, StageExec}; use crate::flight_service::DoGet; use crate::metrics::proto::MetricsSetProto; @@ -330,6 +331,7 @@ impl ExecutionPlan for NetworkShuffleExec { }, ); + let metrics_collection_capture = self_ready.metrics_collection.clone(); async move { let url = task.url.ok_or(internal_datafusion_err!( "NetworkShuffleExec: task is unassigned, cannot proceed" @@ -343,8 +345,13 @@ impl ExecutionPlan for NetworkShuffleExec { .into_inner() .map_err(|err| FlightError::Tonic(Box::new(err))); - Ok(FlightRecordBatchStream::new_from_flight_data(stream) - .map_err(map_flight_to_datafusion_error)) + let metrics_collecting_stream = + MetricsCollectingStream::new(stream, metrics_collection_capture); + + Ok( + FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream) + .map_err(map_flight_to_datafusion_error), + ) } .try_flatten_stream() .boxed() diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index b9701ff..47add9b 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -1,7 +1,10 @@ use crate::config_extension_ext::ContextGrpcMetadata; -use crate::execution_plans::{DistributedTaskContext, StageExec}; +use crate::execution_plans::{ + DistributedTaskContext, StageExec, collect_and_create_metrics_flight_data, +}; use crate::flight_service::service::ArrowFlightEndpoint; use crate::flight_service::session_builder::DistributedSessionBuilderContext; +use crate::flight_service::trailing_flight_data_stream::TrailingFlightDataStream; use crate::protobuf::{ DistributedCodec, StageExecProto, StageKey, datafusion_error_to_tonic_status, stage_from_proto, }; @@ -94,12 +97,6 @@ impl ArrowFlightEndpoint { }) .await?; let stage = Arc::clone(&stage_data.stage); - let num_partitions_remaining = Arc::clone(&stage_data.num_partitions_remaining); - - // If all the partitions are done, remove the stage from the cache. - if num_partitions_remaining.fetch_sub(1, Ordering::SeqCst) <= 1 { - self.task_data_entries.remove(key); - } // Find out which partition group we are executing let cfg = session_state.config_mut(); @@ -126,7 +123,16 @@ impl ArrowFlightEndpoint { .execute(doget.target_partition as usize, session_state.task_ctx()) .map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?; - Ok(record_batch_stream_to_response(stream)) + let task_data_capture = self.task_data_entries.clone(); + Ok(flight_stream_from_record_batch_stream( + key.clone(), + stage, + stage_data.clone(), + move || { + task_data_capture.remove(key.clone()); + }, + stream, + )) } } @@ -134,7 +140,13 @@ fn missing(field: &'static str) -> impl FnOnce() -> Status { move || Status::invalid_argument(format!("Missing field '{field}'")) } -fn record_batch_stream_to_response( +// Creates a tonic response from a stream of record batches. Handles +// - RecordBatch to flight conversion partition tracking, stage eviction, and trailing metrics. +fn flight_stream_from_record_batch_stream( + stage_key: StageKey, + stage: Arc, + stage_data: TaskData, + evict_stage: impl FnOnce() + Send + 'static, stream: SendableRecordBatchStream, ) -> Response<::DoGetStream> { let flight_data_stream = @@ -144,7 +156,31 @@ fn record_batch_stream_to_response( FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err))) })); - Response::new(Box::pin(flight_data_stream.map_err(|err| match err { + let trailing_metrics_stream = TrailingFlightDataStream::new( + move || { + if stage_data + .num_partitions_remaining + .fetch_sub(1, Ordering::SeqCst) + == 1 + { + evict_stage(); + + let metrics_stream = collect_and_create_metrics_flight_data(stage_key, stage) + .map_err(|err| { + Status::internal(format!( + "error collecting metrics in arrow flight endpoint: {err}" + )) + })?; + + return Ok(Some(metrics_stream)); + } + + Ok(None) + }, + flight_data_stream, + ); + + Response::new(Box::pin(trailing_metrics_stream.map_err(|err| match err { FlightError::Tonic(status) => *status, _ => Status::internal(format!("Error during flight stream: {err}")), }))) @@ -215,9 +251,9 @@ mod tests { let stage_proto = proto_from_stage(&stage, &DefaultPhysicalExtensionCodec {}).unwrap(); let stage_proto_for_closure = stage_proto.clone(); let endpoint_ref = &endpoint; + let do_get = async move |partition: u64, task_number: u64, stage_key: StageKey| { let stage_proto = stage_proto_for_closure.clone(); - // Create DoGet message let doget = DoGet { stage_proto: Some(stage_proto), target_task_index: task_number, @@ -225,14 +261,17 @@ mod tests { stage_key: Some(stage_key), }; - // Create Flight ticket let ticket = Ticket { ticket: Bytes::from(doget.encode_to_vec()), }; - // Call the actual get() method let request = Request::new(ticket); - endpoint_ref.get(request).await + let response = endpoint_ref.get(request).await?; + let mut stream = response.into_inner(); + + // Consume the stream. + while let Some(_flight_data) = stream.try_next().await? {} + Ok::<(), Status>(()) }; // For each task, call do_get() for each partition except the last. @@ -248,7 +287,7 @@ mod tests { // Run the last partition of task 0. Any partition number works. Verify that the task state // is evicted because all partitions have been processed. - let result = do_get(1, 0, task_keys[0].clone()).await; + let result = do_get(2, 0, task_keys[0].clone()).await; assert!(result.is_ok()); let stored_stage_keys = endpoint.task_data_entries.keys().collect::>(); assert_eq!(stored_stage_keys.len(), 2); @@ -256,14 +295,14 @@ mod tests { assert!(stored_stage_keys.contains(&task_keys[2])); // Run the last partition of task 1. - let result = do_get(1, 1, task_keys[1].clone()).await; + let result = do_get(2, 1, task_keys[1].clone()).await; assert!(result.is_ok()); let stored_stage_keys = endpoint.task_data_entries.keys().collect::>(); assert_eq!(stored_stage_keys.len(), 1); assert!(stored_stage_keys.contains(&task_keys[2])); // Run the last partition of the last task. - let result = do_get(1, 2, task_keys[2].clone()).await; + let result = do_get(2, 2, task_keys[2].clone()).await; assert!(result.is_ok()); let stored_stage_keys = endpoint.task_data_entries.keys().collect::>(); assert_eq!(stored_stage_keys.len(), 0); diff --git a/src/flight_service/mod.rs b/src/flight_service/mod.rs index fdc99c4..db3bd91 100644 --- a/src/flight_service/mod.rs +++ b/src/flight_service/mod.rs @@ -1,7 +1,7 @@ mod do_get; mod service; mod session_builder; -mod trailing_flight_data_stream; +pub(super) mod trailing_flight_data_stream; pub(crate) use do_get::DoGet; pub use service::ArrowFlightEndpoint; diff --git a/src/flight_service/service.rs b/src/flight_service/service.rs index 675cd51..df01be7 100644 --- a/src/flight_service/service.rs +++ b/src/flight_service/service.rs @@ -17,7 +17,7 @@ use tonic::{Request, Response, Status, Streaming}; pub struct ArrowFlightEndpoint { pub(super) runtime: Arc, - pub(super) task_data_entries: TTLMap>>, + pub(super) task_data_entries: Arc>>>, pub(super) session_builder: Arc, } @@ -28,7 +28,7 @@ impl ArrowFlightEndpoint { let ttl_map = TTLMap::try_new(TTLMapConfig::default())?; Ok(Self { runtime: Arc::new(RuntimeEnv::default()), - task_data_entries: ttl_map, + task_data_entries: Arc::new(ttl_map), session_builder: Arc::new(session_builder), }) } diff --git a/src/flight_service/trailing_flight_data_stream.rs b/src/flight_service/trailing_flight_data_stream.rs index db85903..ebdf53e 100644 --- a/src/flight_service/trailing_flight_data_stream.rs +++ b/src/flight_service/trailing_flight_data_stream.rs @@ -8,22 +8,24 @@ use tokio::pin; /// TrailingFlightDataStream - wraps a FlightData stream. It calls the `on_complete` closure when the stream is finished. /// If the closure returns a new stream, it will be appended to the original stream and consumed. #[pin_project] -pub struct TrailingFlightDataStream +pub struct TrailingFlightDataStream where S: Stream> + Send, - F: FnOnce() -> Result, FlightError>, + T: Stream> + Send, + F: FnOnce() -> Result, FlightError>, { #[pin] inner: S, on_complete: Option, #[pin] - trailing_stream: Option, + trailing_stream: Option, } -impl TrailingFlightDataStream +impl TrailingFlightDataStream where S: Stream> + Send, - F: FnOnce() -> Result, FlightError>, + T: Stream> + Send, + F: FnOnce() -> Result, FlightError>, { // TODO: remove #[allow(dead_code)] @@ -36,10 +38,11 @@ where } } -impl Stream for TrailingFlightDataStream +impl Stream for TrailingFlightDataStream where S: Stream> + Send, - F: FnOnce() -> Result, FlightError>, + T: Stream> + Send, + F: FnOnce() -> Result, FlightError>, { type Item = Result; @@ -74,7 +77,7 @@ mod tests { use arrow::record_batch::RecordBatch; use arrow_flight::FlightData; use arrow_flight::decode::FlightRecordBatchStream; - use arrow_flight::encode::FlightDataEncoderBuilder; + use arrow_flight::encode::{FlightDataEncoder, FlightDataEncoderBuilder}; use futures::stream::{self, StreamExt}; use std::sync::Arc; @@ -186,7 +189,7 @@ mod tests { )))), ]; let inner_stream = stream::iter(data); - let on_complete = || Ok(None); + let on_complete = || Ok(None::); let trailing_stream = TrailingFlightDataStream::new(on_complete, inner_stream); let record_batches = FlightRecordBatchStream::new_from_flight_data(trailing_stream) .collect::>>() @@ -202,8 +205,7 @@ mod tests { let name_array = StringArray::from(vec!["item1"]); let value_array = Int32Array::from(vec![1]); let inner_stream = create_flight_data_stream(name_array, value_array); - - let on_complete = || -> Result, FlightError> { + let on_complete = || -> Result, FlightError> { Err(FlightError::ExternalError(Box::new(std::io::Error::new( std::io::ErrorKind::Other, "callback error", @@ -225,7 +227,7 @@ mod tests { StringArray::from(vec!["item1"] as Vec<&str>), Int32Array::from(vec![1] as Vec), ); - let on_complete = || Ok(None); + let on_complete = || Ok(None::); let trailing_stream = TrailingFlightDataStream::new(on_complete, inner_stream); let record_batches = FlightRecordBatchStream::new_from_flight_data(trailing_stream) .collect::>>() diff --git a/src/metrics/proto.rs b/src/metrics/proto.rs index 4222d53..0206bc3 100644 --- a/src/metrics/proto.rs +++ b/src/metrics/proto.rs @@ -380,30 +380,28 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::StartTimestamp(start_ts)) => match start_ts.value { - Some(value) => { - let timestamp = Timestamp::new(); + Some(MetricValueProto::StartTimestamp(start_ts)) => { + let timestamp = Timestamp::new(); + if let Some(value) = start_ts.value { timestamp.set(DateTime::from_timestamp_nanos(value)); - Ok(Arc::new(Metric::new_with_labels( - MetricValue::StartTimestamp(timestamp), - partition, - labels, - ))) } - None => internal_err!("encountered invalid start timestamp metric with no value"), - }, - Some(MetricValueProto::EndTimestamp(end_ts)) => match end_ts.value { - Some(value) => { - let timestamp = Timestamp::new(); + Ok(Arc::new(Metric::new_with_labels( + MetricValue::StartTimestamp(timestamp), + partition, + labels, + ))) + } + Some(MetricValueProto::EndTimestamp(end_ts)) => { + let timestamp = Timestamp::new(); + if let Some(value) = end_ts.value { timestamp.set(DateTime::from_timestamp_nanos(value)); - Ok(Arc::new(Metric::new_with_labels( - MetricValue::EndTimestamp(timestamp), - partition, - labels, - ))) } - None => internal_err!("encountered invalid end timestamp metric with no value"), - }, + Ok(Arc::new(Metric::new_with_labels( + MetricValue::EndTimestamp(timestamp), + partition, + labels, + ))) + } None => internal_err!("proto metric is missing the metric field"), } } @@ -853,18 +851,22 @@ mod tests { } #[test] - fn test_invalid_proto_timestamp_error() { - // Create a MetricProto with EndTimestamp that has no value (None) - let invalid_end_timestamp_proto = MetricProto { - metric: Some(MetricValueProto::EndTimestamp(EndTimestamp { value: None })), - labels: vec![], - partition: Some(0), - }; - - let result = metric_proto_to_df(invalid_end_timestamp_proto); + fn test_default_timestamp_roundtrip() { + let default_timestamp = Timestamp::default(); + let metric_with_default_timestamp = + Metric::new(MetricValue::EndTimestamp(default_timestamp), Some(0)); + + let proto_result = df_metric_to_proto(Arc::new(metric_with_default_timestamp)); + assert!( + proto_result.is_ok(), + "should successfully convert default timestamp to proto" + ); + + let proto_metric = proto_result.unwrap(); + let roundtrip_result = metric_proto_to_df(proto_metric); assert!( - result.is_err(), - "should return error for invalid end timestamp with no value" + roundtrip_result.is_ok(), + "should successfully roundtrip default timestamp" ); } } diff --git a/src/test_utils/mod.rs b/src/test_utils/mod.rs index 140bd63..a5b3bec 100644 --- a/src/test_utils/mod.rs +++ b/src/test_utils/mod.rs @@ -4,5 +4,6 @@ pub mod localhost; pub mod metrics; pub mod mock_exec; pub mod parquet; +pub mod plans; pub mod session_context; pub mod tpch; diff --git a/src/test_utils/plans.rs b/src/test_utils/plans.rs new file mode 100644 index 0000000..5bb7bce --- /dev/null +++ b/src/test_utils/plans.rs @@ -0,0 +1,66 @@ +use std::sync::Arc; + +use datafusion::{ + common::{HashMap, HashSet}, + physical_plan::ExecutionPlan, +}; + +use crate::{ + StageExec, + execution_plans::{NetworkCoalesceExec, NetworkShuffleExec}, + protobuf::StageKey, +}; + +/// count_plan_nodes counts the number of execution plan nodes in a plan using BFS traversal. +/// This does NOT traverse child stages, only the execution plan tree within this stage. +/// Excludes [NetworkBoundary] nodes from the count. +pub fn count_plan_nodes(plan: &Arc) -> usize { + let mut count = 0; + let mut queue = vec![plan]; + + while let Some(plan) = queue.pop() { + // Skip [NetworkBoundary] nodes from the count. + if !plan.as_any().is::() && !plan.as_any().is::() { + count += 1; + } + + // Add children to the queue for BFS traversal + for child in plan.children() { + queue.push(child); + } + } + count +} + +/// Returns +/// - a map of all stages +/// - a set of all the stage keys (one per task) +pub fn get_stages_and_stage_keys( + stage: &StageExec, +) -> (HashMap, HashSet) { + let query_id = stage.query_id; + let mut i = 0; + let mut queue = vec![stage]; + let mut stage_keys = HashSet::new(); + let mut stages_map = HashMap::new(); + + while i < queue.len() { + let stage = queue[i]; + stages_map.insert(stage.num, stage); + i += 1; + + // Add each task. + for j in 0..stage.tasks.len() { + let stage_key = StageKey { + query_id: query_id.to_string(), + stage_id: stage.num as u64, + task_number: j as u64, + }; + stage_keys.insert(stage_key); + } + + // Add any child stages + queue.extend(stage.child_stages_iter()); + } + (stages_map, stage_keys) +} From cbfa6da4bc25b817d1b713ff89f6c2c71127f0fb Mon Sep 17 00:00:00 2001 From: Jayant Shrivastava Date: Mon, 29 Sep 2025 13:17:46 -0400 Subject: [PATCH 2/3] distributed_physical_optimizer_rule: move metrics_collection to NetworkBoundary trait This is a small refactor which moves metrics_collection to the NetworkBoundary trait so all network boundaries must obey the metrics collecting protocol. --- src/distributed_physical_optimizer_rule.rs | 17 +++++++++++++ src/execution_plans/metrics.rs | 29 +++++++++++----------- src/execution_plans/network_coalesce.rs | 23 ++++++++--------- src/execution_plans/network_shuffle.rs | 23 ++++++++--------- src/protobuf/distributed_codec.rs | 4 +-- 5 files changed, 55 insertions(+), 41 deletions(-) diff --git a/src/distributed_physical_optimizer_rule.rs b/src/distributed_physical_optimizer_rule.rs index 17935bf..2f89f6f 100644 --- a/src/distributed_physical_optimizer_rule.rs +++ b/src/distributed_physical_optimizer_rule.rs @@ -1,5 +1,8 @@ use super::{NetworkShuffleExec, PartitionIsolatorExec, StageExec}; use crate::execution_plans::NetworkCoalesceExec; +use crate::metrics::proto::MetricsSetProto; +use crate::protobuf::StageKey; +use dashmap::DashMap; use datafusion::common::plan_err; use datafusion::common::tree_node::TreeNodeRecursion; use datafusion::datasource::source::DataSourceExec; @@ -318,6 +321,20 @@ pub trait NetworkBoundary: ExecutionPlan { } Ok(Arc::clone(children.first().unwrap())) } + + /// metrics_collection is used to collect metrics from child tasks. It is empty when a + /// [NetworkBoundary] is instantiated (deserialized, created via new() etc...). + /// Metrics are populated by executing() the [NetworkBoundary]. It's expected that the + /// collection is complete after the [NetworkBoundary] has been executed. It is undefined + /// what this returns during execution. + /// + /// An instance may receive metrics for 0 to N child tasks, where N is the number of tasks + /// in the stage it is reading from. This is because, by convention, the ArrowFlightEndpoint + /// sends metrics for a task to the last [NetworkBoundary] to read from it, which may or may + /// not be this instance. + fn metrics_collection(&self) -> Option>>> { + None + } } /// Error thrown during distributed planning that prompts the planner to change something and diff --git a/src/execution_plans/metrics.rs b/src/execution_plans/metrics.rs index 55560a1..529c85b 100644 --- a/src/execution_plans/metrics.rs +++ b/src/execution_plans/metrics.rs @@ -1,3 +1,4 @@ +use crate::distributed_physical_optimizer_rule::NetworkBoundary; use crate::execution_plans::{NetworkCoalesceExec, NetworkShuffleExec, StageExec}; use crate::metrics::proto::{MetricsSetProto, metrics_set_proto_to_df}; use arrow::ipc::writer::DictionaryTracker; @@ -50,26 +51,24 @@ impl TreeNodeRewriter for TaskMetricsCollector { type Node = Arc; fn f_down(&mut self, plan: Self::Node) -> Result> { - // If the plan is an NetworkShuffleExec, assume it has collected metrics already + // If the plan is a NetwordBoundary, assume it has collected metrics already // from child tasks. let metrics_collection = if let Some(node) = plan.as_any().downcast_ref::() { - let NetworkShuffleExec::Ready(ready) = node else { - return internal_err!( - "unexpected NetworkShuffleExec::Pending during metrics collection" - ); - }; - Some(Arc::clone(&ready.metrics_collection)) + node.metrics_collection() + .map(Some) + .ok_or(DataFusionError::Internal( + "could not collect metrics from NetworkShuffleExec".to_string(), + )) } else if let Some(node) = plan.as_any().downcast_ref::() { - let NetworkCoalesceExec::Ready(ready) = node else { - return internal_err!( - "unexpected NetworkCoalesceExec::Pending during metrics collection" - ); - }; - Some(Arc::clone(&ready.metrics_collection)) + node.metrics_collection() + .map(Some) + .ok_or(DataFusionError::Internal( + "could not collect metrics from NetworkCoalesceExec".to_string(), + )) } else { - None - }; + Ok(None) + }?; if let Some(metrics_collection) = metrics_collection { for mut entry in metrics_collection.iter_mut() { diff --git a/src/execution_plans/network_coalesce.rs b/src/execution_plans/network_coalesce.rs index 2d88113..489d184 100644 --- a/src/execution_plans/network_coalesce.rs +++ b/src/execution_plans/network_coalesce.rs @@ -90,15 +90,7 @@ pub struct NetworkCoalesceReady { pub(crate) properties: PlanProperties, pub(crate) stage_num: usize, pub(crate) input_tasks: usize, - /// metrics_collection is used to collect metrics from child tasks. It is empty when an - /// is instantiated (deserialized, created via [NetworkCoalesceExec::new_ready] etc...). - /// Metrics are populated in this map via [NetworkCoalesceExec::execute]. - /// - /// An instance may receive metrics for 0 to N child tasks, where N is the number of tasks in - /// the stage it is reading from. This is because, by convention, the ArrowFlightEndpoint - /// sends metrics for a task to the last NetworkCoalesceExec to read from it, which may or may - /// not be this instance. - pub(crate) metrics_collection: Arc>>, + pub(crate) child_task_metrics: Arc>>, } impl NetworkCoalesceExec { @@ -164,7 +156,7 @@ impl NetworkBoundary for NetworkCoalesceExec { }), stage_num, input_tasks: pending.input_tasks, - metrics_collection: Default::default(), + child_task_metrics: Default::default(), }; Ok(Arc::new(Self::Ready(ready))) @@ -185,10 +177,17 @@ impl NetworkBoundary for NetworkCoalesceExec { }), stage_num: ready.stage_num, input_tasks, - metrics_collection: Arc::clone(&ready.metrics_collection), + child_task_metrics: Arc::clone(&ready.child_task_metrics), }), }) } + + fn metrics_collection(&self) -> Option>>> { + match self { + NetworkCoalesceExec::Pending(_) => None, + NetworkCoalesceExec::Ready(v) => Some(v.child_task_metrics.clone()), + } + } } impl DisplayAs for NetworkCoalesceExec { @@ -298,7 +297,7 @@ impl ExecutionPlan for NetworkCoalesceExec { return internal_err!("NetworkCoalesceExec: task is unassigned, cannot proceed"); }; - let metrics_collection_capture = self_ready.metrics_collection.clone(); + let metrics_collection_capture = self_ready.child_task_metrics.clone(); let stream = async move { let channel = channel_resolver.get_channel_for_url(&url).await?; let stream = FlightServiceClient::new(channel) diff --git a/src/execution_plans/network_shuffle.rs b/src/execution_plans/network_shuffle.rs index 9e531bf..71f82d4 100644 --- a/src/execution_plans/network_shuffle.rs +++ b/src/execution_plans/network_shuffle.rs @@ -136,15 +136,7 @@ pub struct NetworkShuffleReadyExec { /// the properties we advertise for this execution plan pub(crate) properties: PlanProperties, pub(crate) stage_num: usize, - /// metrics_collection is used to collect metrics from child tasks. It is empty when an - /// is instantiated (deserialized, created via [NetworkShuffleExec::new_ready] etc...). - /// Metrics are populated in this map via [NetworkShuffleExec::execute]. - /// - /// An instance may receive metrics for 0 to N child tasks, where N is the number of tasks in - /// the stage it is reading from. This is because, by convention, the ArrowFlightEndpoint - /// sends metrics for a task to the last NetworkShuffleExec to read from it, which may or may - /// not be this instance. - pub(crate) metrics_collection: Arc>>, + pub(crate) child_task_metrics: Arc>>, } impl NetworkShuffleExec { @@ -209,7 +201,7 @@ impl NetworkBoundary for NetworkShuffleExec { NetworkShuffleExec::Ready(prev) => NetworkShuffleExec::Ready(NetworkShuffleReadyExec { properties: prev.properties.clone(), stage_num: prev.stage_num, - metrics_collection: Arc::clone(&prev.metrics_collection), + child_task_metrics: Arc::clone(&prev.child_task_metrics), }), }) } @@ -226,11 +218,18 @@ impl NetworkBoundary for NetworkShuffleExec { let ready = NetworkShuffleReadyExec { properties: pending.repartition_exec.properties().clone(), stage_num, - metrics_collection: Default::default(), + child_task_metrics: Default::default(), }; Ok(Arc::new(Self::Ready(ready))) } + + fn metrics_collection(&self) -> Option>>> { + match self { + NetworkShuffleExec::Pending(_) => None, + NetworkShuffleExec::Ready(v) => Some(v.child_task_metrics.clone()), + } + } } impl DisplayAs for NetworkShuffleExec { @@ -331,7 +330,7 @@ impl ExecutionPlan for NetworkShuffleExec { }, ); - let metrics_collection_capture = self_ready.metrics_collection.clone(); + let metrics_collection_capture = self_ready.child_task_metrics.clone(); async move { let url = task.url.ok_or(internal_datafusion_err!( "NetworkShuffleExec: task is unassigned, cannot proceed" diff --git a/src/protobuf/distributed_codec.rs b/src/protobuf/distributed_codec.rs index 67c8fbd..ea9449b 100644 --- a/src/protobuf/distributed_codec.rs +++ b/src/protobuf/distributed_codec.rs @@ -234,7 +234,7 @@ fn new_network_hash_shuffle_exec( Boundedness::Bounded, ), stage_num, - metrics_collection: Default::default(), + child_task_metrics: Default::default(), }) } @@ -268,7 +268,7 @@ fn new_network_coalesce_tasks_exec( ), stage_num, input_tasks, - metrics_collection: Default::default(), + child_task_metrics: Default::default(), }) } From e0890cf90b061675f6b2806dbc6472f31751b177 Mon Sep 17 00:00:00 2001 From: Jayant Shrivastava Date: Mon, 29 Sep 2025 14:17:12 -0400 Subject: [PATCH 3/3] stateful execution plan test --- tests/stateful_execution_plan.rs | 278 +++++++++++++++++++++++++++++++ 1 file changed, 278 insertions(+) create mode 100644 tests/stateful_execution_plan.rs diff --git a/tests/stateful_execution_plan.rs b/tests/stateful_execution_plan.rs new file mode 100644 index 0000000..5372146 --- /dev/null +++ b/tests/stateful_execution_plan.rs @@ -0,0 +1,278 @@ +#[cfg(all(feature = "integration", test))] +mod tests { + use datafusion::arrow::array::Int64Array; + use datafusion::arrow::compute::SortOptions; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::arrow::record_batch::RecordBatch; + use datafusion::arrow::util::pretty::pretty_format_batches; + use datafusion::common::runtime::SpawnedTask; + use datafusion::error::DataFusionError; + use datafusion::execution::{ + FunctionRegistry, SendableRecordBatchStream, SessionState, SessionStateBuilder, TaskContext, + }; + use datafusion::logical_expr::Operator; + use datafusion::physical_expr::expressions::{BinaryExpr, col, lit}; + use datafusion::physical_expr::{ + EquivalenceProperties, LexOrdering, Partitioning, PhysicalSortExpr, + }; + use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; + use datafusion::physical_plan::filter::FilterExec; + use datafusion::physical_plan::repartition::RepartitionExec; + use datafusion::physical_plan::sorts::sort::SortExec; + use datafusion::physical_plan::stream::RecordBatchStreamAdapter; + use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, displayable, execute_stream, + }; + use datafusion_distributed::test_utils::localhost::start_localhost_context; + use datafusion_distributed::{ + DistributedExt, DistributedSessionBuilderContext, PartitionIsolatorExec, assert_snapshot, + }; + use datafusion_distributed::{DistributedPhysicalOptimizerRule, NetworkShuffleExec}; + use datafusion_proto::physical_plan::PhysicalExtensionCodec; + use datafusion_proto::protobuf::proto_error; + use futures::TryStreamExt; + use prost::Message; + use std::any::Any; + use std::fmt::Formatter; + use std::sync::{Arc, RwLock}; + use std::time::Duration; + use tokio::sync::mpsc; + use tokio_stream::StreamExt; + use tokio_stream::wrappers::ReceiverStream; + + #[tokio::test] + async fn stateful_execution_plan() -> Result<(), Box> { + async fn build_state( + ctx: DistributedSessionBuilderContext, + ) -> Result { + Ok(SessionStateBuilder::new() + .with_runtime_env(ctx.runtime_env) + .with_default_features() + .with_distributed_user_codec(Int64ListExecCodec) + .build()) + } + + let (ctx, _guard) = start_localhost_context(3, build_state).await; + + let distributed_plan = build_plan()?; + let distributed_plan = DistributedPhysicalOptimizerRule::distribute_plan(distributed_plan)?; + + assert_snapshot!(displayable(&distributed_plan).indent(true).to_string(), @r" + ┌───── Stage 3 Tasks: t0:[p0] + │ SortExec: expr=[numbers@0 DESC NULLS LAST], preserve_partitioning=[false] + │ RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=10 + │ NetworkShuffleExec read_from=Stage 2, output_partitions=10, n_tasks=1, input_tasks=10 + └────────────────────────────────────────────────── + ┌───── Stage 2 Tasks: t0:[p0,p1,p2,p3,p4,p5,p6,p7,p8,p9] t1:[p10,p11,p12,p13,p14,p15,p16,p17,p18,p19] t2:[p20,p21,p22,p23,p24,p25,p26,p27,p28,p29] t3:[p30,p31,p32,p33,p34,p35,p36,p37,p38,p39] t4:[p40,p41,p42,p43,p44,p45,p46,p47,p48,p49] t5:[p50,p51,p52,p53,p54,p55,p56,p57,p58,p59] t6:[p60,p61,p62,p63,p64,p65,p66,p67,p68,p69] t7:[p70,p71,p72,p73,p74,p75,p76,p77,p78,p79] t8:[p80,p81,p82,p83,p84,p85,p86,p87,p88,p89] t9:[p90,p91,p92,p93,p94,p95,p96,p97,p98,p99] + │ RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + │ SortExec: expr=[numbers@0 DESC NULLS LAST], preserve_partitioning=[false] + │ NetworkShuffleExec read_from=Stage 1, output_partitions=1, n_tasks=10, input_tasks=1 + └────────────────────────────────────────────────── + ┌───── Stage 1 Tasks: t0:[p0,p1,p2,p3,p4,p5,p6,p7,p8,p9] + │ RepartitionExec: partitioning=Hash([numbers@0], 10), input_partitions=1 + │ FilterExec: numbers@0 > 1 + │ StatefulInt64ListExec: length=6 + └────────────────────────────────────────────────── + "); + + let stream = execute_stream(Arc::new(distributed_plan), ctx.task_ctx())?; + let batches_distributed = stream.try_collect::>().await?; + + assert_snapshot!(pretty_format_batches(&batches_distributed).unwrap(), @r" + +---------+ + | numbers | + +---------+ + | 6 | + | 5 | + | 4 | + | 3 | + | 2 | + +---------+ + "); + Ok(()) + } + + fn build_plan() -> Result, DataFusionError> { + let mut plan: Arc = + Arc::new(StatefulInt64ListExec::new(vec![1, 2, 3, 4, 5, 6])); + + plan = Arc::new(PartitionIsolatorExec::new_pending(plan)); + + plan = Arc::new(FilterExec::try_new( + Arc::new(BinaryExpr::new( + col("numbers", &plan.schema())?, + Operator::Gt, + lit(1i64), + )), + plan, + )?); + + plan = Arc::new(NetworkShuffleExec::try_new( + Arc::clone(&plan), + Partitioning::Hash(vec![col("numbers", &plan.schema())?], 1), + 10, + )?); + + plan = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new( + col("numbers", &plan.schema())?, + SortOptions::new(true, false), + )]) + .unwrap(), + plan, + )); + + plan = Arc::new(NetworkShuffleExec::try_new( + plan, + Partitioning::RoundRobinBatch(10), + 10, + )?); + + plan = Arc::new(RepartitionExec::try_new( + plan, + Partitioning::RoundRobinBatch(1), + )?); + + plan = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new( + col("numbers", &plan.schema())?, + SortOptions::new(true, false), + )]) + .unwrap(), + plan, + )); + + Ok(plan) + } + + #[derive(Debug)] + pub struct StatefulInt64ListExec { + plan_properties: PlanProperties, + numbers: Vec, + task: RwLock>>, + tx: RwLock>>, + rx: RwLock>>, + } + + impl StatefulInt64ListExec { + fn new(numbers: Vec) -> Self { + let schema = Schema::new(vec![Field::new("numbers", DataType::Int64, false)]); + let (tx, rx) = mpsc::channel(10); + Self { + numbers, + plan_properties: PlanProperties::new( + EquivalenceProperties::new(Arc::new(schema)), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ), + task: RwLock::new(None), + tx: RwLock::new(Some(tx)), + rx: RwLock::new(Some(rx)), + } + } + } + + impl DisplayAs for StatefulInt64ListExec { + fn fmt_as(&self, _: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + write!(f, "StatefulInt64ListExec: length={:?}", self.numbers.len()) + } + } + + impl ExecutionPlan for StatefulInt64ListExec { + fn name(&self) -> &str { + "StatefulInt64ListExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.plan_properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> datafusion::common::Result> { + Ok(self) + } + + fn execute( + &self, + _: usize, + _: Arc, + ) -> datafusion::common::Result { + if let Some(tx) = self.tx.write().unwrap().take() { + let numbers = self.numbers.clone(); + self.task + .write() + .unwrap() + .replace(SpawnedTask::spawn(async move { + for n in numbers { + tx.send(n).await.unwrap(); + tokio::time::sleep(Duration::from_millis(100)).await; + } + })); + } + + let rx = self.rx.write().unwrap().take().unwrap(); + let schema = self.schema(); + + let stream = ReceiverStream::new(rx).map(move |v| { + RecordBatch::try_new(schema.clone(), vec![Arc::new(Int64Array::from(vec![v]))]) + .map_err(DataFusionError::from) + }); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema().clone(), + stream, + ))) + } + } + + #[derive(Debug)] + struct Int64ListExecCodec; + + #[derive(Clone, PartialEq, ::prost::Message)] + struct Int64ListExecProto { + #[prost(message, repeated, tag = "1")] + numbers: Vec, + } + + impl PhysicalExtensionCodec for Int64ListExecCodec { + fn try_decode( + &self, + buf: &[u8], + _: &[Arc], + _registry: &dyn FunctionRegistry, + ) -> datafusion::common::Result> { + let node = + Int64ListExecProto::decode(buf).map_err(|err| proto_error(format!("{err}")))?; + Ok(Arc::new(StatefulInt64ListExec::new(node.numbers.clone()))) + } + + fn try_encode( + &self, + node: Arc, + buf: &mut Vec, + ) -> datafusion::common::Result<()> { + let Some(plan) = node.as_any().downcast_ref::() else { + return Err(proto_error(format!( + "Expected plan to be of type Int64ListExec, but was {}", + node.name() + ))); + }; + Int64ListExecProto { + numbers: plan.numbers.clone(), + } + .encode(buf) + .map_err(|err| proto_error(format!("{err}"))) + } + } +} \ No newline at end of file