@@ -8,11 +8,11 @@ use datafusion::arrow::datatypes::Schema;
88use datafusion:: arrow:: datatypes:: SchemaRef ;
99use datafusion:: common:: internal_datafusion_err;
1010use datafusion:: error:: DataFusionError ;
11- use datafusion:: execution:: { FunctionRegistry , SessionStateBuilder } ;
11+ use datafusion:: execution:: TaskContext ;
1212use datafusion:: physical_expr:: EquivalenceProperties ;
1313use datafusion:: physical_plan:: execution_plan:: { Boundedness , EmissionType } ;
1414use datafusion:: physical_plan:: { ExecutionPlan , Partitioning , PlanProperties } ;
15- use datafusion:: prelude:: { SessionConfig , SessionContext } ;
15+ use datafusion:: prelude:: SessionConfig ;
1616use datafusion_proto:: physical_plan:: from_proto:: parse_protobuf_partitioning;
1717use datafusion_proto:: physical_plan:: to_proto:: serialize_partitioning;
1818use datafusion_proto:: physical_plan:: { ComposedPhysicalExtensionCodec , PhysicalExtensionCodec } ;
@@ -40,7 +40,7 @@ impl PhysicalExtensionCodec for DistributedCodec {
4040 & self ,
4141 buf : & [ u8 ] ,
4242 inputs : & [ Arc < dyn ExecutionPlan > ] ,
43- registry : & dyn FunctionRegistry ,
43+ ctx : & TaskContext ,
4444 ) -> datafusion:: common:: Result < Arc < dyn ExecutionPlan > > {
4545 let DistributedExecProto {
4646 node : Some ( distributed_exec_node) ,
@@ -51,20 +51,6 @@ impl PhysicalExtensionCodec for DistributedCodec {
5151 ) ) ;
5252 } ;
5353
54- // TODO: The PhysicalExtensionCodec trait doesn't provide access to session state,
55- // so we create a new SessionContext which loses any custom UDFs, UDAFs, and other
56- // user configurations. This is a limitation of the current trait design.
57- let state = SessionStateBuilder :: new ( )
58- . with_scalar_functions (
59- registry
60- . udfs ( )
61- . iter ( )
62- . map ( |f| registry. udf ( f) )
63- . collect :: < Result < Vec < _ > , _ > > ( ) ?,
64- )
65- . build ( ) ;
66- let ctx = SessionContext :: from ( state) ;
67-
6854 fn parse_stage_proto (
6955 proto : Option < StageProto > ,
7056 inputs : & [ Arc < dyn ExecutionPlan > ] ,
@@ -114,7 +100,7 @@ impl PhysicalExtensionCodec for DistributedCodec {
114100
115101 let partitioning = parse_protobuf_partitioning (
116102 partitioning. as_ref ( ) ,
117- & ctx,
103+ ctx,
118104 & schema,
119105 & DistributedCodec { } ,
120106 ) ?
@@ -138,7 +124,7 @@ impl PhysicalExtensionCodec for DistributedCodec {
138124
139125 let partitioning = parse_protobuf_partitioning (
140126 partitioning. as_ref ( ) ,
141- & ctx,
127+ ctx,
142128 & schema,
143129 & DistributedCodec { } ,
144130 ) ?
@@ -403,11 +389,12 @@ mod tests {
403389 use datafusion:: physical_expr:: LexOrdering ;
404390 use datafusion:: physical_plan:: empty:: EmptyExec ;
405391 use datafusion:: {
406- execution:: registry:: MemoryFunctionRegistry ,
407392 physical_expr:: { Partitioning , PhysicalSortExpr , expressions:: Column , expressions:: col} ,
408393 physical_plan:: { ExecutionPlan , displayable, sorts:: sort:: SortExec , union:: UnionExec } ,
409394 } ;
410395
396+ use datafusion:: prelude:: SessionContext ;
397+
411398 fn empty_exec ( ) -> Arc < dyn ExecutionPlan > {
412399 Arc :: new ( EmptyExec :: new ( SchemaRef :: new ( Schema :: empty ( ) ) ) )
413400 }
@@ -429,10 +416,14 @@ mod tests {
429416 displayable ( plan. as_ref ( ) ) . indent ( true ) . to_string ( )
430417 }
431418
419+ fn create_context ( ) -> Arc < TaskContext > {
420+ SessionContext :: new ( ) . task_ctx ( )
421+ }
422+
432423 #[ test]
433424 fn test_roundtrip_single_flight ( ) -> datafusion:: common:: Result < ( ) > {
434425 let codec = DistributedCodec ;
435- let registry = MemoryFunctionRegistry :: new ( ) ;
426+ let ctx = create_context ( ) ;
436427
437428 let schema = schema_i32 ( "a" ) ;
438429 let part = Partitioning :: Hash ( vec ! [ Arc :: new( Column :: new( "a" , 0 ) ) ] , 4 ) ;
@@ -442,7 +433,7 @@ mod tests {
442433 let mut buf = Vec :: new ( ) ;
443434 codec. try_encode ( plan. clone ( ) , & mut buf) ?;
444435
445- let decoded = codec. try_decode ( & buf, & [ empty_exec ( ) ] , & registry ) ?;
436+ let decoded = codec. try_decode ( & buf, & [ empty_exec ( ) ] , & ctx ) ?;
446437 assert_eq ! ( repr( & plan) , repr( & decoded) ) ;
447438
448439 Ok ( ( ) )
@@ -451,7 +442,7 @@ mod tests {
451442 #[ test]
452443 fn test_roundtrip_isolator_flight ( ) -> datafusion:: common:: Result < ( ) > {
453444 let codec = DistributedCodec ;
454- let registry = MemoryFunctionRegistry :: new ( ) ;
445+ let ctx = create_context ( ) ;
455446
456447 let schema = schema_i32 ( "b" ) ;
457448 let flight = Arc :: new ( new_network_hash_shuffle_exec (
@@ -466,7 +457,7 @@ mod tests {
466457 let mut buf = Vec :: new ( ) ;
467458 codec. try_encode ( plan. clone ( ) , & mut buf) ?;
468459
469- let decoded = codec. try_decode ( & buf, & [ flight] , & registry ) ?;
460+ let decoded = codec. try_decode ( & buf, & [ flight] , & ctx ) ?;
470461 assert_eq ! ( repr( & plan) , repr( & decoded) ) ;
471462
472463 Ok ( ( ) )
@@ -475,7 +466,7 @@ mod tests {
475466 #[ test]
476467 fn test_roundtrip_isolator_union ( ) -> datafusion:: common:: Result < ( ) > {
477468 let codec = DistributedCodec ;
478- let registry = MemoryFunctionRegistry :: new ( ) ;
469+ let ctx = create_context ( ) ;
479470
480471 let schema = schema_i32 ( "c" ) ;
481472 let left = Arc :: new ( new_network_hash_shuffle_exec (
@@ -489,14 +480,14 @@ mod tests {
489480 dummy_stage ( ) ,
490481 ) ) ;
491482
492- let union = Arc :: new ( UnionExec :: new ( vec ! [ left. clone( ) , right. clone( ) ] ) ) ;
483+ let union = UnionExec :: try_new ( vec ! [ left. clone( ) , right. clone( ) ] ) ? ;
493484 let plan: Arc < dyn ExecutionPlan > =
494485 Arc :: new ( PartitionIsolatorExec :: new_ready ( union. clone ( ) , 1 ) ?) ;
495486
496487 let mut buf = Vec :: new ( ) ;
497488 codec. try_encode ( plan. clone ( ) , & mut buf) ?;
498489
499- let decoded = codec. try_decode ( & buf, & [ union] , & registry ) ?;
490+ let decoded = codec. try_decode ( & buf, & [ union] , & ctx ) ?;
500491 assert_eq ! ( repr( & plan) , repr( & decoded) ) ;
501492
502493 Ok ( ( ) )
@@ -505,7 +496,7 @@ mod tests {
505496 #[ test]
506497 fn test_roundtrip_isolator_sort_flight ( ) -> datafusion:: common:: Result < ( ) > {
507498 let codec = DistributedCodec ;
508- let registry = MemoryFunctionRegistry :: new ( ) ;
499+ let ctx = create_context ( ) ;
509500
510501 let schema = schema_i32 ( "d" ) ;
511502 let flight = Arc :: new ( new_network_hash_shuffle_exec (
@@ -529,7 +520,7 @@ mod tests {
529520 let mut buf = Vec :: new ( ) ;
530521 codec. try_encode ( plan. clone ( ) , & mut buf) ?;
531522
532- let decoded = codec. try_decode ( & buf, & [ sort] , & registry ) ?;
523+ let decoded = codec. try_decode ( & buf, & [ sort] , & ctx ) ?;
533524 assert_eq ! ( repr( & plan) , repr( & decoded) ) ;
534525
535526 Ok ( ( ) )
@@ -538,7 +529,7 @@ mod tests {
538529 #[ test]
539530 fn test_roundtrip_single_flight_coalesce ( ) -> datafusion:: common:: Result < ( ) > {
540531 let codec = DistributedCodec ;
541- let registry = MemoryFunctionRegistry :: new ( ) ;
532+ let ctx = create_context ( ) ;
542533
543534 let schema = schema_i32 ( "e" ) ;
544535 let plan: Arc < dyn ExecutionPlan > = Arc :: new ( new_network_coalesce_tasks_exec (
@@ -550,7 +541,7 @@ mod tests {
550541 let mut buf = Vec :: new ( ) ;
551542 codec. try_encode ( plan. clone ( ) , & mut buf) ?;
552543
553- let decoded = codec. try_decode ( & buf, & [ empty_exec ( ) ] , & registry ) ?;
544+ let decoded = codec. try_decode ( & buf, & [ empty_exec ( ) ] , & ctx ) ?;
554545 assert_eq ! ( repr( & plan) , repr( & decoded) ) ;
555546
556547 Ok ( ( ) )
@@ -559,7 +550,7 @@ mod tests {
559550 #[ test]
560551 fn test_roundtrip_isolator_flight_coalesce ( ) -> datafusion:: common:: Result < ( ) > {
561552 let codec = DistributedCodec ;
562- let registry = MemoryFunctionRegistry :: new ( ) ;
553+ let ctx = create_context ( ) ;
563554
564555 let schema = schema_i32 ( "f" ) ;
565556 let flight = Arc :: new ( new_network_coalesce_tasks_exec (
@@ -574,7 +565,7 @@ mod tests {
574565 let mut buf = Vec :: new ( ) ;
575566 codec. try_encode ( plan. clone ( ) , & mut buf) ?;
576567
577- let decoded = codec. try_decode ( & buf, & [ flight] , & registry ) ?;
568+ let decoded = codec. try_decode ( & buf, & [ flight] , & ctx ) ?;
578569 assert_eq ! ( repr( & plan) , repr( & decoded) ) ;
579570
580571 Ok ( ( ) )
@@ -583,7 +574,7 @@ mod tests {
583574 #[ test]
584575 fn test_roundtrip_isolator_union_coalesce ( ) -> datafusion:: common:: Result < ( ) > {
585576 let codec = DistributedCodec ;
586- let registry = MemoryFunctionRegistry :: new ( ) ;
577+ let ctx = create_context ( ) ;
587578
588579 let schema = schema_i32 ( "g" ) ;
589580 let left = Arc :: new ( new_network_coalesce_tasks_exec (
@@ -597,14 +588,14 @@ mod tests {
597588 dummy_stage ( ) ,
598589 ) ) ;
599590
600- let union = Arc :: new ( UnionExec :: new ( vec ! [ left. clone( ) , right. clone( ) ] ) ) ;
591+ let union = UnionExec :: try_new ( vec ! [ left. clone( ) , right. clone( ) ] ) ? ;
601592 let plan: Arc < dyn ExecutionPlan > =
602593 Arc :: new ( PartitionIsolatorExec :: new_ready ( union. clone ( ) , 3 ) ?) ;
603594
604595 let mut buf = Vec :: new ( ) ;
605596 codec. try_encode ( plan. clone ( ) , & mut buf) ?;
606597
607- let decoded = codec. try_decode ( & buf, & [ union] , & registry ) ?;
598+ let decoded = codec. try_decode ( & buf, & [ union] , & ctx ) ?;
608599 assert_eq ! ( repr( & plan) , repr( & decoded) ) ;
609600
610601 Ok ( ( ) )
0 commit comments