@@ -6,32 +6,40 @@ use crate::{NetworkShuffleExec, PartitionIsolatorExec};
66use bytes:: Bytes ;
77use datafusion:: arrow:: datatypes:: Schema ;
88use datafusion:: arrow:: datatypes:: SchemaRef ;
9- use datafusion:: common:: { internal_datafusion_err, not_impl_err } ;
9+ use datafusion:: common:: internal_datafusion_err;
1010use datafusion:: error:: DataFusionError ;
11- use datafusion:: execution:: FunctionRegistry ;
11+ use datafusion:: execution:: { FunctionRegistry , TaskContext } ;
1212use datafusion:: logical_expr:: { AggregateUDF , ScalarUDF , WindowUDF } ;
1313use datafusion:: physical_expr:: EquivalenceProperties ;
1414use datafusion:: physical_plan:: execution_plan:: { Boundedness , EmissionType } ;
1515use datafusion:: physical_plan:: { ExecutionPlan , Partitioning , PlanProperties } ;
16- use datafusion:: prelude:: { SessionConfig , SessionContext } ;
16+ use datafusion:: prelude:: SessionContext ;
1717use datafusion_proto:: physical_plan:: from_proto:: parse_protobuf_partitioning;
1818use datafusion_proto:: physical_plan:: to_proto:: serialize_partitioning;
1919use datafusion_proto:: physical_plan:: { ComposedPhysicalExtensionCodec , PhysicalExtensionCodec } ;
2020use datafusion_proto:: protobuf;
2121use datafusion_proto:: protobuf:: proto_error;
2222use prost:: Message ;
23+ use std:: fmt:: { Debug , Formatter } ;
2324use std:: sync:: Arc ;
2425use url:: Url ;
2526
2627/// DataFusion [PhysicalExtensionCodec] implementation that allows serializing and
2728/// deserializing the custom ExecutionPlans in this project
28- #[ derive( Debug ) ]
29- pub struct DistributedCodec ;
29+ #[ derive( Clone , Default ) ]
30+ pub struct DistributedCodec ( Arc < TaskContext > ) ;
31+
32+ impl Debug for DistributedCodec {
33+ fn fmt ( & self , f : & mut Formatter < ' _ > ) -> std:: fmt:: Result {
34+ write ! ( f, "DistributedCodec" )
35+ }
36+ }
3037
3138impl DistributedCodec {
32- pub fn new_combined_with_user ( cfg : & SessionConfig ) -> impl PhysicalExtensionCodec + use < > {
33- let mut codecs: Vec < Arc < dyn PhysicalExtensionCodec > > = vec ! [ Arc :: new( DistributedCodec { } ) ] ;
34- codecs. extend ( get_distributed_user_codecs ( cfg) ) ;
39+ pub fn new_combined_with_user ( ctx : Arc < TaskContext > ) -> impl PhysicalExtensionCodec {
40+ let mut codecs: Vec < Arc < dyn PhysicalExtensionCodec > > =
41+ vec ! [ Arc :: new( DistributedCodec ( Arc :: clone( & ctx) ) ) ] ;
42+ codecs. extend ( get_distributed_user_codecs ( ctx. session_config ( ) ) ) ;
3543 ComposedPhysicalExtensionCodec :: new ( codecs)
3644 }
3745}
@@ -104,13 +112,9 @@ impl PhysicalExtensionCodec for DistributedCodec {
104112 . map ( |s| s. try_into ( ) )
105113 . ok_or ( proto_error ( "NetworkShuffleExec is missing schema" ) ) ??;
106114
107- let partitioning = parse_protobuf_partitioning (
108- partitioning. as_ref ( ) ,
109- & ctx,
110- & schema,
111- & DistributedCodec { } ,
112- ) ?
113- . ok_or ( proto_error ( "NetworkShuffleExec is missing partitioning" ) ) ?;
115+ let partitioning =
116+ parse_protobuf_partitioning ( partitioning. as_ref ( ) , & ctx, & schema, self ) ?
117+ . ok_or ( proto_error ( "NetworkShuffleExec is missing partitioning" ) ) ?;
114118
115119 Ok ( Arc :: new ( new_network_hash_shuffle_exec (
116120 partitioning,
@@ -128,13 +132,9 @@ impl PhysicalExtensionCodec for DistributedCodec {
128132 . map ( |s| s. try_into ( ) )
129133 . ok_or ( proto_error ( "NetworkCoalesceExec is missing schema" ) ) ??;
130134
131- let partitioning = parse_protobuf_partitioning (
132- partitioning. as_ref ( ) ,
133- & ctx,
134- & schema,
135- & DistributedCodec { } ,
136- ) ?
137- . ok_or ( proto_error ( "NetworkCoalesceExec is missing partitioning" ) ) ?;
135+ let partitioning =
136+ parse_protobuf_partitioning ( partitioning. as_ref ( ) , & ctx, & schema, self ) ?
137+ . ok_or ( proto_error ( "NetworkCoalesceExec is missing partitioning" ) ) ?;
138138
139139 Ok ( Arc :: new ( new_network_coalesce_tasks_exec (
140140 partitioning,
@@ -185,7 +185,7 @@ impl PhysicalExtensionCodec for DistributedCodec {
185185 schema : Some ( node. schema ( ) . try_into ( ) ?) ,
186186 partitioning : Some ( serialize_partitioning (
187187 node. properties ( ) . output_partitioning ( ) ,
188- & DistributedCodec { } ,
188+ self ,
189189 ) ?) ,
190190 input_stage : Some ( encode_stage_proto ( node. input_stage ( ) ) ?) ,
191191 } ;
@@ -200,7 +200,7 @@ impl PhysicalExtensionCodec for DistributedCodec {
200200 schema : Some ( node. schema ( ) . try_into ( ) ?) ,
201201 partitioning : Some ( serialize_partitioning (
202202 node. properties ( ) . output_partitioning ( ) ,
203- & DistributedCodec { } ,
203+ self ,
204204 ) ?) ,
205205 input_stage : Some ( encode_stage_proto ( node. input_stage ( ) ) ?) ,
206206 } ;
@@ -230,16 +230,28 @@ impl PhysicalExtensionCodec for DistributedCodec {
230230 }
231231 }
232232
233- fn try_encode_udf ( & self , _: & ScalarUDF , _: & mut Vec < u8 > ) -> datafusion:: common:: Result < ( ) > {
234- not_impl_err ! ( "DistributedCodec does not encode UDFs" )
233+ fn try_decode_udf (
234+ & self ,
235+ name : & str ,
236+ _buf : & [ u8 ] ,
237+ ) -> datafusion:: common:: Result < Arc < ScalarUDF > > {
238+ self . 0 . udf ( name)
235239 }
236240
237- fn try_encode_udaf ( & self , _: & AggregateUDF , _: & mut Vec < u8 > ) -> datafusion:: common:: Result < ( ) > {
238- not_impl_err ! ( "DistributedCodec does not encode UDAFs" )
241+ fn try_decode_udaf (
242+ & self ,
243+ name : & str ,
244+ _buf : & [ u8 ] ,
245+ ) -> datafusion:: common:: Result < Arc < AggregateUDF > > {
246+ self . 0 . udaf ( name)
239247 }
240248
241- fn try_encode_udwf ( & self , _: & WindowUDF , _: & mut Vec < u8 > ) -> datafusion:: common:: Result < ( ) > {
242- not_impl_err ! ( "DistributedCodec does not encode UDWFs" )
249+ fn try_decode_udwf (
250+ & self ,
251+ name : & str ,
252+ _buf : & [ u8 ] ,
253+ ) -> datafusion:: common:: Result < Arc < WindowUDF > > {
254+ self . 0 . udwf ( name)
243255 }
244256}
245257
@@ -424,7 +436,7 @@ mod tests {
424436
425437 #[ test]
426438 fn test_roundtrip_single_flight ( ) -> datafusion:: common:: Result < ( ) > {
427- let codec = DistributedCodec ;
439+ let codec = DistributedCodec :: default ( ) ;
428440 let registry = MemoryFunctionRegistry :: new ( ) ;
429441
430442 let schema = schema_i32 ( "a" ) ;
@@ -443,7 +455,7 @@ mod tests {
443455
444456 #[ test]
445457 fn test_roundtrip_isolator_flight ( ) -> datafusion:: common:: Result < ( ) > {
446- let codec = DistributedCodec ;
458+ let codec = DistributedCodec :: default ( ) ;
447459 let registry = MemoryFunctionRegistry :: new ( ) ;
448460
449461 let schema = schema_i32 ( "b" ) ;
@@ -467,7 +479,7 @@ mod tests {
467479
468480 #[ test]
469481 fn test_roundtrip_isolator_union ( ) -> datafusion:: common:: Result < ( ) > {
470- let codec = DistributedCodec ;
482+ let codec = DistributedCodec :: default ( ) ;
471483 let registry = MemoryFunctionRegistry :: new ( ) ;
472484
473485 let schema = schema_i32 ( "c" ) ;
@@ -497,7 +509,7 @@ mod tests {
497509
498510 #[ test]
499511 fn test_roundtrip_isolator_sort_flight ( ) -> datafusion:: common:: Result < ( ) > {
500- let codec = DistributedCodec ;
512+ let codec = DistributedCodec :: default ( ) ;
501513 let registry = MemoryFunctionRegistry :: new ( ) ;
502514
503515 let schema = schema_i32 ( "d" ) ;
0 commit comments