Skip to content

Commit ba7282a

Browse files
authored
Merge pull request #76 from datafusion-contrib/lia/fix-serialization-bug
Fix serialization error
2 parents ccf36a1 + bfb00ad commit ba7282a

File tree

1 file changed

+134
-8
lines changed

1 file changed

+134
-8
lines changed

src/plan/codec.rs

Lines changed: 134 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,22 +84,30 @@ impl PhysicalExtensionCodec for DistributedCodec {
8484
buf: &mut Vec<u8>,
8585
) -> datafusion::common::Result<()> {
8686
if let Some(node) = node.as_any().downcast_ref::<ArrowFlightReadExec>() {
87-
ArrowFlightReadExecProto {
87+
let inner = ArrowFlightReadExecProto {
8888
schema: Some(node.schema().try_into()?),
8989
partitioning: Some(serialize_partitioning(
9090
node.properties().output_partitioning(),
9191
&DistributedCodec {},
9292
)?),
9393
stage_num: node.stage_num as u64,
94-
}
95-
.encode(buf)
96-
.map_err(|err| proto_error(format!("{err}")))
94+
};
95+
96+
let wrapper = DistributedExecProto {
97+
node: Some(DistributedExecNode::ArrowFlightReadExec(inner)),
98+
};
99+
100+
wrapper.encode(buf).map_err(|e| proto_error(format!("{e}")))
97101
} else if let Some(node) = node.as_any().downcast_ref::<PartitionIsolatorExec>() {
98-
PartitionIsolatorExecProto {
102+
let inner = PartitionIsolatorExecProto {
99103
partition_count: node.partition_count as u64,
100-
}
101-
.encode(buf)
102-
.map_err(|err| proto_error(format!("{err}")))
104+
};
105+
106+
let wrapper = DistributedExecProto {
107+
node: Some(DistributedExecNode::PartitionIsolatorExec(inner)),
108+
};
109+
110+
wrapper.encode(buf).map_err(|e| proto_error(format!("{e}")))
103111
} else {
104112
Err(proto_error(format!("Unexpected plan {}", node.name())))
105113
}
@@ -138,3 +146,121 @@ pub struct ArrowFlightReadExecProto {
138146
#[prost(uint64, tag = "3")]
139147
stage_num: u64,
140148
}
149+
150+
#[cfg(test)]
151+
mod tests {
152+
use super::*;
153+
use datafusion::arrow::datatypes::{DataType, Field};
154+
use datafusion::{
155+
execution::registry::MemoryFunctionRegistry,
156+
physical_expr::{expressions::col, expressions::Column, Partitioning, PhysicalSortExpr},
157+
physical_plan::{displayable, sorts::sort::SortExec, union::UnionExec, ExecutionPlan},
158+
};
159+
160+
fn schema_i32(name: &str) -> Arc<Schema> {
161+
Arc::new(Schema::new(vec![Field::new(name, DataType::Int32, false)]))
162+
}
163+
164+
fn repr(plan: &Arc<dyn ExecutionPlan>) -> String {
165+
displayable(plan.as_ref()).indent(true).to_string()
166+
}
167+
168+
#[test]
169+
fn test_roundtrip_single_flight() -> datafusion::common::Result<()> {
170+
let codec = DistributedCodec;
171+
let registry = MemoryFunctionRegistry::new();
172+
173+
let schema = schema_i32("a");
174+
let part = Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 4);
175+
let plan: Arc<dyn ExecutionPlan> = Arc::new(ArrowFlightReadExec::new(part, schema, 0));
176+
177+
let mut buf = Vec::new();
178+
codec.try_encode(plan.clone(), &mut buf)?;
179+
180+
let decoded = codec.try_decode(&buf, &[], &registry)?;
181+
assert_eq!(repr(&plan), repr(&decoded));
182+
183+
Ok(())
184+
}
185+
186+
#[test]
187+
fn test_roundtrip_isolator_flight() -> datafusion::common::Result<()> {
188+
let codec = DistributedCodec;
189+
let registry = MemoryFunctionRegistry::new();
190+
191+
let schema = schema_i32("b");
192+
let flight = Arc::new(ArrowFlightReadExec::new(
193+
Partitioning::UnknownPartitioning(1),
194+
schema,
195+
0,
196+
));
197+
198+
let plan: Arc<dyn ExecutionPlan> = Arc::new(PartitionIsolatorExec::new(flight.clone(), 3));
199+
200+
let mut buf = Vec::new();
201+
codec.try_encode(plan.clone(), &mut buf)?;
202+
203+
let decoded = codec.try_decode(&buf, &[flight], &registry)?;
204+
assert_eq!(repr(&plan), repr(&decoded));
205+
206+
Ok(())
207+
}
208+
209+
#[test]
210+
fn test_roundtrip_isolator_union() -> datafusion::common::Result<()> {
211+
let codec = DistributedCodec;
212+
let registry = MemoryFunctionRegistry::new();
213+
214+
let schema = schema_i32("c");
215+
let left = Arc::new(ArrowFlightReadExec::new(
216+
Partitioning::RoundRobinBatch(2),
217+
schema.clone(),
218+
0,
219+
));
220+
let right = Arc::new(ArrowFlightReadExec::new(
221+
Partitioning::RoundRobinBatch(2),
222+
schema.clone(),
223+
1,
224+
));
225+
226+
let union = Arc::new(UnionExec::new(vec![left.clone(), right.clone()]));
227+
let plan: Arc<dyn ExecutionPlan> = Arc::new(PartitionIsolatorExec::new(union.clone(), 5));
228+
229+
let mut buf = Vec::new();
230+
codec.try_encode(plan.clone(), &mut buf)?;
231+
232+
let decoded = codec.try_decode(&buf, &[union], &registry)?;
233+
assert_eq!(repr(&plan), repr(&decoded));
234+
235+
Ok(())
236+
}
237+
238+
#[test]
239+
fn test_roundtrip_isolator_sort_flight() -> datafusion::common::Result<()> {
240+
let codec = DistributedCodec;
241+
let registry = MemoryFunctionRegistry::new();
242+
243+
let schema = schema_i32("d");
244+
let flight = Arc::new(ArrowFlightReadExec::new(
245+
Partitioning::UnknownPartitioning(1),
246+
schema.clone(),
247+
0,
248+
));
249+
250+
let sort_expr = PhysicalSortExpr {
251+
expr: col("d", &schema)?,
252+
options: Default::default(),
253+
};
254+
let sort = Arc::new(SortExec::new(vec![sort_expr].into(), flight.clone()));
255+
256+
let plan: Arc<dyn ExecutionPlan> = Arc::new(PartitionIsolatorExec::new(sort.clone(), 2));
257+
258+
let mut buf = Vec::new();
259+
codec.try_encode(plan.clone(), &mut buf)?;
260+
261+
let decoded = codec.try_decode(&buf, &[sort], &registry)?;
262+
assert_eq!(repr(&plan), repr(&decoded));
263+
264+
Ok(())
265+
}
266+
}

0 commit comments

Comments
 (0)