diff --git a/src/plan/codec.rs b/src/plan/codec.rs index 314a73f..f4b3871 100644 --- a/src/plan/codec.rs +++ b/src/plan/codec.rs @@ -84,22 +84,30 @@ impl PhysicalExtensionCodec for DistributedCodec { buf: &mut Vec, ) -> datafusion::common::Result<()> { if let Some(node) = node.as_any().downcast_ref::() { - ArrowFlightReadExecProto { + let inner = ArrowFlightReadExecProto { schema: Some(node.schema().try_into()?), partitioning: Some(serialize_partitioning( node.properties().output_partitioning(), &DistributedCodec {}, )?), stage_num: node.stage_num as u64, - } - .encode(buf) - .map_err(|err| proto_error(format!("{err}"))) + }; + + let wrapper = DistributedExecProto { + node: Some(DistributedExecNode::ArrowFlightReadExec(inner)), + }; + + wrapper.encode(buf).map_err(|e| proto_error(format!("{e}"))) } else if let Some(node) = node.as_any().downcast_ref::() { - PartitionIsolatorExecProto { + let inner = PartitionIsolatorExecProto { partition_count: node.partition_count as u64, - } - .encode(buf) - .map_err(|err| proto_error(format!("{err}"))) + }; + + let wrapper = DistributedExecProto { + node: Some(DistributedExecNode::PartitionIsolatorExec(inner)), + }; + + wrapper.encode(buf).map_err(|e| proto_error(format!("{e}"))) } else { Err(proto_error(format!("Unexpected plan {}", node.name()))) } @@ -138,3 +146,121 @@ pub struct ArrowFlightReadExecProto { #[prost(uint64, tag = "3")] stage_num: u64, } + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::arrow::datatypes::{DataType, Field}; + use datafusion::{ + execution::registry::MemoryFunctionRegistry, + physical_expr::{expressions::col, expressions::Column, Partitioning, PhysicalSortExpr}, + physical_plan::{displayable, sorts::sort::SortExec, union::UnionExec, ExecutionPlan}, + }; + + fn schema_i32(name: &str) -> Arc { + Arc::new(Schema::new(vec![Field::new(name, DataType::Int32, false)])) + } + + fn repr(plan: &Arc) -> String { + displayable(plan.as_ref()).indent(true).to_string() + } + + #[test] + fn test_roundtrip_single_flight() -> datafusion::common::Result<()> { + let codec = DistributedCodec; + let registry = MemoryFunctionRegistry::new(); + + let schema = schema_i32("a"); + let part = Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 4); + let plan: Arc = Arc::new(ArrowFlightReadExec::new(part, schema, 0)); + + let mut buf = Vec::new(); + codec.try_encode(plan.clone(), &mut buf)?; + + let decoded = codec.try_decode(&buf, &[], ®istry)?; + assert_eq!(repr(&plan), repr(&decoded)); + + Ok(()) + } + + #[test] + fn test_roundtrip_isolator_flight() -> datafusion::common::Result<()> { + let codec = DistributedCodec; + let registry = MemoryFunctionRegistry::new(); + + let schema = schema_i32("b"); + let flight = Arc::new(ArrowFlightReadExec::new( + Partitioning::UnknownPartitioning(1), + schema, + 0, + )); + + let plan: Arc = Arc::new(PartitionIsolatorExec::new(flight.clone(), 3)); + + let mut buf = Vec::new(); + codec.try_encode(plan.clone(), &mut buf)?; + + let decoded = codec.try_decode(&buf, &[flight], ®istry)?; + assert_eq!(repr(&plan), repr(&decoded)); + + Ok(()) + } + + #[test] + fn test_roundtrip_isolator_union() -> datafusion::common::Result<()> { + let codec = DistributedCodec; + let registry = MemoryFunctionRegistry::new(); + + let schema = schema_i32("c"); + let left = Arc::new(ArrowFlightReadExec::new( + Partitioning::RoundRobinBatch(2), + schema.clone(), + 0, + )); + let right = Arc::new(ArrowFlightReadExec::new( + Partitioning::RoundRobinBatch(2), + schema.clone(), + 1, + )); + + let union = Arc::new(UnionExec::new(vec![left.clone(), right.clone()])); + let plan: Arc = Arc::new(PartitionIsolatorExec::new(union.clone(), 5)); + + let mut buf = Vec::new(); + codec.try_encode(plan.clone(), &mut buf)?; + + let decoded = codec.try_decode(&buf, &[union], ®istry)?; + assert_eq!(repr(&plan), repr(&decoded)); + + Ok(()) + } + + #[test] + fn test_roundtrip_isolator_sort_flight() -> datafusion::common::Result<()> { + let codec = DistributedCodec; + let registry = MemoryFunctionRegistry::new(); + + let schema = schema_i32("d"); + let flight = Arc::new(ArrowFlightReadExec::new( + Partitioning::UnknownPartitioning(1), + schema.clone(), + 0, + )); + + let sort_expr = PhysicalSortExpr { + expr: col("d", &schema)?, + options: Default::default(), + }; + let sort = Arc::new(SortExec::new(vec![sort_expr].into(), flight.clone())); + + let plan: Arc = Arc::new(PartitionIsolatorExec::new(sort.clone(), 2)); + + let mut buf = Vec::new(); + codec.try_encode(plan.clone(), &mut buf)?; + + let decoded = codec.try_decode(&buf, &[sort], ®istry)?; + assert_eq!(repr(&plan), repr(&decoded)); + + Ok(()) + } +}