Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ path = "src/bin/distributed-datafusion.rs"
anyhow = "1"
arrow = { version = "55.1", features = ["ipc"] }
arrow-flight = { version = "55", features = ["flight-sql-experimental"] }
async-trait = "0.1.88"
async-stream = "0.3"
bytes = "1.5"
clap = { version = "4.4", features = ["derive"] }
Expand Down Expand Up @@ -67,7 +68,7 @@ rustls = "0.23"
test-log = "0.2"
thiserror = "1.0"

tokio = { version = "1.0", features = ["full"] }
tokio = { version = "1.46", features = ["full"] }
tokio-stream = { version = "0.1", features = ["full"] }

tonic = { version = "0.12", default-features = false, features = [
Expand Down
4 changes: 2 additions & 2 deletions src/bin/distributed-datafusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ async fn main() -> Result<()> {

match args.mode.as_str() {
"proxy" => {
let service = DDProxyService::new(new_friendly_name()?, args.port).await?;
let service = DDProxyService::new(new_friendly_name()?, args.port, None).await?;
service.serve().await?;
}
"worker" => {
let service = DDWorkerService::new(new_friendly_name()?, args.port).await?;
let service = DDWorkerService::new(new_friendly_name()?, args.port, None).await?;
service.serve().await?;
}
_ => {
Expand Down
69 changes: 48 additions & 21 deletions src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,21 @@ use crate::{
};

#[derive(Debug)]
pub struct DDCodec {}
pub struct DDCodec {
sub_codec: Arc<dyn PhysicalExtensionCodec>,
}

impl DDCodec {
pub fn new(sub_codec: Arc<dyn PhysicalExtensionCodec>) -> Self {
Self { sub_codec }
}
}

impl Default for DDCodec {
fn default() -> Self {
Self::new(Arc::new(DefaultPhysicalExtensionCodec {}))
}
}

impl PhysicalExtensionCodec for DDCodec {
fn try_decode(
Expand Down Expand Up @@ -127,6 +141,13 @@ impl PhysicalExtensionCodec for DDCodec {
Ok(Arc::new(RecordBatchExec::new(batch)))
}
}
} else if let Ok(ext) = self.sub_codec.try_decode(buf, inputs, registry) {
// If the node is not a DDExecNode, we delegate to the sub codec
trace!(
"Delegated decoding to sub codec for node: {}",
displayable(ext.as_ref()).one_line()
);
Ok(ext)
} else {
internal_err!("cannot decode proto extension in distributed datafusion codec")
}
Expand All @@ -151,52 +172,58 @@ impl PhysicalExtensionCodec for DDCodec {
stage_id: reader.stage_id,
};

Payload::StageReaderExec(pb)
Some(Payload::StageReaderExec(pb))
} else if let Some(pi) = node.as_any().downcast_ref::<PartitionIsolatorExec>() {
let pb = PartitionIsolatorExecNode {
partition_count: pi.partition_count as u64,
};

Payload::IsolatorExec(pb)
Some(Payload::IsolatorExec(pb))
} else if let Some(max) = node.as_any().downcast_ref::<MaxRowsExec>() {
let pb = MaxRowsExecNode {
max_rows: max.max_rows as u64,
};
Payload::MaxRowsExec(pb)
Some(Payload::MaxRowsExec(pb))
} else if let Some(exec) = node.as_any().downcast_ref::<DistributedAnalyzeExec>() {
let pb = DistributedAnalyzeExecNode {
verbose: exec.verbose,
show_statistics: exec.show_statistics,
};
Payload::DistributedAnalyzeExec(pb)
Some(Payload::DistributedAnalyzeExec(pb))
} else if let Some(exec) = node.as_any().downcast_ref::<DistributedAnalyzeRootExec>() {
let pb = DistributedAnalyzeRootExecNode {
verbose: exec.verbose,
show_statistics: exec.show_statistics,
};
Payload::DistributedAnalyzeRootExec(pb)
Some(Payload::DistributedAnalyzeRootExec(pb))
} else if let Some(exec) = node.as_any().downcast_ref::<RecordBatchExec>() {
let pb = RecordBatchExecNode {
batch: batch_to_ipc(&exec.batch).map_err(|e| {
internal_datafusion_err!("Failed to encode RecordBatch: {:#?}", e)
})?,
};
Payload::RecordBatchExec(pb)
Some(Payload::RecordBatchExec(pb))
} else {
return internal_err!("Not supported node to encode to proto");
trace!(
"Node {} is not a custom DDExecNode, delegating to sub codec",
displayable(node.as_ref()).one_line()
);
None
};

let pb = DdExecNode {
payload: Some(payload),
};
pb.encode(buf)
.map_err(|e| internal_datafusion_err!("Failed to encode protobuf: {}", e))?;

trace!(
"DONE encoding node: {}",
displayable(node.as_ref()).one_line()
);
Ok(())
match payload {
Some(payload) => {
let pb = DdExecNode {
payload: Some(payload),
};
pb.encode(buf)
.map_err(|e| internal_datafusion_err!("Failed to encode protobuf: {:#?}", e))
}
None => {
// If the node is not one of our custom nodes, we delegate to the sub codec
self.sub_codec.try_encode(node, buf)
}
}
}
}

Expand Down Expand Up @@ -225,7 +252,7 @@ mod test {

fn verify_round_trip(exec: Arc<dyn ExecutionPlan>) {
let ctx = SessionContext::new();
let codec = DDCodec {};
let codec = DDCodec::new(Arc::new(DefaultPhysicalExtensionCodec {}));

// serialize execution plan to proto
let proto: protobuf::PhysicalPlanNode =
Expand Down Expand Up @@ -255,7 +282,7 @@ mod test {
let schema = create_test_schema();
let part = Partitioning::UnknownPartitioning(2);
let exec = Arc::new(DDStageReaderExec::try_new(part, schema, 1).unwrap());
let codec = DDCodec {};
let codec = DDCodec::new(Arc::new(DefaultPhysicalExtensionCodec {}));
let mut buf = vec![];
codec.try_encode(exec.clone(), &mut buf).unwrap();
let ctx = SessionContext::new();
Expand Down
8 changes: 8 additions & 0 deletions src/customizer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
use datafusion::prelude::SessionContext;
use datafusion_proto::physical_plan::PhysicalExtensionCodec;

#[async_trait::async_trait]
pub trait Customizer: PhysicalExtensionCodec + Send + Sync {
/// Customize the context before planning a a query.
async fn customize(&self, ctx: &mut SessionContext) -> Result<(), Box<dyn std::error::Error>>;
}
11 changes: 8 additions & 3 deletions src/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@ use datafusion::{
physical_plan::{displayable, ExecutionPlan},
prelude::SessionContext,
};
use datafusion_proto::physical_plan::PhysicalExtensionCodec;

use crate::{result::Result, util::bytes_to_physical_plan, vocab::DDTask};

pub fn format_distributed_tasks(tasks: &[DDTask]) -> Result<String> {
pub fn format_distributed_tasks(
tasks: &[DDTask],
codec: &dyn PhysicalExtensionCodec,
) -> Result<String> {
let mut result = String::new();
for (i, task) in tasks.iter().enumerate() {
let plan = bytes_to_physical_plan(&SessionContext::new(), &task.plan_bytes)
let plan = bytes_to_physical_plan(&SessionContext::new(), &task.plan_bytes, codec)
.context("unable to decode task plan for formatted output")?;

result.push_str(&format!(
Expand Down Expand Up @@ -45,6 +49,7 @@ pub fn build_explain_batch(
physical_plan: &Arc<dyn ExecutionPlan>,
distributed_plan: &Arc<dyn ExecutionPlan>,
distributed_tasks: &[DDTask],
codec: &dyn PhysicalExtensionCodec,
) -> Result<RecordBatch> {
let schema = Arc::new(Schema::new(vec![
Field::new("plan_type", DataType::Utf8, false),
Expand All @@ -64,7 +69,7 @@ pub fn build_explain_batch(
displayable(distributed_plan.as_ref())
.indent(true)
.to_string(),
format_distributed_tasks(distributed_tasks)?,
format_distributed_tasks(distributed_tasks, codec)?,
]);

let batch = RecordBatch::try_new(schema, vec![Arc::new(plan_types), Arc::new(plans)])?;
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub use proto::generated::protobuf;

pub mod analyze;
pub mod codec;
pub mod customizer;
pub mod explain;
pub mod flight;
pub mod friendly;
Expand Down
7 changes: 5 additions & 2 deletions src/planning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use datafusion::{
},
prelude::{SQLOptions, SessionConfig, SessionContext},
};
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
use futures::TryStreamExt;
use itertools::Itertools;
use prost::Message;
Expand Down Expand Up @@ -439,6 +440,7 @@ pub async fn distribute_stages(
query_id: &str,
stages: Vec<DDStage>,
worker_addrs: Vec<Host>,
codec: &dyn PhysicalExtensionCodec,
) -> Result<(Addrs, Vec<DDTask>)> {
// map of worker name to address
// FIXME: use types over tuples of strings, as we can accidently swap them and
Expand All @@ -457,7 +459,7 @@ pub async fn distribute_stages(

// all stages to workers
let (task_datas, final_addrs) =
assign_to_workers(query_id, &stages, workers.values().collect())?;
assign_to_workers(query_id, &stages, workers.values().collect(), codec)?;

// we retry this a few times to ensure that the workers are ready
// and can accept the stages
Expand Down Expand Up @@ -551,6 +553,7 @@ fn assign_to_workers(
query_id: &str,
stages: &[DDStage],
worker_addrs: Vec<&Host>,
codec: &dyn PhysicalExtensionCodec,
) -> Result<(Vec<DDTask>, Addrs)> {
let mut task_datas = vec![];
let mut worker_idx = 0;
Expand All @@ -570,7 +573,7 @@ fn assign_to_workers(

for stage in stages {
for partition_group in stage.partition_groups.iter() {
let plan_bytes = physical_plan_to_bytes(stage.plan.clone())?;
let plan_bytes = physical_plan_to_bytes(stage.plan.clone(), codec)?;

let host = worker_addrs[worker_idx].clone();
worker_idx = (worker_idx + 1) % worker_addrs.len();
Expand Down
23 changes: 19 additions & 4 deletions src/proxy_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ use tokio::{
use tonic::{async_trait, transport::Server, Request, Response, Status};

use crate::{
customizer::Customizer,
flight::{FlightSqlHandler, FlightSqlServ},
logging::{debug, info, trace},
planning::{add_ctx_extentions, get_ctx},
Expand All @@ -57,10 +58,13 @@ pub struct DDProxyHandler {
pub host: Host,

pub planner: QueryPlanner,

/// Optional customizer for our context and proto serde
pub customizer: Option<Arc<dyn Customizer>>,
}

impl DDProxyHandler {
pub fn new(name: String, addr: String) -> Self {
pub fn new(name: String, addr: String, customizer: Option<Arc<dyn Customizer>>) -> Self {
// call this function to bootstrap the worker discovery mechanism
get_worker_addresses().expect("Could not get worker addresses upon startup");

Expand All @@ -70,7 +74,8 @@ impl DDProxyHandler {
};
Self {
host: host.clone(),
planner: QueryPlanner::new(),
planner: QueryPlanner::new(customizer.clone()),
customizer,
}
}

Expand Down Expand Up @@ -118,6 +123,12 @@ impl DDProxyHandler {
add_ctx_extentions(&mut ctx, &self.host, &query_id, stage_id, addrs, vec![])
.map_err(|e| Status::internal(format!("Could not add context extensions {e:?}")))?;

if let Some(ref c) = self.customizer {
c.customize(&mut ctx)
.await
.map_err(|e| Status::internal(format!("Could not customize context {e:?}")))?;
}

// TODO: revisit this to allow for consuming a partitular partition
trace!("calling execute plan");
let partition = 0;
Expand Down Expand Up @@ -278,7 +289,11 @@ pub struct DDProxyService {
}

impl DDProxyService {
pub async fn new(name: String, port: usize) -> Result<Self> {
pub async fn new(
name: String,
port: usize,
ctx_customizer: Option<Arc<dyn Customizer>>,
) -> Result<Self> {
debug!("Creating DDProxyService!");

let (all_done_tx, all_done_rx) = channel(1);
Expand All @@ -290,7 +305,7 @@ impl DDProxyService {

info!("DDProxyService bound to {addr}");

let handler = Arc::new(DDProxyHandler::new(name, addr.clone()));
let handler = Arc::new(DDProxyHandler::new(name, addr.clone(), ctx_customizer));

Ok(Self {
listener,
Expand Down
Loading
Loading