diff --git a/Cargo.lock b/Cargo.lock index a859edd8..e3f8d344 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1725,6 +1725,7 @@ dependencies = [ "arrow", "arrow-flight", "async-stream", + "async-trait", "bytes", "clap", "datafusion", @@ -2507,6 +2508,17 @@ version = "3.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" +[[package]] +name = "io-uring" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b86e202f00093dcba4275d4636b93ef9dd75d025ae560d2521b45ea28ab49013" +dependencies = [ + "bitflags", + "cfg-if", + "libc", +] + [[package]] name = "ipnet" version = "2.11.0" @@ -4560,17 +4572,19 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.45.1" +version = "1.46.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75ef51a33ef1da925cea3e4eb122833cb377c61439ca401b770f54902b806779" +checksum = "0cc3a2344dafbe23a245241fe8b09735b521110d30fcefbbd5feb1797ca35d17" dependencies = [ "backtrace", "bytes", + "io-uring", "libc", "mio", "parking_lot", "pin-project-lite", "signal-hook-registry", + "slab", "socket2", "tokio-macros", "windows-sys 0.52.0", diff --git a/Cargo.toml b/Cargo.toml index b008b6f7..132d0926 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,6 +40,7 @@ path = "src/bin/distributed-datafusion.rs" anyhow = "1" arrow = { version = "55.1", features = ["ipc"] } arrow-flight = { version = "55", features = ["flight-sql-experimental"] } +async-trait = "0.1.88" async-stream = "0.3" bytes = "1.5" clap = { version = "4.4", features = ["derive"] } @@ -67,7 +68,7 @@ rustls = "0.23" test-log = "0.2" thiserror = "1.0" -tokio = { version = "1.0", features = ["full"] } +tokio = { version = "1.46", features = ["full"] } tokio-stream = { version = "0.1", features = ["full"] } tonic = { version = "0.12", default-features = false, features = [ diff --git a/src/bin/distributed-datafusion.rs b/src/bin/distributed-datafusion.rs index 400f0ea9..738235b0 100644 --- a/src/bin/distributed-datafusion.rs +++ b/src/bin/distributed-datafusion.rs @@ -33,11 +33,11 @@ async fn main() -> Result<()> { match args.mode.as_str() { "proxy" => { - let service = DDProxyService::new(new_friendly_name()?, args.port).await?; + let service = DDProxyService::new(new_friendly_name()?, args.port, None).await?; service.serve().await?; } "worker" => { - let service = DDWorkerService::new(new_friendly_name()?, args.port).await?; + let service = DDWorkerService::new(new_friendly_name()?, args.port, None).await?; service.serve().await?; } _ => { diff --git a/src/codec.rs b/src/codec.rs index d5231e41..88933df8 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -31,7 +31,21 @@ use crate::{ }; #[derive(Debug)] -pub struct DDCodec {} +pub struct DDCodec { + sub_codec: Arc, +} + +impl DDCodec { + pub fn new(sub_codec: Arc) -> Self { + Self { sub_codec } + } +} + +impl Default for DDCodec { + fn default() -> Self { + Self::new(Arc::new(DefaultPhysicalExtensionCodec {})) + } +} impl PhysicalExtensionCodec for DDCodec { fn try_decode( @@ -127,6 +141,13 @@ impl PhysicalExtensionCodec for DDCodec { Ok(Arc::new(RecordBatchExec::new(batch))) } } + } else if let Ok(ext) = self.sub_codec.try_decode(buf, inputs, registry) { + // If the node is not a DDExecNode, we delegate to the sub codec + trace!( + "Delegated decoding to sub codec for node: {}", + displayable(ext.as_ref()).one_line() + ); + Ok(ext) } else { internal_err!("cannot decode proto extension in distributed datafusion codec") } @@ -151,52 +172,58 @@ impl PhysicalExtensionCodec for DDCodec { stage_id: reader.stage_id, }; - Payload::StageReaderExec(pb) + Some(Payload::StageReaderExec(pb)) } else if let Some(pi) = node.as_any().downcast_ref::() { let pb = PartitionIsolatorExecNode { partition_count: pi.partition_count as u64, }; - Payload::IsolatorExec(pb) + Some(Payload::IsolatorExec(pb)) } else if let Some(max) = node.as_any().downcast_ref::() { let pb = MaxRowsExecNode { max_rows: max.max_rows as u64, }; - Payload::MaxRowsExec(pb) + Some(Payload::MaxRowsExec(pb)) } else if let Some(exec) = node.as_any().downcast_ref::() { let pb = DistributedAnalyzeExecNode { verbose: exec.verbose, show_statistics: exec.show_statistics, }; - Payload::DistributedAnalyzeExec(pb) + Some(Payload::DistributedAnalyzeExec(pb)) } else if let Some(exec) = node.as_any().downcast_ref::() { let pb = DistributedAnalyzeRootExecNode { verbose: exec.verbose, show_statistics: exec.show_statistics, }; - Payload::DistributedAnalyzeRootExec(pb) + Some(Payload::DistributedAnalyzeRootExec(pb)) } else if let Some(exec) = node.as_any().downcast_ref::() { let pb = RecordBatchExecNode { batch: batch_to_ipc(&exec.batch).map_err(|e| { internal_datafusion_err!("Failed to encode RecordBatch: {:#?}", e) })?, }; - Payload::RecordBatchExec(pb) + Some(Payload::RecordBatchExec(pb)) } else { - return internal_err!("Not supported node to encode to proto"); + trace!( + "Node {} is not a custom DDExecNode, delegating to sub codec", + displayable(node.as_ref()).one_line() + ); + None }; - let pb = DdExecNode { - payload: Some(payload), - }; - pb.encode(buf) - .map_err(|e| internal_datafusion_err!("Failed to encode protobuf: {}", e))?; - - trace!( - "DONE encoding node: {}", - displayable(node.as_ref()).one_line() - ); - Ok(()) + match payload { + Some(payload) => { + let pb = DdExecNode { + payload: Some(payload), + }; + pb.encode(buf) + .map_err(|e| internal_datafusion_err!("Failed to encode protobuf: {:#?}", e)) + } + None => { + // If the node is not one of our custom nodes, we delegate to the sub codec + self.sub_codec.try_encode(node, buf) + } + } } } @@ -225,7 +252,7 @@ mod test { fn verify_round_trip(exec: Arc) { let ctx = SessionContext::new(); - let codec = DDCodec {}; + let codec = DDCodec::new(Arc::new(DefaultPhysicalExtensionCodec {})); // serialize execution plan to proto let proto: protobuf::PhysicalPlanNode = @@ -255,7 +282,7 @@ mod test { let schema = create_test_schema(); let part = Partitioning::UnknownPartitioning(2); let exec = Arc::new(DDStageReaderExec::try_new(part, schema, 1).unwrap()); - let codec = DDCodec {}; + let codec = DDCodec::new(Arc::new(DefaultPhysicalExtensionCodec {})); let mut buf = vec![]; codec.try_encode(exec.clone(), &mut buf).unwrap(); let ctx = SessionContext::new(); diff --git a/src/customizer.rs b/src/customizer.rs new file mode 100644 index 00000000..b502e5e0 --- /dev/null +++ b/src/customizer.rs @@ -0,0 +1,13 @@ +use datafusion::prelude::SessionContext; +use datafusion_proto::physical_plan::PhysicalExtensionCodec; + +#[async_trait::async_trait] +pub trait Customizer: PhysicalExtensionCodec + Send + Sync { + /// Customize the context before planning a query. + /// This may include registering new file formats or introducing additional + /// `PhysicalPlan` operators. + /// + /// To support serialization of customized plans for distributed execution, + /// a `Codec` may also be required. + async fn customize(&self, ctx: &mut SessionContext) -> Result<(), Box>; +} diff --git a/src/explain.rs b/src/explain.rs index 5bf6968c..dc215ed4 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -11,13 +11,17 @@ use datafusion::{ physical_plan::{displayable, ExecutionPlan}, prelude::SessionContext, }; +use datafusion_proto::physical_plan::PhysicalExtensionCodec; use crate::{result::Result, util::bytes_to_physical_plan, vocab::DDTask}; -pub fn format_distributed_tasks(tasks: &[DDTask]) -> Result { +pub fn format_distributed_tasks( + tasks: &[DDTask], + codec: &dyn PhysicalExtensionCodec, +) -> Result { let mut result = String::new(); for (i, task) in tasks.iter().enumerate() { - let plan = bytes_to_physical_plan(&SessionContext::new(), &task.plan_bytes) + let plan = bytes_to_physical_plan(&SessionContext::new(), &task.plan_bytes, codec) .context("unable to decode task plan for formatted output")?; result.push_str(&format!( @@ -45,6 +49,7 @@ pub fn build_explain_batch( physical_plan: &Arc, distributed_plan: &Arc, distributed_tasks: &[DDTask], + codec: &dyn PhysicalExtensionCodec, ) -> Result { let schema = Arc::new(Schema::new(vec![ Field::new("plan_type", DataType::Utf8, false), @@ -64,7 +69,7 @@ pub fn build_explain_batch( displayable(distributed_plan.as_ref()) .indent(true) .to_string(), - format_distributed_tasks(distributed_tasks)?, + format_distributed_tasks(distributed_tasks, codec)?, ]); let batch = RecordBatch::try_new(schema, vec![Arc::new(plan_types), Arc::new(plans)])?; diff --git a/src/lib.rs b/src/lib.rs index a40b4e5b..76299b97 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,6 +24,7 @@ pub use proto::generated::protobuf; pub mod analyze; pub mod codec; +pub mod customizer; pub mod explain; pub mod flight; pub mod friendly; diff --git a/src/planning.rs b/src/planning.rs index 8e04213d..461435ec 100644 --- a/src/planning.rs +++ b/src/planning.rs @@ -24,6 +24,7 @@ use datafusion::{ }, prelude::{SQLOptions, SessionConfig, SessionContext}, }; +use datafusion_proto::physical_plan::PhysicalExtensionCodec; use futures::TryStreamExt; use itertools::Itertools; use prost::Message; @@ -439,6 +440,7 @@ pub async fn distribute_stages( query_id: &str, stages: Vec, worker_addrs: Vec, + codec: &dyn PhysicalExtensionCodec, ) -> Result<(Addrs, Vec)> { // map of worker name to address // FIXME: use types over tuples of strings, as we can accidently swap them and @@ -457,7 +459,7 @@ pub async fn distribute_stages( // all stages to workers let (task_datas, final_addrs) = - assign_to_workers(query_id, &stages, workers.values().collect())?; + assign_to_workers(query_id, &stages, workers.values().collect(), codec)?; // we retry this a few times to ensure that the workers are ready // and can accept the stages @@ -551,6 +553,7 @@ fn assign_to_workers( query_id: &str, stages: &[DDStage], worker_addrs: Vec<&Host>, + codec: &dyn PhysicalExtensionCodec, ) -> Result<(Vec, Addrs)> { let mut task_datas = vec![]; let mut worker_idx = 0; @@ -570,7 +573,7 @@ fn assign_to_workers( for stage in stages { for partition_group in stage.partition_groups.iter() { - let plan_bytes = physical_plan_to_bytes(stage.plan.clone())?; + let plan_bytes = physical_plan_to_bytes(stage.plan.clone(), codec)?; let host = worker_addrs[worker_idx].clone(); worker_idx = (worker_idx + 1) % worker_addrs.len(); diff --git a/src/proxy_service.rs b/src/proxy_service.rs index 95a30860..12ac59d0 100644 --- a/src/proxy_service.rs +++ b/src/proxy_service.rs @@ -40,6 +40,7 @@ use tokio::{ use tonic::{async_trait, transport::Server, Request, Response, Status}; use crate::{ + customizer::Customizer, flight::{FlightSqlHandler, FlightSqlServ}, logging::{debug, info, trace}, planning::{add_ctx_extentions, get_ctx}, @@ -57,10 +58,13 @@ pub struct DDProxyHandler { pub host: Host, pub planner: QueryPlanner, + + /// Optional customizer for our context and proto serde + pub customizer: Option>, } impl DDProxyHandler { - pub fn new(name: String, addr: String) -> Self { + pub fn new(name: String, addr: String, customizer: Option>) -> Self { // call this function to bootstrap the worker discovery mechanism get_worker_addresses().expect("Could not get worker addresses upon startup"); @@ -70,7 +74,8 @@ impl DDProxyHandler { }; Self { host: host.clone(), - planner: QueryPlanner::new(), + planner: QueryPlanner::new(customizer.clone()), + customizer, } } @@ -118,6 +123,12 @@ impl DDProxyHandler { add_ctx_extentions(&mut ctx, &self.host, &query_id, stage_id, addrs, vec![]) .map_err(|e| Status::internal(format!("Could not add context extensions {e:?}")))?; + if let Some(ref c) = self.customizer { + c.customize(&mut ctx) + .await + .map_err(|e| Status::internal(format!("Could not customize context {e:?}")))?; + } + // TODO: revisit this to allow for consuming a partitular partition trace!("calling execute plan"); let partition = 0; @@ -278,7 +289,11 @@ pub struct DDProxyService { } impl DDProxyService { - pub async fn new(name: String, port: usize) -> Result { + pub async fn new( + name: String, + port: usize, + ctx_customizer: Option>, + ) -> Result { debug!("Creating DDProxyService!"); let (all_done_tx, all_done_rx) = channel(1); @@ -290,7 +305,7 @@ impl DDProxyService { info!("DDProxyService bound to {addr}"); - let handler = Arc::new(DDProxyHandler::new(name, addr.clone())); + let handler = Arc::new(DDProxyHandler::new(name, addr.clone(), ctx_customizer)); Ok(Self { listener, diff --git a/src/query_planner.rs b/src/query_planner.rs index 14ecc783..46af7cad 100644 --- a/src/query_planner.rs +++ b/src/query_planner.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use anyhow::anyhow; +use anyhow::{anyhow, Context as AnyhowContext}; use arrow::{compute::concat_batches, datatypes::SchemaRef}; use datafusion::{ logical_expr::LogicalPlan, @@ -8,10 +8,13 @@ use datafusion::{ prelude::SessionContext, }; +use datafusion_proto::physical_plan::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec}; use datafusion_substrait::{logical_plan::consumer::from_substrait_plan, substrait::proto::Plan}; use tokio_stream::StreamExt; use crate::{ + codec::DDCodec, + customizer::Customizer, explain::build_explain_batch, planning::{ distribute_stages, execution_planning, get_ctx, logical_planning, physical_planning, @@ -55,17 +58,28 @@ impl std::fmt::Debug for QueryPlan { } /// Query planner responsible for preparing SQL queries for distributed execution -pub struct QueryPlanner; +pub struct QueryPlanner { + customizer: Option>, + codec: Arc, +} impl Default for QueryPlanner { fn default() -> Self { - Self::new() + Self::new(None) } } impl QueryPlanner { - pub fn new() -> Self { - Self + pub fn new(customizer: Option>) -> Self { + let codec = Arc::new(DDCodec::new( + customizer + .clone() + .map(|c| c as Arc) + .or(Some(Arc::new(DefaultPhysicalExtensionCodec {}))) + .unwrap(), + )); + + Self { customizer, codec } } /// Common planning steps shared by both query and its EXPLAIN @@ -73,7 +87,13 @@ impl QueryPlanner { /// Prepare a query by parsing the SQL, planning it, and distributing the /// physical plan into stages that can be executed by workers. pub async fn prepare(&self, sql: &str) -> Result { - let ctx = get_ctx().map_err(|e| anyhow!("Could not create context: {e}"))?; + let mut ctx = get_ctx().map_err(|e| anyhow!("Could not create context: {e}"))?; + if let Some(customizer) = &self.customizer { + customizer + .customize(&mut ctx) + .await + .map_err(|e| anyhow!("Customization failed: {e:#?}"))?; + } let logical_plan = logical_planning(sql, &ctx).await?; @@ -153,6 +173,7 @@ impl QueryPlanner { &query_plan.physical_plan, &query_plan.distributed_plan, &query_plan.distributed_tasks, + self.codec.as_ref(), )?; let physical_plan = Arc::new(RecordBatchExec::new(batch)); @@ -185,8 +206,13 @@ impl QueryPlanner { // distribute the stages to workers, further dividing them up // into chunks of partitions (partition_groups) - let (final_workers, tasks) = - distribute_stages(&query_id, distributed_stages, worker_addrs).await?; + let (final_workers, tasks) = distribute_stages( + &query_id, + distributed_stages, + worker_addrs, + self.codec.as_ref(), + ) + .await?; let qp = QueryPlan { query_id, diff --git a/src/util.rs b/src/util.rs index 20b04b13..0d9bd97b 100644 --- a/src/util.rs +++ b/src/util.rs @@ -38,7 +38,7 @@ use datafusion::{ }, prelude::SessionContext, }; -use datafusion_proto::physical_plan::AsExecutionPlan; +use datafusion_proto::physical_plan::{AsExecutionPlan, PhysicalExtensionCodec}; use futures::{stream::BoxStream, Stream, StreamExt}; use object_store::{ aws::AmazonS3Builder, gcp::GoogleCloudStorageBuilder, http::HttpBuilder, ObjectStore, @@ -133,13 +133,15 @@ where out } -pub fn physical_plan_to_bytes(plan: Arc) -> Result, DataFusionError> { +pub fn physical_plan_to_bytes( + plan: Arc, + codec: &dyn PhysicalExtensionCodec, +) -> Result, DataFusionError> { trace!( "serializing plan to bytes. plan:\n{}", display_plan_with_partition_counts(&plan) ); - let codec = DDCodec {}; - let proto = datafusion_proto::protobuf::PhysicalPlanNode::try_from_physical_plan(plan, &codec)?; + let proto = datafusion_proto::protobuf::PhysicalPlanNode::try_from_physical_plan(plan, codec)?; let bytes = proto.encode_to_vec(); Ok(bytes) @@ -148,11 +150,11 @@ pub fn physical_plan_to_bytes(plan: Arc) -> Result, D pub fn bytes_to_physical_plan( ctx: &SessionContext, plan_bytes: &[u8], + codec: &dyn PhysicalExtensionCodec, ) -> Result, DataFusionError> { let proto_plan = datafusion_proto::protobuf::PhysicalPlanNode::try_decode(plan_bytes)?; - let codec = DDCodec {}; - let plan = proto_plan.try_into_physical_plan(ctx, ctx.runtime_env().as_ref(), &codec)?; + let plan = proto_plan.try_into_physical_plan(ctx, ctx.runtime_env().as_ref(), codec)?; Ok(plan) } diff --git a/src/worker_service.rs b/src/worker_service.rs index 8500984f..af8a2085 100644 --- a/src/worker_service.rs +++ b/src/worker_service.rs @@ -33,6 +33,7 @@ use datafusion::{ physical_plan::{ExecutionPlan, ExecutionPlanProperties}, prelude::SessionContext, }; +use datafusion_proto::physical_plan::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec}; use futures::{StreamExt, TryStreamExt}; use parking_lot::{Mutex, RwLock}; use prost::Message; @@ -44,6 +45,8 @@ use tonic::{async_trait, transport::Server, Request, Response, Status}; use crate::{ analyze::DistributedAnalyzeExec, + codec::DDCodec, + customizer::Customizer, flight::{FlightHandler, FlightServ}, logging::{debug, error, info, trace}, planning::{add_ctx_extentions, get_ctx}, @@ -94,10 +97,15 @@ struct DDWorkerHandler { /// our map of query_id -> (session ctx, execution plan) stages: Arc>>, done: Arc>, + + /// Optional customizer for our context and proto serde + pub customizer: Option>, + + codec: Arc, } impl DDWorkerHandler { - pub fn new(name: String, addr: String) -> Self { + pub fn new(name: String, addr: String, customizer: Option>) -> Self { let stages: Arc>> = Arc::new(RwLock::new(HashMap::new())); let done = Arc::new(Mutex::new(false)); @@ -153,11 +161,21 @@ impl DDWorkerHandler { } }); + let codec = Arc::new(DDCodec::new( + customizer + .clone() + .map(|c| c as Arc) + .or(Some(Arc::new(DefaultPhysicalExtensionCodec {}))) + .unwrap(), + )); + Self { name, addr, stages, done, + customizer, + codec, } } @@ -185,10 +203,11 @@ impl DDWorkerHandler { ) .await?; - let plan = bytes_to_physical_plan(&ctx, plan_bytes).context(format!( - "{}, Could not decode plan for query_id {} stage {}", - self.name, query_id, stage_id - ))?; + let plan = + bytes_to_physical_plan(&ctx, plan_bytes, self.codec.as_ref()).context(format!( + "{}, Could not decode plan for query_id {} stage {}", + self.name, query_id, stage_id + ))?; let partitions = if full_partitions { partition_group.clone() @@ -626,7 +645,11 @@ pub struct DDWorkerService { } impl DDWorkerService { - pub async fn new(name: String, port: usize) -> Result { + pub async fn new( + name: String, + port: usize, + customizer: Option>, + ) -> Result { let name = format!("[{}]", name); let (all_done_tx, all_done_rx) = channel(1); @@ -638,7 +661,7 @@ impl DDWorkerService { info!("DDWorkerService bound to {addr}"); - let handler = Arc::new(DDWorkerHandler::new(name.clone(), addr.clone())); + let handler = Arc::new(DDWorkerHandler::new(name.clone(), addr.clone(), customizer)); Ok(Self { name,