@@ -84,22 +84,30 @@ impl PhysicalExtensionCodec for DistributedCodec {
8484 buf : & mut Vec < u8 > ,
8585 ) -> datafusion:: common:: Result < ( ) > {
8686 if let Some ( node) = node. as_any ( ) . downcast_ref :: < ArrowFlightReadExec > ( ) {
87- ArrowFlightReadExecProto {
87+ let inner = ArrowFlightReadExecProto {
8888 schema : Some ( node. schema ( ) . try_into ( ) ?) ,
8989 partitioning : Some ( serialize_partitioning (
9090 node. properties ( ) . output_partitioning ( ) ,
9191 & DistributedCodec { } ,
9292 ) ?) ,
9393 stage_num : node. stage_num as u64 ,
94- }
95- . encode ( buf)
96- . map_err ( |err| proto_error ( format ! ( "{err}" ) ) )
94+ } ;
95+
96+ let wrapper = DistributedExecProto {
97+ node : Some ( DistributedExecNode :: ArrowFlightReadExec ( inner) ) ,
98+ } ;
99+
100+ wrapper. encode ( buf) . map_err ( |e| proto_error ( format ! ( "{e}" ) ) )
97101 } else if let Some ( node) = node. as_any ( ) . downcast_ref :: < PartitionIsolatorExec > ( ) {
98- PartitionIsolatorExecProto {
102+ let inner = PartitionIsolatorExecProto {
99103 partition_count : node. partition_count as u64 ,
100- }
101- . encode ( buf)
102- . map_err ( |err| proto_error ( format ! ( "{err}" ) ) )
104+ } ;
105+
106+ let wrapper = DistributedExecProto {
107+ node : Some ( DistributedExecNode :: PartitionIsolatorExec ( inner) ) ,
108+ } ;
109+
110+ wrapper. encode ( buf) . map_err ( |e| proto_error ( format ! ( "{e}" ) ) )
103111 } else {
104112 Err ( proto_error ( format ! ( "Unexpected plan {}" , node. name( ) ) ) )
105113 }
@@ -138,3 +146,121 @@ pub struct ArrowFlightReadExecProto {
138146 #[ prost( uint64, tag = "3" ) ]
139147 stage_num : u64 ,
140148}
149+
150+ #[ cfg( test) ]
151+ mod tests {
152+ use super :: * ;
153+ use datafusion:: arrow:: datatypes:: { DataType , Field } ;
154+ use datafusion:: {
155+ execution:: registry:: MemoryFunctionRegistry ,
156+ physical_expr:: { expressions:: col, expressions:: Column , Partitioning , PhysicalSortExpr } ,
157+ physical_plan:: { displayable, sorts:: sort:: SortExec , union:: UnionExec , ExecutionPlan } ,
158+ } ;
159+
160+ fn schema_i32 ( name : & str ) -> Arc < Schema > {
161+ Arc :: new ( Schema :: new ( vec ! [ Field :: new( name, DataType :: Int32 , false ) ] ) )
162+ }
163+
164+ fn repr ( plan : & Arc < dyn ExecutionPlan > ) -> String {
165+ displayable ( plan. as_ref ( ) ) . indent ( true ) . to_string ( )
166+ }
167+
168+ #[ test]
169+ fn test_roundtrip_single_flight ( ) -> datafusion:: common:: Result < ( ) > {
170+ let codec = DistributedCodec ;
171+ let registry = MemoryFunctionRegistry :: new ( ) ;
172+
173+ let schema = schema_i32 ( "a" ) ;
174+ let part = Partitioning :: Hash ( vec ! [ Arc :: new( Column :: new( "a" , 0 ) ) ] , 4 ) ;
175+ let plan: Arc < dyn ExecutionPlan > = Arc :: new ( ArrowFlightReadExec :: new ( part, schema, 0 ) ) ;
176+
177+ let mut buf = Vec :: new ( ) ;
178+ codec. try_encode ( plan. clone ( ) , & mut buf) ?;
179+
180+ let decoded = codec. try_decode ( & buf, & [ ] , & registry) ?;
181+ assert_eq ! ( repr( & plan) , repr( & decoded) ) ;
182+
183+ Ok ( ( ) )
184+ }
185+
186+ #[ test]
187+ fn test_roundtrip_isolator_flight ( ) -> datafusion:: common:: Result < ( ) > {
188+ let codec = DistributedCodec ;
189+ let registry = MemoryFunctionRegistry :: new ( ) ;
190+
191+ let schema = schema_i32 ( "b" ) ;
192+ let flight = Arc :: new ( ArrowFlightReadExec :: new (
193+ Partitioning :: UnknownPartitioning ( 1 ) ,
194+ schema,
195+ 0 ,
196+ ) ) ;
197+
198+ let plan: Arc < dyn ExecutionPlan > = Arc :: new ( PartitionIsolatorExec :: new ( flight. clone ( ) , 3 ) ) ;
199+
200+ let mut buf = Vec :: new ( ) ;
201+ codec. try_encode ( plan. clone ( ) , & mut buf) ?;
202+
203+ let decoded = codec. try_decode ( & buf, & [ flight] , & registry) ?;
204+ assert_eq ! ( repr( & plan) , repr( & decoded) ) ;
205+
206+ Ok ( ( ) )
207+ }
208+
209+ #[ test]
210+ fn test_roundtrip_isolator_union ( ) -> datafusion:: common:: Result < ( ) > {
211+ let codec = DistributedCodec ;
212+ let registry = MemoryFunctionRegistry :: new ( ) ;
213+
214+ let schema = schema_i32 ( "c" ) ;
215+ let left = Arc :: new ( ArrowFlightReadExec :: new (
216+ Partitioning :: RoundRobinBatch ( 2 ) ,
217+ schema. clone ( ) ,
218+ 0 ,
219+ ) ) ;
220+ let right = Arc :: new ( ArrowFlightReadExec :: new (
221+ Partitioning :: RoundRobinBatch ( 2 ) ,
222+ schema. clone ( ) ,
223+ 1 ,
224+ ) ) ;
225+
226+ let union = Arc :: new ( UnionExec :: new ( vec ! [ left. clone( ) , right. clone( ) ] ) ) ;
227+ let plan: Arc < dyn ExecutionPlan > = Arc :: new ( PartitionIsolatorExec :: new ( union. clone ( ) , 5 ) ) ;
228+
229+ let mut buf = Vec :: new ( ) ;
230+ codec. try_encode ( plan. clone ( ) , & mut buf) ?;
231+
232+ let decoded = codec. try_decode ( & buf, & [ union] , & registry) ?;
233+ assert_eq ! ( repr( & plan) , repr( & decoded) ) ;
234+
235+ Ok ( ( ) )
236+ }
237+
238+ #[ test]
239+ fn test_roundtrip_isolator_sort_flight ( ) -> datafusion:: common:: Result < ( ) > {
240+ let codec = DistributedCodec ;
241+ let registry = MemoryFunctionRegistry :: new ( ) ;
242+
243+ let schema = schema_i32 ( "d" ) ;
244+ let flight = Arc :: new ( ArrowFlightReadExec :: new (
245+ Partitioning :: UnknownPartitioning ( 1 ) ,
246+ schema. clone ( ) ,
247+ 0 ,
248+ ) ) ;
249+
250+ let sort_expr = PhysicalSortExpr {
251+ expr : col ( "d" , & schema) ?,
252+ options : Default :: default ( ) ,
253+ } ;
254+ let sort = Arc :: new ( SortExec :: new ( vec ! [ sort_expr] . into ( ) , flight. clone ( ) ) ) ;
255+
256+ let plan: Arc < dyn ExecutionPlan > = Arc :: new ( PartitionIsolatorExec :: new ( sort. clone ( ) , 2 ) ) ;
257+
258+ let mut buf = Vec :: new ( ) ;
259+ codec. try_encode ( plan. clone ( ) , & mut buf) ?;
260+
261+ let decoded = codec. try_decode ( & buf, & [ sort] , & registry) ?;
262+ assert_eq ! ( repr( & plan) , repr( & decoded) ) ;
263+
264+ Ok ( ( ) )
265+ }
266+ }
0 commit comments