Skip to content

Commit 73af635

Browse files
committed
Add roundtrip tests
1 parent 875a829 commit 73af635

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed

src/plan/codec.rs

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,98 @@ pub struct ArrowFlightReadExecProto {
146146
#[prost(uint64, tag = "3")]
147147
stage_num: u64,
148148
}
149+
150+
#[cfg(test)]
151+
mod tests {
152+
use super::*;
153+
use datafusion::arrow::datatypes::{DataType, Field};
154+
use datafusion::{
155+
execution::registry::MemoryFunctionRegistry,
156+
physical_expr::{expressions::col, expressions::Column, Partitioning, PhysicalSortExpr},
157+
physical_plan::{displayable, sorts::sort::SortExec, union::UnionExec, ExecutionPlan},
158+
};
159+
160+
type TestCase = (
161+
&'static str,
162+
Arc<dyn ExecutionPlan>,
163+
Vec<Arc<dyn ExecutionPlan>>,
164+
);
165+
166+
fn schema_i32(name: &str) -> Arc<Schema> {
167+
Arc::new(Schema::new(vec![Field::new(name, DataType::Int32, false)]))
168+
}
169+
170+
fn repr(plan: &Arc<dyn ExecutionPlan>) -> String {
171+
displayable(plan.as_ref()).indent(true).to_string()
172+
}
173+
174+
#[test]
175+
fn distributed_codec_roundtrips() -> datafusion::common::Result<()> {
176+
let codec = DistributedCodec;
177+
let registry = MemoryFunctionRegistry::new();
178+
179+
let mut cases: Vec<TestCase> = Vec::new();
180+
181+
// ArrowFlightReadExec
182+
let schema = schema_i32("a");
183+
let part = Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 4);
184+
let plan: Arc<dyn ExecutionPlan> = Arc::new(ArrowFlightReadExec::new(part, schema, 0));
185+
cases.push(("single_flight", plan, vec![]));
186+
187+
// PartitionIsolatorExec -> ArrowFlightReadExec
188+
let schema = schema_i32("b");
189+
let flight = Arc::new(ArrowFlightReadExec::new(
190+
Partitioning::UnknownPartitioning(1),
191+
schema,
192+
0,
193+
));
194+
let plan: Arc<dyn ExecutionPlan> = Arc::new(PartitionIsolatorExec::new(flight.clone(), 3));
195+
cases.push(("isolator_flight", plan, vec![flight]));
196+
197+
// PartitionIsolatorExec -> UnionExec(ArrowFlightReadExec)
198+
let schema = schema_i32("c");
199+
let left = Arc::new(ArrowFlightReadExec::new(
200+
Partitioning::RoundRobinBatch(2),
201+
schema.clone(),
202+
0,
203+
));
204+
let right = Arc::new(ArrowFlightReadExec::new(
205+
Partitioning::RoundRobinBatch(2),
206+
schema.clone(),
207+
1,
208+
));
209+
let union = Arc::new(UnionExec::new(vec![left.clone(), right.clone()]));
210+
let plan: Arc<dyn ExecutionPlan> = Arc::new(PartitionIsolatorExec::new(union.clone(), 5));
211+
cases.push(("isolator_union", plan, vec![union]));
212+
213+
// PartitionIsolatorExec -> SortExec -> ArrowFlightReadExec
214+
let schema = schema_i32("d");
215+
let flight = Arc::new(ArrowFlightReadExec::new(
216+
Partitioning::UnknownPartitioning(1),
217+
schema.clone(),
218+
0,
219+
));
220+
let sort_expr = PhysicalSortExpr {
221+
expr: col("d", &schema)?,
222+
options: Default::default(),
223+
};
224+
let sort = Arc::new(SortExec::new(vec![sort_expr].into(), flight.clone()));
225+
let plan: Arc<dyn ExecutionPlan> = Arc::new(PartitionIsolatorExec::new(sort.clone(), 2));
226+
cases.push(("isolator_sort_flight", plan, vec![sort]));
227+
228+
// Test each case
229+
for (name, original, inputs) in cases {
230+
let mut buf = Vec::new();
231+
codec.try_encode(original.clone(), &mut buf)?;
232+
233+
let decoded = codec.try_decode(&buf, &inputs, &registry)?;
234+
235+
assert_eq!(
236+
repr(&original),
237+
repr(&decoded),
238+
"mismatch after round-trip for {name}"
239+
);
240+
}
241+
Ok(())
242+
}
243+
}

0 commit comments

Comments
 (0)