diff --git a/src/execution_plans/arrow_flight_read.rs b/src/execution_plans/arrow_flight_read.rs index bee5819..09e4895 100644 --- a/src/execution_plans/arrow_flight_read.rs +++ b/src/execution_plans/arrow_flight_read.rs @@ -2,13 +2,15 @@ use crate::channel_resolver_ext::get_distributed_channel_resolver; use crate::config_extension_ext::ContextGrpcMetadata; use crate::errors::{map_flight_to_datafusion_error, map_status_to_datafusion_error}; use crate::execution_plans::StageExec; -use crate::flight_service::{DoGet, StageKey}; -use crate::protobuf::{proto_from_stage, DistributedCodec}; +use crate::flight_service::DoGet; +use crate::metrics::proto::MetricsSetProto; +use crate::protobuf::{proto_from_stage, DistributedCodec, StageKey}; use crate::ChannelResolver; use arrow_flight::decode::FlightRecordBatchStream; use arrow_flight::error::FlightError; use arrow_flight::flight_service_client::FlightServiceClient; use arrow_flight::Ticket; +use dashmap::DashMap; use datafusion::arrow::datatypes::SchemaRef; use datafusion::common::{exec_err, internal_datafusion_err, internal_err, plan_err}; use datafusion::error::DataFusionError; @@ -56,6 +58,13 @@ pub struct ArrowFlightReadReadyExec { /// the properties we advertise for this execution plan properties: PlanProperties, pub(crate) stage_num: usize, + /// metrics_collection is used to collect metrics from child tasks. It is empty when an ArrowFlightReadReadyExec is instansiated + /// (deserialized, created via [ArrowFlightReadExec::new_ready] etc). Metrics are populated in this map via [ArrowFlightReadExec::execute]. + /// + /// An instance may recieve metrics for 0 to N child tasks, where N is the number of tasks in the stage it is reading from. + /// This is because, by convention, the ArrowFlightEndpoint sends metrics for a task to the last ArrowFlightReadExec to read from it, which + /// may or may not be this instance. + pub(super) metrics_collection: Arc>>, } impl ArrowFlightReadExec { @@ -85,6 +94,7 @@ impl ArrowFlightReadExec { Self::Ready(ArrowFlightReadReadyExec { properties, stage_num, + metrics_collection: Arc::new(DashMap::new()), }) } diff --git a/src/execution_plans/metrics.rs b/src/execution_plans/metrics.rs new file mode 100644 index 0000000..8cfbbc5 --- /dev/null +++ b/src/execution_plans/metrics.rs @@ -0,0 +1,509 @@ +use datafusion::execution::TaskContext; +use datafusion::physical_plan::metrics::MetricsSet; +use std::collections::HashMap; +use std::sync::Arc; + +use crate::execution_plans::{ArrowFlightReadExec, StageExec}; +use crate::metrics::proto::{metrics_set_proto_to_df, MetricsSetProto}; +use crate::protobuf::StageKey; +use datafusion::common::internal_err; +use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter}; +use datafusion::error::{DataFusionError, Result}; +use datafusion::execution::SendableRecordBatchStream; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::{DisplayAs, DisplayFormatType, PlanProperties}; +use std::any::Any; +use std::fmt::{Debug, Formatter}; + +/// TaskMetricsCollector is used to collect metrics from a task. It implements [TreeNodeRewriter]. +/// Note: TaskMetricsCollector is not a [datafusion::physical_plan::ExecutionPlanVisitor] to keep +/// parity between [TaskMetricsCollector] and [TaskMetricsRewriter]. +pub struct TaskMetricsCollector { + /// metrics contains the metrics for the current task. + task_metrics: Vec, + /// child_task_metrics contains metrics for tasks from child [StageExec]s if they were + /// collected. + child_task_metrics: HashMap>, +} + +/// MetricsCollectorResult is the result of collecting metrics from a task. +#[allow(dead_code)] +pub struct MetricsCollectorResult { + // metrics is a collection of metrics for a task ordered using a pre-order traversal of the task's plan. + task_metrics: Vec, + // child_task_metrics contains metrics for child tasks if they were collected. + child_task_metrics: HashMap>, +} + +impl TreeNodeRewriter for TaskMetricsCollector { + type Node = Arc; + + fn f_down(&mut self, plan: Self::Node) -> Result> { + // If the plan is an ArrowFlightReadExec, assume it has collected metrics already + // from child tasks. + if let Some(read_exec) = plan.as_any().downcast_ref::() { + match read_exec { + ArrowFlightReadExec::Pending { .. } => { + return internal_err!( + "unexpected ArrowFlightReadExec::pending during metrics collection" + ); + } + ArrowFlightReadExec::Ready(read_exec) => { + for mut entry in read_exec.metrics_collection.iter_mut() { + let stage_key = entry.key().clone(); + let task_metrics = std::mem::take(entry.value_mut()); // Avoid copy. + match self.child_task_metrics.get(&stage_key) { + // There should never be two ArrowFlightReadExec with metrics for the same stage_key. + // By convention, the ArrowFlightReadExec which runs the last partition in a task should be + // sent metrics (the ArrowFlightEndpoint tracks it for us). + Some(_) => { + return internal_err!( + "duplicate task metrics for key {} during metrics collection", + stage_key + ); + } + None => { + self.child_task_metrics + .insert(stage_key.clone(), task_metrics); + } + } + } + } + } + // Skip the subtree of the ArrowFlightReadExec. + return Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)); + } + + // For plan nodes in this task, collect metrics. + match plan.metrics() { + Some(metrics) => self.task_metrics.push(metrics.clone()), + None => { + // TODO: Consider using a more efficent encoding scheme to avoid empty slots in the vec. + self.task_metrics.push(MetricsSet::new()) + } + } + Ok(Transformed::new(plan, false, TreeNodeRecursion::Continue)) + } +} + +impl TaskMetricsCollector { + #[allow(dead_code)] + pub fn new() -> Self { + Self { + task_metrics: Vec::new(), + child_task_metrics: HashMap::new(), + } + } + + /// collect metrics from a StageExec plan and any child tasks. + /// Returns + /// - a vec representing the metrics for the current task (ordered using a pre-order traversal) + /// - a map representing the metrics for some subset of child tasks collected from ArrowFlightReadExec leaves + #[allow(dead_code)] + pub fn collect(mut self, stage: &StageExec) -> Result { + stage.plan.clone().rewrite(&mut self)?; + Ok(MetricsCollectorResult { + task_metrics: self.task_metrics, + child_task_metrics: self.child_task_metrics, + }) + } +} + +/// TaskMetricsRewriter is used to enrich a task with metrics by re-writing the plan using [MetricsWrapperExec] nodes. +/// +/// Ex. for a plan with the form +/// AggregateExec +/// └── ProjectionExec +/// └── ArrowFlightReadExec +/// +/// the task will be rewritten as +/// +/// MetricsWrapperExec (wrapped: AggregateExec) +/// └── MetricsWrapperExec (wrapped: ProjectionExec) +/// └── ArrowFlightReadExec +/// (Note that the ArrowFlightReadExec node is not wrapped) +pub struct TaskMetricsRewriter { + metrics: Vec, + idx: usize, +} + +impl TaskMetricsRewriter { + /// Create a new TaskMetricsRewriter. The provided metrics will be used to enrich the plan. + #[allow(dead_code)] + pub fn new(metrics: Vec) -> Self { + Self { metrics, idx: 0 } + } + + /// enrich_task_with_metrics rewrites the plan by wrapping nodes. If the length of the provided metrics set vec does not + /// match the number of nodes in the plan, an error will be returned. + #[allow(dead_code)] + pub fn enrich_task_with_metrics( + mut self, + plan: Arc, + ) -> Result> { + let transformed = plan.rewrite(&mut self)?; + if self.idx != self.metrics.len() { + return internal_err!("too many metrics sets provided to rewrite task: {} metrics sets provided, {} nodes in plan", self.metrics.len(), self.idx); + } + Ok(transformed.data) + } +} + +impl TreeNodeRewriter for TaskMetricsRewriter { + type Node = Arc; + + fn f_down(&mut self, plan: Self::Node) -> Result> { + if plan + .as_any() + .downcast_ref::() + .is_some() + { + // Do not recurse into ArrowFlightReadExec. + return Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)); + } + if self.idx >= self.metrics.len() { + return internal_err!( + "not enough metrics provided to rewrite task: {} metrics provided", + self.metrics.len() + ); + } + let proto_metrics = &self.metrics[self.idx]; + + let wrapped_plan_node: Arc = Arc::new(MetricsWrapperExec::new( + plan.clone(), + metrics_set_proto_to_df(proto_metrics)?, + )); + let result = Transformed::new(wrapped_plan_node, true, TreeNodeRecursion::Continue); + self.idx += 1; + Ok(result) + } +} + +/// A transparent wrapper that delegates all execution to its child but returns custom metrics. This node is invisible during display. +/// The structure of a plan tree is closely tied to the [TaskMetricsRewriter]. +struct MetricsWrapperExec { + inner: Arc, + /// metrics for this plan node. + metrics: MetricsSet, + /// children is initially None. When used by the [TaskMetricsRewriter], the children will be updated + /// to point at other wrapped nodes. + children: Option>>, +} + +impl MetricsWrapperExec { + pub fn new(inner: Arc, metrics: MetricsSet) -> Self { + Self { + inner, + metrics, + children: None, + } + } +} + +/// MetricsWrapperExec is invisible during display. +impl DisplayAs for MetricsWrapperExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + self.inner.fmt_as(t, f) + } +} + +/// MetricsWrapperExec is visible when debugging. +impl Debug for MetricsWrapperExec { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "MetricsWrapperExec ({:?})", self.inner) + } +} + +impl ExecutionPlan for MetricsWrapperExec { + fn name(&self) -> &str { + "MetricsWrapperExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + unimplemented!("MetricsWrapperExec does not implement properties") + } + + fn children(&self) -> Vec<&Arc> { + match &self.children { + Some(children) => children.iter().collect(), + None => self.inner.children(), + } + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(MetricsWrapperExec { + inner: self.inner.clone(), + metrics: self.metrics.clone(), + children: Some(children), + })) + } + + fn execute( + &self, + _partition: usize, + _contex: Arc, + ) -> Result { + unimplemented!("MetricsWrapperExec does not implement execute") + } + + // metrics returns the wrapped metrics. + fn metrics(&self) -> Option { + Some(self.metrics.clone()) + } +} + +#[cfg(test)] +mod tests { + use crate::metrics::proto::{ + ElapsedCompute, EndTimestamp, MetricProto, MetricValueProto, OutputRows, StartTimestamp, + }; + + use super::*; + use datafusion::arrow::array::{Int32Array, StringArray}; + use datafusion::arrow::record_batch::RecordBatch; + + use crate::test_utils::in_memory_channel_resolver::InMemoryChannelResolver; + use crate::test_utils::session_context::register_temp_parquet_table; + use crate::DistributedExt; + use crate::DistributedPhysicalOptimizerRule; + use datafusion::execution::{context::SessionContext, SessionStateBuilder}; + use datafusion::physical_plan::metrics::MetricValue; + use datafusion::prelude::SessionConfig; + use datafusion::{ + arrow::datatypes::{DataType, Field, Schema}, + physical_plan::display::DisplayableExecutionPlan, + }; + use std::sync::Arc; + + /// Creates a stage with the following structure: + /// + /// SortPreservingMergeExec + /// SortExec + /// ProjectionExec + /// AggregateExec + /// CoalesceBatchesExec + /// ArrowFlightReadExec + /// + /// ... (for the purposes of these tests, we don't care about child stages). + async fn make_test_stage_exec_with_5_nodes() -> (StageExec, SessionContext) { + // Create distributed session state with in-memory channel resolver + let config = SessionConfig::new().with_target_partitions(2); + + let state = SessionStateBuilder::new() + .with_default_features() + .with_config(config) + .with_distributed_channel_resolver(InMemoryChannelResolver::new()) + .with_physical_optimizer_rule(Arc::new( + DistributedPhysicalOptimizerRule::default().with_maximum_partitions_per_task(1), + )) + .build(); + + let ctx = SessionContext::from(state); + + // Create test data + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + ])); + + let batches = vec![ + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ], + ) + .unwrap(), + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![4, 5, 6])), + Arc::new(StringArray::from(vec!["d", "e", "f"])), + ], + ) + .unwrap(), + ]; + + // Register the test data as a parquet table + let _ = register_temp_parquet_table("test_table", schema.clone(), batches, &ctx) + .await + .unwrap(); + + let df = ctx + .sql("SELECT id, COUNT(*) as count FROM test_table WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10") + .await + .unwrap(); + let physical_distributed = df.create_physical_plan().await.unwrap(); + + let stage_exec = match physical_distributed.as_any().downcast_ref::() { + Some(stage_exec) => stage_exec.clone(), + None => panic!( + "Expected StageExec from distributed optimization, got: {}", + physical_distributed.name() + ), + }; + + (stage_exec, ctx) + } + + /// creates a "distinct" set of metrics from the provided seed + fn make_distinct_metrics_set(seed: u64) -> MetricsSetProto { + const TEST_TIMESTAMP: i64 = 1758200400000000000; // 2025-09-18 13:00:00 UTC + MetricsSetProto { + metrics: vec![ + MetricProto { + metric: Some(MetricValueProto::OutputRows(OutputRows { value: seed })), + labels: vec![], + partition: None, + }, + MetricProto { + metric: Some(MetricValueProto::ElapsedCompute(ElapsedCompute { + value: seed, + })), + labels: vec![], + partition: None, + }, + MetricProto { + metric: Some(MetricValueProto::StartTimestamp(StartTimestamp { + value: Some(TEST_TIMESTAMP + (seed as i64 * 1_000_000_000)), + })), + labels: vec![], + partition: None, + }, + MetricProto { + metric: Some(MetricValueProto::EndTimestamp(EndTimestamp { + value: Some(TEST_TIMESTAMP + ((seed as i64 + 1) * 1_000_000_000)), + })), + labels: vec![], + partition: None, + }, + ], + } + } + + #[tokio::test] + async fn test_metrics_rewriter() { + let (test_stage, _ctx) = make_test_stage_exec_with_5_nodes().await; + let test_metrics_sets = (0..5) // 5 nodes excluding ArrowFlightReadExec + .map(|i| make_distinct_metrics_set(i + 10)) + .collect::>(); + + let rewriter = TaskMetricsRewriter::new(test_metrics_sets.clone()); + let plan_with_metrics = rewriter + .enrich_task_with_metrics(test_stage.plan.clone()) + .unwrap(); + + let plan_str = + DisplayableExecutionPlan::with_full_metrics(plan_with_metrics.as_ref()).indent(true); + // Expected distributed plan output with metrics + let expected = [ + r"SortPreservingMergeExec: [id@0 ASC NULLS LAST], fetch=10, metrics=[output_rows=10, elapsed_compute=10ns, start_timestamp=2025-09-18 13:00:10 UTC, end_timestamp=2025-09-18 13:00:11 UTC]", + r" SortExec: TopK(fetch=10), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true], metrics=[output_rows=11, elapsed_compute=11ns, start_timestamp=2025-09-18 13:00:11 UTC, end_timestamp=2025-09-18 13:00:12 UTC]", + r" ProjectionExec: expr=[id@0 as id, count(Int64(1))@1 as count], metrics=[output_rows=12, elapsed_compute=12ns, start_timestamp=2025-09-18 13:00:12 UTC, end_timestamp=2025-09-18 13:00:13 UTC]", + r" AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[count(Int64(1))], metrics=[output_rows=13, elapsed_compute=13ns, start_timestamp=2025-09-18 13:00:13 UTC, end_timestamp=2025-09-18 13:00:14 UTC]", + r" CoalesceBatchesExec: target_batch_size=8192, metrics=[output_rows=14, elapsed_compute=14ns, start_timestamp=2025-09-18 13:00:14 UTC, end_timestamp=2025-09-18 13:00:15 UTC]", + r" ArrowFlightReadExec, metrics=[]", + "" // trailing newline + ].join("\n"); + assert_eq!(expected, plan_str.to_string()); + } + + #[tokio::test] + async fn test_metrics_rewriter_correct_number_of_metrics() { + let test_metrics_set = make_distinct_metrics_set(10); + let (executable_plan, _ctx) = make_test_stage_exec_with_5_nodes().await; + let task_plan = executable_plan + .as_any() + .downcast_ref::() + .unwrap() + .plan + .clone(); + + // Too few metrics sets. + let rewriter = TaskMetricsRewriter::new(vec![test_metrics_set.clone()]); + let result = rewriter.enrich_task_with_metrics(task_plan.clone()); + assert!(result.is_err()); + + // Too many metrics sets. + let rewriter = TaskMetricsRewriter::new(vec![ + test_metrics_set.clone(), + test_metrics_set.clone(), + test_metrics_set.clone(), + test_metrics_set.clone(), + ]); + let result = rewriter.enrich_task_with_metrics(task_plan.clone()); + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_metrics_collection() { + let (stage_exec, ctx) = make_test_stage_exec_with_5_nodes().await; + + // Execute the plan to completion. + let task_ctx = ctx.task_ctx(); + let stream = stage_exec.execute(0, task_ctx).unwrap(); + + use futures::StreamExt; + let mut stream = stream; + while let Some(_batch) = stream.next().await {} + + let collector = TaskMetricsCollector::new(); + let result = collector.collect(&stage_exec).unwrap(); + + // With the distributed optimizer, we get a much more complex plan structure + // The exact number of metrics sets depends on the plan optimization, so be flexible + assert_eq!(result.task_metrics.len(), 5); + + let expected_metrics_count = [4, 10, 8, 16, 8]; + for (node_idx, metrics_set) in result.task_metrics.iter().enumerate() { + let metrics_count = metrics_set.iter().count(); + assert_eq!(metrics_count, expected_metrics_count[node_idx]); + + // Each node should have basic metrics: ElapsedCompute, OutputRows, StartTimestamp, EndTimestamp. + let mut has_start_timestamp = false; + let mut has_end_timestamp = false; + let mut has_elapsed_compute = false; + let mut has_output_rows = false; + + for metric in metrics_set.iter() { + let metric_name = metric.value().name(); + let metric_value = metric.value(); + + match metric_value { + MetricValue::StartTimestamp(_) if metric_name == "start_timestamp" => { + has_start_timestamp = true; + } + MetricValue::EndTimestamp(_) if metric_name == "end_timestamp" => { + has_end_timestamp = true; + } + MetricValue::ElapsedCompute(_) if metric_name == "elapsed_compute" => { + has_elapsed_compute = true; + } + MetricValue::OutputRows(_) if metric_name == "output_rows" => { + has_output_rows = true; + } + _ => { + // Other metrics are fine, we just validate the core ones + } + } + } + + // Each node should have the four basic metrics + assert!(has_start_timestamp); + assert!(has_end_timestamp); + assert!(has_elapsed_compute); + assert!(has_output_rows); + } + + // TODO: once we propagate metrics from child stages, we can assert this. + assert_eq!(0, result.child_task_metrics.len()); + } +} diff --git a/src/execution_plans/mod.rs b/src/execution_plans/mod.rs index c8ff514..6b75d02 100644 --- a/src/execution_plans/mod.rs +++ b/src/execution_plans/mod.rs @@ -1,4 +1,5 @@ mod arrow_flight_read; +mod metrics; mod partition_isolator; mod stage; diff --git a/src/execution_plans/stage.rs b/src/execution_plans/stage.rs index d99f3a9..a4e2f42 100644 --- a/src/execution_plans/stage.rs +++ b/src/execution_plans/stage.rs @@ -530,7 +530,7 @@ pub fn display_plan( plan: &Arc, partition_group: &[usize], stage_num: usize, - distributed: bool, + _distributed: bool, ) -> Result { // draw all plans // we need to label the nodes including depth to uniquely identify them within this task diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index e979667..da33255 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -1,10 +1,9 @@ -use super::service::StageKey; use crate::config_extension_ext::ContextGrpcMetadata; use crate::errors::datafusion_error_to_tonic_status; use crate::execution_plans::{PartitionGroup, StageExec}; use crate::flight_service::service::ArrowFlightEndpoint; use crate::flight_service::session_builder::DistributedSessionBuilderContext; -use crate::protobuf::{stage_from_proto, DistributedCodec, StageExecProto}; +use crate::protobuf::{stage_from_proto, DistributedCodec, StageExecProto, StageKey}; use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::error::FlightError; use arrow_flight::flight_service_server::FlightService; diff --git a/src/flight_service/mod.rs b/src/flight_service/mod.rs index 8277e32..4be0f12 100644 --- a/src/flight_service/mod.rs +++ b/src/flight_service/mod.rs @@ -4,7 +4,7 @@ mod session_builder; pub(crate) use do_get::DoGet; -pub use service::{ArrowFlightEndpoint, StageKey}; +pub use service::ArrowFlightEndpoint; pub use session_builder::{ DefaultSessionBuilder, DistributedSessionBuilder, DistributedSessionBuilderContext, MappedDistributedSessionBuilder, MappedDistributedSessionBuilderExt, diff --git a/src/flight_service/service.rs b/src/flight_service/service.rs index c50ff1d..ed23f0e 100644 --- a/src/flight_service/service.rs +++ b/src/flight_service/service.rs @@ -1,6 +1,7 @@ use crate::common::ttl_map::{TTLMap, TTLMapConfig}; use crate::flight_service::do_get::TaskData; use crate::flight_service::DistributedSessionBuilder; +use crate::protobuf::StageKey; use arrow_flight::flight_service_server::FlightService; use arrow_flight::{ Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, @@ -14,20 +15,6 @@ use std::sync::Arc; use tokio::sync::OnceCell; use tonic::{Request, Response, Status, Streaming}; -/// A key that uniquely identifies a stage in a query -#[derive(Clone, Hash, Eq, PartialEq, ::prost::Message)] -pub struct StageKey { - /// Our query id - #[prost(string, tag = "1")] - pub query_id: String, - /// Our stage id - #[prost(uint64, tag = "2")] - pub stage_id: u64, - /// The task number within the stage - #[prost(uint64, tag = "3")] - pub task_number: u64, -} - pub struct ArrowFlightEndpoint { pub(super) runtime: Arc, pub(super) task_data_entries: TTLMap>>, diff --git a/src/protobuf/mod.rs b/src/protobuf/mod.rs index 0891846..d01d2dc 100644 --- a/src/protobuf/mod.rs +++ b/src/protobuf/mod.rs @@ -3,5 +3,5 @@ mod stage_proto; mod user_codec; pub(crate) use distributed_codec::DistributedCodec; -pub(crate) use stage_proto::{proto_from_stage, stage_from_proto, StageExecProto}; +pub(crate) use stage_proto::{proto_from_stage, stage_from_proto, StageExecProto, StageKey}; pub(crate) use user_codec::{get_distributed_user_codec, set_distributed_user_codec}; diff --git a/src/protobuf/stage_proto.rs b/src/protobuf/stage_proto.rs index 332e6f7..e7a3733 100644 --- a/src/protobuf/stage_proto.rs +++ b/src/protobuf/stage_proto.rs @@ -9,9 +9,34 @@ use datafusion_proto::{ physical_plan::{AsExecutionPlan, PhysicalExtensionCodec}, protobuf::PhysicalPlanNode, }; +use std::fmt::Display; use std::sync::Arc; use url::Url; +/// A key that uniquely identifies a stage in a query +#[derive(Clone, Hash, Eq, PartialEq, ::prost::Message)] +pub struct StageKey { + /// Our query id + #[prost(string, tag = "1")] + pub query_id: String, + /// Our stage id + #[prost(uint64, tag = "2")] + pub stage_id: u64, + /// The task number within the stage + #[prost(uint64, tag = "3")] + pub task_number: u64, +} + +impl Display for StageKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "StageKey_QueryID_{}_StageID_{}_TaskNumber_{}", + self.query_id, self.stage_id, self.task_number + ) + } +} + #[derive(Clone, PartialEq, ::prost::Message)] pub struct StageExecProto { /// Our query id diff --git a/src/test_utils/in_memory_channel_resolver.rs b/src/test_utils/in_memory_channel_resolver.rs new file mode 100644 index 0000000..c387e23 --- /dev/null +++ b/src/test_utils/in_memory_channel_resolver.rs @@ -0,0 +1,82 @@ +use crate::{ + ArrowFlightEndpoint, BoxCloneSyncChannel, ChannelResolver, DistributedExt, + DistributedSessionBuilderContext, +}; +use arrow_flight::flight_service_server::FlightServiceServer; +use async_trait::async_trait; +use datafusion::common::DataFusionError; +use datafusion::execution::SessionStateBuilder; +use hyper_util::rt::TokioIo; +use tonic::transport::{Endpoint, Server}; + +const DUMMY_URL: &str = "http://localhost:50051"; + +/// [ChannelResolver] implementation that returns gRPC clients backed by an in-memory +/// tokio duplex rather than a TCP connection. +#[derive(Clone)] +pub struct InMemoryChannelResolver { + channel: BoxCloneSyncChannel, +} + +impl Default for InMemoryChannelResolver { + fn default() -> Self { + Self::new() + } +} + +impl InMemoryChannelResolver { + pub fn new() -> Self { + let (client, server) = tokio::io::duplex(1024 * 1024); + + let mut client = Some(client); + let channel = Endpoint::try_from(DUMMY_URL) + .expect("Invalid dummy URL for building an endpoint. This should never happen") + .connect_with_connector_lazy(tower::service_fn(move |_| { + let client = client + .take() + .expect("Client taken twice. This should never happen"); + async move { Ok::<_, std::io::Error>(TokioIo::new(client)) } + })); + + let this = Self { + channel: BoxCloneSyncChannel::new(channel), + }; + let this_clone = this.clone(); + + let endpoint = + ArrowFlightEndpoint::try_new(move |ctx: DistributedSessionBuilderContext| { + let this = this.clone(); + async move { + let builder = SessionStateBuilder::new() + .with_default_features() + .with_distributed_channel_resolver(this) + .with_runtime_env(ctx.runtime_env.clone()); + Ok(builder.build()) + } + }) + .unwrap(); + + tokio::spawn(async move { + Server::builder() + .add_service(FlightServiceServer::new(endpoint)) + .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) + .await + }); + + this_clone + } +} + +#[async_trait] +impl ChannelResolver for InMemoryChannelResolver { + fn get_urls(&self) -> Result, DataFusionError> { + Ok(vec![url::Url::parse(DUMMY_URL).unwrap()]) + } + + async fn get_channel_for_url( + &self, + _: &url::Url, + ) -> Result { + Ok(self.channel.clone()) + } +} diff --git a/src/test_utils/mod.rs b/src/test_utils/mod.rs index 2ef71e3..4dd283c 100644 --- a/src/test_utils/mod.rs +++ b/src/test_utils/mod.rs @@ -1,5 +1,7 @@ +pub mod in_memory_channel_resolver; pub mod insta; pub mod localhost; pub mod mock_exec; pub mod parquet; +pub mod session_context; pub mod tpch; diff --git a/src/test_utils/session_context.rs b/src/test_utils/session_context.rs new file mode 100644 index 0000000..276010d --- /dev/null +++ b/src/test_utils/session_context.rs @@ -0,0 +1,98 @@ +use arrow::record_batch::RecordBatch; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::error::Result; +use datafusion::execution::context::SessionContext; +use datafusion::prelude::ParquetReadOptions; +use parquet::arrow::ArrowWriter; +use std::path::PathBuf; +use uuid::Uuid; + +/// Creates a temporary Parquet file from RecordBatches and registers it with the SessionContext +/// under the provided table name. Returns the file name. +/// +/// TODO: consider expanding this to support partitioned data +pub async fn register_temp_parquet_table( + table_name: &str, + schema: SchemaRef, + batches: Vec, + ctx: &SessionContext, +) -> Result { + if batches.is_empty() { + return Err(datafusion::error::DataFusionError::Execution( + "cannot create parquet file from empty batch list".to_string(), + )); + } + for batch in &batches { + if batch.schema() != schema { + return Err(datafusion::error::DataFusionError::Execution( + "all batches must have the same schema".to_string(), + )); + } + } + + let temp_dir = std::env::temp_dir(); + let file_id = Uuid::new_v4(); + let temp_file_path = temp_dir.join(format!("{}_{}.parquet", table_name, file_id,)); + + let file = std::fs::File::create(&temp_file_path)?; + let schema = batches[0].schema(); + let mut writer = ArrowWriter::try_new(file, schema, None)?; + + for batch in batches { + writer.write(&batch)?; + } + writer.close()?; + + ctx.register_parquet( + table_name, + temp_file_path.to_string_lossy().as_ref(), + ParquetReadOptions::default(), + ) + .await?; + + Ok(temp_file_path) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int32Array, StringArray}; + use arrow::datatypes::{DataType, Field, Schema}; + use tokio::fs::remove_file; + + use std::sync::Arc; + + #[tokio::test] + async fn test_register_temp_parquet_table() { + let ctx = SessionContext::new(); + + // Create test data + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ], + ) + .unwrap(); + + // Register temp table + let temp_file = + register_temp_parquet_table("test_table", schema.clone(), vec![batch], &ctx) + .await + .unwrap(); + + let df = ctx.sql("SELECT * FROM test_table").await.unwrap(); + let results = df.collect().await.unwrap(); + + assert_eq!(results.len(), 1); + assert_eq!(results[0].num_rows(), 3); + + let _ = remove_file(temp_file).await; + } +} diff --git a/tests/distributed_aggregation.rs b/tests/distributed_aggregation.rs index a967ea1..ec0ec6a 100644 --- a/tests/distributed_aggregation.rs +++ b/tests/distributed_aggregation.rs @@ -31,8 +31,6 @@ mod tests { .indent(true) .to_string(); - println!("physical plan:\n{}", physical_str); - assert_snapshot!(physical_str, @r" ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday]