@@ -8,38 +8,29 @@ use datafusion::arrow::datatypes::Schema;
88use datafusion:: arrow:: datatypes:: SchemaRef ;
99use datafusion:: common:: internal_datafusion_err;
1010use datafusion:: error:: DataFusionError ;
11- use datafusion:: execution:: { FunctionRegistry , TaskContext } ;
12- use datafusion:: logical_expr:: { AggregateUDF , ScalarUDF , WindowUDF } ;
11+ use datafusion:: execution:: { FunctionRegistry , SessionStateBuilder } ;
1312use datafusion:: physical_expr:: EquivalenceProperties ;
1413use datafusion:: physical_plan:: execution_plan:: { Boundedness , EmissionType } ;
1514use datafusion:: physical_plan:: { ExecutionPlan , Partitioning , PlanProperties } ;
16- use datafusion:: prelude:: SessionContext ;
15+ use datafusion:: prelude:: { SessionConfig , SessionContext } ;
1716use datafusion_proto:: physical_plan:: from_proto:: parse_protobuf_partitioning;
1817use datafusion_proto:: physical_plan:: to_proto:: serialize_partitioning;
1918use datafusion_proto:: physical_plan:: { ComposedPhysicalExtensionCodec , PhysicalExtensionCodec } ;
2019use datafusion_proto:: protobuf;
2120use datafusion_proto:: protobuf:: proto_error;
2221use prost:: Message ;
23- use std:: fmt:: { Debug , Formatter } ;
2422use std:: sync:: Arc ;
2523use url:: Url ;
2624
2725/// DataFusion [PhysicalExtensionCodec] implementation that allows serializing and
2826/// deserializing the custom ExecutionPlans in this project
29- #[ derive( Clone , Default ) ]
30- pub struct DistributedCodec ( Arc < TaskContext > ) ;
31-
32- impl Debug for DistributedCodec {
33- fn fmt ( & self , f : & mut Formatter < ' _ > ) -> std:: fmt:: Result {
34- write ! ( f, "DistributedCodec" )
35- }
36- }
27+ #[ derive( Debug ) ]
28+ pub struct DistributedCodec ;
3729
3830impl DistributedCodec {
39- pub fn new_combined_with_user ( ctx : Arc < TaskContext > ) -> impl PhysicalExtensionCodec {
40- let mut codecs: Vec < Arc < dyn PhysicalExtensionCodec > > =
41- vec ! [ Arc :: new( DistributedCodec ( Arc :: clone( & ctx) ) ) ] ;
42- codecs. extend ( get_distributed_user_codecs ( ctx. session_config ( ) ) ) ;
31+ pub fn new_combined_with_user ( cfg : & SessionConfig ) -> impl PhysicalExtensionCodec + use < > {
32+ let mut codecs: Vec < Arc < dyn PhysicalExtensionCodec > > = vec ! [ Arc :: new( DistributedCodec { } ) ] ;
33+ codecs. extend ( get_distributed_user_codecs ( cfg) ) ;
4334 ComposedPhysicalExtensionCodec :: new ( codecs)
4435 }
4536}
@@ -49,7 +40,7 @@ impl PhysicalExtensionCodec for DistributedCodec {
4940 & self ,
5041 buf : & [ u8 ] ,
5142 inputs : & [ Arc < dyn ExecutionPlan > ] ,
52- _registry : & dyn FunctionRegistry ,
43+ registry : & dyn FunctionRegistry ,
5344 ) -> datafusion:: common:: Result < Arc < dyn ExecutionPlan > > {
5445 let DistributedExecProto {
5546 node : Some ( distributed_exec_node) ,
@@ -63,7 +54,16 @@ impl PhysicalExtensionCodec for DistributedCodec {
6354 // TODO: The PhysicalExtensionCodec trait doesn't provide access to session state,
6455 // so we create a new SessionContext which loses any custom UDFs, UDAFs, and other
6556 // user configurations. This is a limitation of the current trait design.
66- let ctx = SessionContext :: new ( ) ;
57+ let state = SessionStateBuilder :: new ( )
58+ . with_scalar_functions (
59+ registry
60+ . udfs ( )
61+ . iter ( )
62+ . map ( |f| registry. udf ( f) )
63+ . collect :: < Result < Vec < _ > , _ > > ( ) ?,
64+ )
65+ . build ( ) ;
66+ let ctx = SessionContext :: from ( state) ;
6767
6868 fn parse_stage_proto (
6969 proto : Option < StageProto > ,
@@ -112,9 +112,13 @@ impl PhysicalExtensionCodec for DistributedCodec {
112112 . map ( |s| s. try_into ( ) )
113113 . ok_or ( proto_error ( "NetworkShuffleExec is missing schema" ) ) ??;
114114
115- let partitioning =
116- parse_protobuf_partitioning ( partitioning. as_ref ( ) , & ctx, & schema, self ) ?
117- . ok_or ( proto_error ( "NetworkShuffleExec is missing partitioning" ) ) ?;
115+ let partitioning = parse_protobuf_partitioning (
116+ partitioning. as_ref ( ) ,
117+ & ctx,
118+ & schema,
119+ & DistributedCodec { } ,
120+ ) ?
121+ . ok_or ( proto_error ( "NetworkShuffleExec is missing partitioning" ) ) ?;
118122
119123 Ok ( Arc :: new ( new_network_hash_shuffle_exec (
120124 partitioning,
@@ -132,9 +136,13 @@ impl PhysicalExtensionCodec for DistributedCodec {
132136 . map ( |s| s. try_into ( ) )
133137 . ok_or ( proto_error ( "NetworkCoalesceExec is missing schema" ) ) ??;
134138
135- let partitioning =
136- parse_protobuf_partitioning ( partitioning. as_ref ( ) , & ctx, & schema, self ) ?
137- . ok_or ( proto_error ( "NetworkCoalesceExec is missing partitioning" ) ) ?;
139+ let partitioning = parse_protobuf_partitioning (
140+ partitioning. as_ref ( ) ,
141+ & ctx,
142+ & schema,
143+ & DistributedCodec { } ,
144+ ) ?
145+ . ok_or ( proto_error ( "NetworkCoalesceExec is missing partitioning" ) ) ?;
138146
139147 Ok ( Arc :: new ( new_network_coalesce_tasks_exec (
140148 partitioning,
@@ -185,7 +193,7 @@ impl PhysicalExtensionCodec for DistributedCodec {
185193 schema : Some ( node. schema ( ) . try_into ( ) ?) ,
186194 partitioning : Some ( serialize_partitioning (
187195 node. properties ( ) . output_partitioning ( ) ,
188- self ,
196+ & DistributedCodec { } ,
189197 ) ?) ,
190198 input_stage : Some ( encode_stage_proto ( node. input_stage ( ) ) ?) ,
191199 } ;
@@ -200,7 +208,7 @@ impl PhysicalExtensionCodec for DistributedCodec {
200208 schema : Some ( node. schema ( ) . try_into ( ) ?) ,
201209 partitioning : Some ( serialize_partitioning (
202210 node. properties ( ) . output_partitioning ( ) ,
203- self ,
211+ & DistributedCodec { } ,
204212 ) ?) ,
205213 input_stage : Some ( encode_stage_proto ( node. input_stage ( ) ) ?) ,
206214 } ;
@@ -229,30 +237,6 @@ impl PhysicalExtensionCodec for DistributedCodec {
229237 Err ( proto_error ( format ! ( "Unexpected plan {}" , node. name( ) ) ) )
230238 }
231239 }
232-
233- fn try_decode_udf (
234- & self ,
235- name : & str ,
236- _buf : & [ u8 ] ,
237- ) -> datafusion:: common:: Result < Arc < ScalarUDF > > {
238- self . 0 . udf ( name)
239- }
240-
241- fn try_decode_udaf (
242- & self ,
243- name : & str ,
244- _buf : & [ u8 ] ,
245- ) -> datafusion:: common:: Result < Arc < AggregateUDF > > {
246- self . 0 . udaf ( name)
247- }
248-
249- fn try_decode_udwf (
250- & self ,
251- name : & str ,
252- _buf : & [ u8 ] ,
253- ) -> datafusion:: common:: Result < Arc < WindowUDF > > {
254- self . 0 . udwf ( name)
255- }
256240}
257241
258242/// A key that uniquely identifies a stage in a query
@@ -436,7 +420,7 @@ mod tests {
436420
437421 #[ test]
438422 fn test_roundtrip_single_flight ( ) -> datafusion:: common:: Result < ( ) > {
439- let codec = DistributedCodec :: default ( ) ;
423+ let codec = DistributedCodec ;
440424 let registry = MemoryFunctionRegistry :: new ( ) ;
441425
442426 let schema = schema_i32 ( "a" ) ;
@@ -455,7 +439,7 @@ mod tests {
455439
456440 #[ test]
457441 fn test_roundtrip_isolator_flight ( ) -> datafusion:: common:: Result < ( ) > {
458- let codec = DistributedCodec :: default ( ) ;
442+ let codec = DistributedCodec ;
459443 let registry = MemoryFunctionRegistry :: new ( ) ;
460444
461445 let schema = schema_i32 ( "b" ) ;
@@ -479,7 +463,7 @@ mod tests {
479463
480464 #[ test]
481465 fn test_roundtrip_isolator_union ( ) -> datafusion:: common:: Result < ( ) > {
482- let codec = DistributedCodec :: default ( ) ;
466+ let codec = DistributedCodec ;
483467 let registry = MemoryFunctionRegistry :: new ( ) ;
484468
485469 let schema = schema_i32 ( "c" ) ;
@@ -509,7 +493,7 @@ mod tests {
509493
510494 #[ test]
511495 fn test_roundtrip_isolator_sort_flight ( ) -> datafusion:: common:: Result < ( ) > {
512- let codec = DistributedCodec :: default ( ) ;
496+ let codec = DistributedCodec ;
513497 let registry = MemoryFunctionRegistry :: new ( ) ;
514498
515499 let schema = schema_i32 ( "d" ) ;
0 commit comments