@@ -146,3 +146,98 @@ pub struct ArrowFlightReadExecProto {
146146 #[ prost( uint64, tag = "3" ) ]
147147 stage_num : u64 ,
148148}
149+
150+ #[ cfg( test) ]
151+ mod tests {
152+ use super :: * ;
153+ use datafusion:: arrow:: datatypes:: { DataType , Field } ;
154+ use datafusion:: {
155+ execution:: registry:: MemoryFunctionRegistry ,
156+ physical_expr:: { expressions:: col, expressions:: Column , Partitioning , PhysicalSortExpr } ,
157+ physical_plan:: { displayable, sorts:: sort:: SortExec , union:: UnionExec , ExecutionPlan } ,
158+ } ;
159+
160+ type TestCase = (
161+ & ' static str ,
162+ Arc < dyn ExecutionPlan > ,
163+ Vec < Arc < dyn ExecutionPlan > > ,
164+ ) ;
165+
166+ fn schema_i32 ( name : & str ) -> Arc < Schema > {
167+ Arc :: new ( Schema :: new ( vec ! [ Field :: new( name, DataType :: Int32 , false ) ] ) )
168+ }
169+
170+ fn repr ( plan : & Arc < dyn ExecutionPlan > ) -> String {
171+ displayable ( plan. as_ref ( ) ) . indent ( true ) . to_string ( )
172+ }
173+
174+ #[ test]
175+ fn distributed_codec_roundtrips ( ) -> datafusion:: common:: Result < ( ) > {
176+ let codec = DistributedCodec ;
177+ let registry = MemoryFunctionRegistry :: new ( ) ;
178+
179+ let mut cases: Vec < TestCase > = Vec :: new ( ) ;
180+
181+ // ArrowFlightReadExec
182+ let schema = schema_i32 ( "a" ) ;
183+ let part = Partitioning :: Hash ( vec ! [ Arc :: new( Column :: new( "a" , 0 ) ) ] , 4 ) ;
184+ let plan: Arc < dyn ExecutionPlan > = Arc :: new ( ArrowFlightReadExec :: new ( part, schema, 0 ) ) ;
185+ cases. push ( ( "single_flight" , plan, vec ! [ ] ) ) ;
186+
187+ // PartitionIsolatorExec -> ArrowFlightReadExec
188+ let schema = schema_i32 ( "b" ) ;
189+ let flight = Arc :: new ( ArrowFlightReadExec :: new (
190+ Partitioning :: UnknownPartitioning ( 1 ) ,
191+ schema,
192+ 0 ,
193+ ) ) ;
194+ let plan: Arc < dyn ExecutionPlan > = Arc :: new ( PartitionIsolatorExec :: new ( flight. clone ( ) , 3 ) ) ;
195+ cases. push ( ( "isolator_flight" , plan, vec ! [ flight] ) ) ;
196+
197+ // PartitionIsolatorExec -> UnionExec(ArrowFlightReadExec)
198+ let schema = schema_i32 ( "c" ) ;
199+ let left = Arc :: new ( ArrowFlightReadExec :: new (
200+ Partitioning :: RoundRobinBatch ( 2 ) ,
201+ schema. clone ( ) ,
202+ 0 ,
203+ ) ) ;
204+ let right = Arc :: new ( ArrowFlightReadExec :: new (
205+ Partitioning :: RoundRobinBatch ( 2 ) ,
206+ schema. clone ( ) ,
207+ 1 ,
208+ ) ) ;
209+ let union = Arc :: new ( UnionExec :: new ( vec ! [ left. clone( ) , right. clone( ) ] ) ) ;
210+ let plan: Arc < dyn ExecutionPlan > = Arc :: new ( PartitionIsolatorExec :: new ( union. clone ( ) , 5 ) ) ;
211+ cases. push ( ( "isolator_union" , plan, vec ! [ union ] ) ) ;
212+
213+ // PartitionIsolatorExec -> SortExec -> ArrowFlightReadExec
214+ let schema = schema_i32 ( "d" ) ;
215+ let flight = Arc :: new ( ArrowFlightReadExec :: new (
216+ Partitioning :: UnknownPartitioning ( 1 ) ,
217+ schema. clone ( ) ,
218+ 0 ,
219+ ) ) ;
220+ let sort_expr = PhysicalSortExpr {
221+ expr : col ( "d" , & schema) ?,
222+ options : Default :: default ( ) ,
223+ } ;
224+ let sort = Arc :: new ( SortExec :: new ( vec ! [ sort_expr] . into ( ) , flight. clone ( ) ) ) ;
225+ let plan: Arc < dyn ExecutionPlan > = Arc :: new ( PartitionIsolatorExec :: new ( sort. clone ( ) , 2 ) ) ;
226+ cases. push ( ( "isolator_sort_flight" , plan, vec ! [ sort] ) ) ;
227+
228+ // Test each case
229+ for ( name, original, inputs) in cases {
230+ let mut buf = Vec :: new ( ) ;
231+ codec. try_encode ( original. clone ( ) , & mut buf) ?;
232+
233+ let decoded = codec. try_decode ( & buf, & inputs, & registry) ?;
234+
235+ assert_eq ! (
236+ repr( & original) ,
237+ repr( & decoded) ,
238+ "mismatch after round-trip for {name}"
239+ ) ;
240+ }
241+ Ok ( ( ) )
242+ }
243+ }
0 commit comments