1- use datafusion:: common:: not_impl_err ;
1+ use datafusion:: common:: internal_datafusion_err ;
22use datafusion:: error:: DataFusionError ;
3+ use datafusion:: error:: Result ;
34use datafusion:: execution:: FunctionRegistry ;
45use datafusion:: logical_expr:: { AggregateUDF , ScalarUDF } ;
56use datafusion:: physical_plan:: ExecutionPlan ;
67use datafusion_proto:: physical_plan:: PhysicalExtensionCodec ;
8+ use prost:: Message ;
79use std:: fmt:: Debug ;
810use std:: sync:: Arc ;
11+ // Code taken from https://github.com/apache/datafusion/blob/10f41887fa40d7d425c19b07857f80115460a98e/datafusion/proto/src/physical_plan/mod.rs
12+ // TODO: It's not yet on DF 49, once upgrading to DF 50 we can remove this
913
10- // Idea taken from
11- // https://github.com/apache/datafusion/blob/0eebc0c7c0ffcd1514f5c6d0f8e2b6d0c69a07f5/datafusion-examples/examples/composed_extension_codec.rs#L236-L291
14+ /// DataEncoderTuple captures the position of the encoder
15+ /// in the codec list that was used to encode the data and actual encoded data
16+ #[ derive( Clone , PartialEq , prost:: Message ) ]
17+ struct DataEncoderTuple {
18+ /// The position of encoder used to encode data
19+ /// (to be used for decoding)
20+ #[ prost( uint32, tag = 1 ) ]
21+ pub encoder_position : u32 ,
1222
13- /// A [PhysicalExtensionCodec] that holds multiple [PhysicalExtensionCodec] and tries them
14- /// sequentially until one works.
15- #[ derive( Debug , Clone , Default ) ]
16- pub ( crate ) struct ComposedPhysicalExtensionCodec {
23+ #[ prost( bytes, tag = 2 ) ]
24+ pub blob : Vec < u8 > ,
25+ }
26+
27+ /// A PhysicalExtensionCodec that tries one of multiple inner codecs
28+ /// until one works
29+ #[ derive( Debug ) ]
30+ pub struct ComposedPhysicalExtensionCodec {
1731 codecs : Vec < Arc < dyn PhysicalExtensionCodec > > ,
1832}
1933
2034impl ComposedPhysicalExtensionCodec {
21- /// Adds a new [PhysicalExtensionCodec] to the list. These codecs will be tried
22- /// sequentially until one works .
23- pub ( crate ) fn push ( & mut self , codec : impl PhysicalExtensionCodec + ' static ) {
24- self . codecs . push ( Arc :: new ( codec ) ) ;
35+ // Position in this codecs list is important as it will be used for decoding.
36+ // If new codec is added it should go to last position .
37+ pub fn new ( codecs : Vec < Arc < dyn PhysicalExtensionCodec > > ) -> Self {
38+ Self { codecs }
2539 }
2640
27- /// Adds a new [PhysicalExtensionCodec] to the list. These codecs will be tried
28- /// sequentially until one works.
29- pub ( crate ) fn push_arc ( & mut self , codec : Arc < dyn PhysicalExtensionCodec > ) {
30- self . codecs . push ( codec) ;
41+ fn decode_protobuf < R > (
42+ & self ,
43+ buf : & [ u8 ] ,
44+ decode : impl FnOnce ( & dyn PhysicalExtensionCodec , & [ u8 ] ) -> Result < R , DataFusionError > ,
45+ ) -> Result < R , DataFusionError > {
46+ let proto =
47+ DataEncoderTuple :: decode ( buf) . map_err ( |e| DataFusionError :: Internal ( e. to_string ( ) ) ) ?;
48+
49+ let pos = proto. encoder_position as usize ;
50+ let codec = self . codecs . get ( pos) . ok_or_else ( || {
51+ internal_datafusion_err ! (
52+ "Can't find required codec in position {pos} in codec list with {} elements" ,
53+ self . codecs. len( )
54+ )
55+ } ) ?;
56+
57+ decode ( codec. as_ref ( ) , & proto. blob )
3158 }
3259
33- fn try_any < T > (
60+ fn encode_protobuf (
3461 & self ,
35- mut f : impl FnMut ( & dyn PhysicalExtensionCodec ) -> Result < T , DataFusionError > ,
36- ) -> Result < T , DataFusionError > {
37- let mut errs = vec ! [ ] ;
38- for codec in & self . codecs {
39- match f ( codec. as_ref ( ) ) {
40- Ok ( node) => return Ok ( node) ,
41- Err ( err) => errs. push ( err) ,
62+ buf : & mut Vec < u8 > ,
63+ mut encode : impl FnMut ( & dyn PhysicalExtensionCodec , & mut Vec < u8 > ) -> Result < ( ) > ,
64+ ) -> Result < ( ) , DataFusionError > {
65+ let mut data = vec ! [ ] ;
66+ let mut last_err = None ;
67+ let mut encoder_position = None ;
68+
69+ // find the encoder
70+ for ( position, codec) in self . codecs . iter ( ) . enumerate ( ) {
71+ match encode ( codec. as_ref ( ) , & mut data) {
72+ Ok ( _) => {
73+ encoder_position = Some ( position as u32 ) ;
74+ break ;
75+ }
76+ Err ( err) => last_err = Some ( err) ,
4277 }
4378 }
4479
45- if errs. is_empty ( ) {
46- return not_impl_err ! ( "Empty list of composed codecs" ) ;
47- }
80+ let encoder_position = encoder_position. ok_or_else ( || {
81+ last_err. unwrap_or_else ( || {
82+ DataFusionError :: NotImplemented ( "Empty list of composed codecs" . to_owned ( ) )
83+ } )
84+ } ) ?;
4885
49- let mut msg = "None of the provided PhysicalExtensionCodec worked:" . to_string ( ) ;
50- for err in & errs {
51- msg += & format ! ( "\n {err}" ) ;
52- }
53- not_impl_err ! ( "{msg}" )
86+ // encode with encoder position
87+ let proto = DataEncoderTuple {
88+ encoder_position,
89+ blob : data,
90+ } ;
91+ proto
92+ . encode ( buf)
93+ . map_err ( |e| DataFusionError :: Internal ( e. to_string ( ) ) )
5494 }
5595}
5696
@@ -60,39 +100,27 @@ impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec {
60100 buf : & [ u8 ] ,
61101 inputs : & [ Arc < dyn ExecutionPlan > ] ,
62102 registry : & dyn FunctionRegistry ,
63- ) -> Result < Arc < dyn ExecutionPlan > , DataFusionError > {
64- self . try_any ( |codec| codec. try_decode ( buf , inputs, registry) )
103+ ) -> Result < Arc < dyn ExecutionPlan > > {
104+ self . decode_protobuf ( buf , |codec, data | codec. try_decode ( data , inputs, registry) )
65105 }
66106
67- fn try_encode (
68- & self ,
69- node : Arc < dyn ExecutionPlan > ,
70- buf : & mut Vec < u8 > ,
71- ) -> Result < ( ) , DataFusionError > {
72- self . try_any ( |codec| codec. try_encode ( node. clone ( ) , buf) )
107+ fn try_encode ( & self , node : Arc < dyn ExecutionPlan > , buf : & mut Vec < u8 > ) -> Result < ( ) > {
108+ self . encode_protobuf ( buf, |codec, data| codec. try_encode ( Arc :: clone ( & node) , data) )
73109 }
74110
75- fn try_decode_udf ( & self , name : & str , buf : & [ u8 ] ) -> Result < Arc < ScalarUDF > , DataFusionError > {
76- self . try_any ( |codec| codec. try_decode_udf ( name, buf ) )
111+ fn try_decode_udf ( & self , name : & str , buf : & [ u8 ] ) -> Result < Arc < ScalarUDF > > {
112+ self . decode_protobuf ( buf , |codec, data | codec. try_decode_udf ( name, data ) )
77113 }
78114
79- fn try_encode_udf ( & self , node : & ScalarUDF , buf : & mut Vec < u8 > ) -> Result < ( ) , DataFusionError > {
80- self . try_any ( |codec| codec. try_encode_udf ( node, buf ) )
115+ fn try_encode_udf ( & self , node : & ScalarUDF , buf : & mut Vec < u8 > ) -> Result < ( ) > {
116+ self . encode_protobuf ( buf , |codec, data | codec. try_encode_udf ( node, data ) )
81117 }
82118
83- fn try_decode_udaf (
84- & self ,
85- name : & str ,
86- buf : & [ u8 ] ,
87- ) -> Result < Arc < AggregateUDF > , DataFusionError > {
88- self . try_any ( |codec| codec. try_decode_udaf ( name, buf) )
119+ fn try_decode_udaf ( & self , name : & str , buf : & [ u8 ] ) -> Result < Arc < AggregateUDF > > {
120+ self . decode_protobuf ( buf, |codec, data| codec. try_decode_udaf ( name, data) )
89121 }
90122
91- fn try_encode_udaf (
92- & self ,
93- node : & AggregateUDF ,
94- buf : & mut Vec < u8 > ,
95- ) -> Result < ( ) , DataFusionError > {
96- self . try_any ( |codec| codec. try_encode_udaf ( node, buf) )
123+ fn try_encode_udaf ( & self , node : & AggregateUDF , buf : & mut Vec < u8 > ) -> Result < ( ) > {
124+ self . encode_protobuf ( buf, |codec, data| codec. try_encode_udaf ( node, data) )
97125 }
98126}
0 commit comments