@@ -19,6 +19,14 @@ pub(crate) fn set_distributed_option_extension<T: ConfigExtension + Default>(
1919 let mut meta = HeaderMap :: new ( ) ;
2020
2121 for entry in t. entries ( ) {
22+ // assume that fields starting with "__" are private, and are not supposed to be sent
23+ // over the wire. This accounts for the fact that we need to send our DistributedConfig
24+ // options without setting the __private_task_estimator and __private_channel_resolver.
25+ // Ideally those two fields should not even be there on the first place, but until
26+ // https://github.com/apache/datafusion/pull/18739 we need to put them there.
27+ if entry. key . starts_with ( "__" ) {
28+ continue ;
29+ }
2230 if let Some ( value) = entry. value {
2331 meta. insert (
2432 HeaderName :: from_str ( & format ! (
@@ -44,29 +52,51 @@ pub(crate) fn set_distributed_option_extension<T: ConfigExtension + Default>(
4452 Ok ( ( ) )
4553}
4654
47- pub ( crate ) fn set_distributed_option_extension_from_headers < T : ConfigExtension + Default > (
48- cfg : & mut SessionConfig ,
55+ pub ( crate ) fn set_distributed_option_extension_from_headers < ' a , T : ConfigExtension + Default > (
56+ cfg : & ' a mut SessionConfig ,
4957 headers : & HeaderMap ,
50- ) -> Result < ( ) , DataFusionError > {
51- let mut result = T :: default ( ) ;
52- let mut found_some = false ;
58+ ) -> Result < & ' a T , DataFusionError > {
59+ enum MutOrOwned < ' a , T > {
60+ Mut ( & ' a mut T ) ,
61+ Owned ( T ) ,
62+ }
63+
64+ impl < ' a , T > MutOrOwned < ' a , T > {
65+ fn as_mut ( & mut self ) -> & mut T {
66+ match self {
67+ MutOrOwned :: Mut ( v) => v,
68+ MutOrOwned :: Owned ( v) => v,
69+ }
70+ }
71+ }
72+
73+ // If the config extension existed before, we want to modify instead of adding a new one from
74+ // scratch. If not, we'll start from scratch with a new one.
75+ let mut result = match cfg. options_mut ( ) . extensions . get_mut :: < T > ( ) {
76+ Some ( v) => MutOrOwned :: Mut ( v) ,
77+ None => MutOrOwned :: Owned ( T :: default ( ) ) ,
78+ } ;
79+
5380 for ( k, v) in headers. iter ( ) {
5481 let key = k. as_str ( ) . trim_start_matches ( FLIGHT_METADATA_CONFIG_PREFIX ) ;
5582 let prefix = format ! ( "{}." , T :: PREFIX ) ;
5683 if key. starts_with ( & prefix) {
57- found_some = true ;
58- result. set (
84+ result. as_mut ( ) . set (
5985 key. trim_start_matches ( & prefix) ,
6086 v. to_str ( )
6187 . map_err ( |err| internal_datafusion_err ! ( "Cannot parse header value: {err}" ) ) ?,
6288 ) ?;
6389 }
6490 }
65- if !found_some {
66- return Ok ( ( ) ) ;
91+
92+ // Only insert the extension if it is not already there. If this is otherwise MutOrOwned::Mut it
93+ // means that the extension was already there, and we already modified it.
94+ if let MutOrOwned :: Owned ( v) = result {
95+ cfg. options_mut ( ) . extensions . insert ( v) ;
6796 }
68- cfg. options_mut ( ) . extensions . insert ( result) ;
69- Ok ( ( ) )
97+ cfg. options ( ) . extensions . get ( ) . ok_or_else ( || {
98+ internal_datafusion_err ! ( "ProgrammingError: a config option extension was just inserted, but it was not immediately retrievable" )
99+ } )
70100}
71101
72102#[ derive( Clone , Debug , Default ) ]
@@ -190,8 +220,15 @@ mod tests {
190220 & Default :: default ( ) ,
191221 ) ?;
192222
193- let extension = config. options ( ) . extensions . get :: < CustomExtension > ( ) ;
194- assert ! ( extension. is_none( ) ) ;
223+ let extension = config
224+ . options ( )
225+ . extensions
226+ . get :: < CustomExtension > ( )
227+ . unwrap ( ) ;
228+ let default = CustomExtension :: default ( ) ;
229+ assert_eq ! ( extension. foo, default . foo) ;
230+ assert_eq ! ( extension. bar, default . bar) ;
231+ assert_eq ! ( extension. baz, default . baz) ;
195232
196233 Ok ( ( ) )
197234 }
@@ -207,8 +244,15 @@ mod tests {
207244
208245 set_distributed_option_extension_from_headers :: < CustomExtension > ( & mut config, & header_map) ?;
209246
210- let extension = config. options ( ) . extensions . get :: < CustomExtension > ( ) ;
211- assert ! ( extension. is_none( ) ) ;
247+ let extension = config
248+ . options ( )
249+ . extensions
250+ . get :: < CustomExtension > ( )
251+ . unwrap ( ) ;
252+ let default = CustomExtension :: default ( ) ;
253+ assert_eq ! ( extension. foo, default . foo) ;
254+ assert_eq ! ( extension. bar, default . bar) ;
255+ assert_eq ! ( extension. baz, default . baz) ;
212256
213257 Ok ( ( ) )
214258 }
0 commit comments