@@ -27,9 +27,9 @@ pub trait ConfigExtensionExt {
2727 /// # use async_trait::async_trait;
2828 /// # use datafusion::common::{extensions_options, DataFusionError};
2929 /// # use datafusion::config::ConfigExtension;
30- /// # use datafusion::execution::SessionState;
30+ /// # use datafusion::execution::{ SessionState, SessionStateBuilder} ;
3131 /// # use datafusion::prelude::SessionConfig;
32- /// # use datafusion_distributed::{ConfigExtensionExt, SessionBuilder };
32+ /// # use datafusion_distributed::{ConfigExtensionExt, DistributedSessionBuilder, DistributedSessionBuilderContext };
3333 ///
3434 /// extensions_options! {
3535 /// pub struct CustomExtension {
@@ -52,11 +52,13 @@ pub trait ConfigExtensionExt {
5252 /// struct MyCustomSessionBuilder;
5353 ///
5454 /// #[async_trait]
55- /// impl SessionBuilder for MyCustomSessionBuilder {
56- /// async fn session_state(&self, mut state: SessionState) -> Result<SessionState, DataFusionError> {
55+ /// impl DistributedSessionBuilder for MyCustomSessionBuilder {
56+ /// async fn build_session_state(&self, ctx: DistributedSessionBuilderContext) -> Result<SessionState, DataFusionError> {
57+ /// let mut state = SessionStateBuilder::new().build();
58+ ///
5759 /// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will
5860 /// // know how to deserialize the CustomExtension from the gRPC metadata.
59- /// state.retrieve_distributed_option_extension::<CustomExtension>()?;
61+ /// state.retrieve_distributed_option_extension::<CustomExtension>(&ctx.headers )?;
6062 /// Ok(state)
6163 /// }
6264 /// }
@@ -76,9 +78,9 @@ pub trait ConfigExtensionExt {
7678 /// # use async_trait::async_trait;
7779 /// # use datafusion::common::{extensions_options, DataFusionError};
7880 /// # use datafusion::config::ConfigExtension;
79- /// # use datafusion::execution::SessionState;
81+ /// # use datafusion::execution::{ SessionState, SessionStateBuilder} ;
8082 /// # use datafusion::prelude::SessionConfig;
81- /// # use datafusion_distributed::{ConfigExtensionExt, SessionBuilder };
83+ /// # use datafusion_distributed::{ConfigExtensionExt, DistributedSessionBuilder, DistributedSessionBuilderContext };
8284 ///
8385 /// extensions_options! {
8486 /// pub struct CustomExtension {
@@ -101,17 +103,19 @@ pub trait ConfigExtensionExt {
101103 /// struct MyCustomSessionBuilder;
102104 ///
103105 /// #[async_trait]
104- /// impl SessionBuilder for MyCustomSessionBuilder {
105- /// async fn session_state(&self, mut state: SessionState) -> Result<SessionState, DataFusionError> {
106+ /// impl DistributedSessionBuilder for MyCustomSessionBuilder {
107+ /// async fn build_session_state(&self, ctx: DistributedSessionBuilderContext) -> Result<SessionState, DataFusionError> {
108+ /// let mut state = SessionStateBuilder::new().build();
106109 /// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will
107110 /// // know how to deserialize the CustomExtension from the gRPC metadata.
108- /// state.retrieve_distributed_option_extension::<CustomExtension>()?;
111+ /// state.retrieve_distributed_option_extension::<CustomExtension>(&ctx.headers )?;
109112 /// Ok(state)
110113 /// }
111114 /// }
112115 /// ```
113116 fn retrieve_distributed_option_extension < T : ConfigExtension + Default > (
114117 & mut self ,
118+ headers : & HeaderMap ,
115119 ) -> Result < ( ) , DataFusionError > ;
116120}
117121
@@ -153,20 +157,17 @@ impl ConfigExtensionExt for SessionConfig {
153157
154158 fn retrieve_distributed_option_extension < T : ConfigExtension + Default > (
155159 & mut self ,
160+ headers : & HeaderMap ,
156161 ) -> Result < ( ) , DataFusionError > {
157- let Some ( flight_metadata) = self . get_extension :: < ContextGrpcMetadata > ( ) else {
158- return Ok ( ( ) ) ;
159- } ;
160-
161162 let mut result = T :: default ( ) ;
162163 let mut found_some = false ;
163- for ( k, v) in flight_metadata . 0 . iter ( ) {
164+ for ( k, v) in headers . iter ( ) {
164165 let key = k. as_str ( ) . trim_start_matches ( FLIGHT_METADATA_CONFIG_PREFIX ) ;
165166 let prefix = format ! ( "{}." , T :: PREFIX ) ;
166167 if key. starts_with ( & prefix) {
167168 found_some = true ;
168169 result. set (
169- & key. trim_start_matches ( & prefix) ,
170+ key. trim_start_matches ( & prefix) ,
170171 v. to_str ( ) . map_err ( |err| {
171172 internal_datafusion_err ! ( "Cannot parse header value: {err}" )
172173 } ) ?,
@@ -185,7 +186,7 @@ impl ConfigExtensionExt for SessionStateBuilder {
185186 delegate ! {
186187 to self . config( ) . get_or_insert_default( ) {
187188 fn add_distributed_option_extension<T : ConfigExtension + Default >( & mut self , t: T ) -> Result <( ) , DataFusionError >;
188- fn retrieve_distributed_option_extension<T : ConfigExtension + Default >( & mut self ) -> Result <( ) , DataFusionError >;
189+ fn retrieve_distributed_option_extension<T : ConfigExtension + Default >( & mut self , h : & HeaderMap ) -> Result <( ) , DataFusionError >;
189190 }
190191 }
191192}
@@ -194,7 +195,7 @@ impl ConfigExtensionExt for SessionState {
194195 delegate ! {
195196 to self . config_mut( ) {
196197 fn add_distributed_option_extension<T : ConfigExtension + Default >( & mut self , t: T ) -> Result <( ) , DataFusionError >;
197- fn retrieve_distributed_option_extension<T : ConfigExtension + Default >( & mut self ) -> Result <( ) , DataFusionError >;
198+ fn retrieve_distributed_option_extension<T : ConfigExtension + Default >( & mut self , h : & HeaderMap ) -> Result <( ) , DataFusionError >;
198199 }
199200 }
200201}
@@ -203,7 +204,7 @@ impl ConfigExtensionExt for SessionContext {
203204 delegate ! {
204205 to self . state_ref( ) . write( ) . config_mut( ) {
205206 fn add_distributed_option_extension<T : ConfigExtension + Default >( & mut self , t: T ) -> Result <( ) , DataFusionError >;
206- fn retrieve_distributed_option_extension<T : ConfigExtension + Default >( & mut self ) -> Result <( ) , DataFusionError >;
207+ fn retrieve_distributed_option_extension<T : ConfigExtension + Default >( & mut self , h : & HeaderMap ) -> Result <( ) , DataFusionError >;
207208 }
208209 }
209210}
@@ -212,17 +213,6 @@ impl ConfigExtensionExt for SessionContext {
212213pub ( crate ) struct ContextGrpcMetadata ( pub HeaderMap ) ;
213214
214215impl ContextGrpcMetadata {
215- pub ( crate ) fn from_headers ( metadata : HeaderMap ) -> Self {
216- let mut new = HeaderMap :: new ( ) ;
217- for ( k, v) in metadata. into_iter ( ) {
218- let Some ( k) = k else { continue } ;
219- if k. as_str ( ) . starts_with ( FLIGHT_METADATA_CONFIG_PREFIX ) {
220- new. insert ( k, v) ;
221- }
222- }
223- Self ( new)
224- }
225-
226216 fn merge ( mut self , other : Self ) -> Self {
227217 for ( k, v) in other. 0 . into_iter ( ) {
228218 let Some ( k) = k else { continue } ;
@@ -246,16 +236,16 @@ mod tests {
246236 fn test_propagation ( ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
247237 let mut config = SessionConfig :: new ( ) ;
248238
249- let mut opt = CustomExtension :: default ( ) ;
250- opt. foo = "foo" . to_string ( ) ;
251- opt. bar = 1 ;
252- opt. baz = true ;
239+ let opt = CustomExtension {
240+ foo : "" . to_string ( ) ,
241+ bar : 0 ,
242+ baz : false ,
243+ } ;
253244
254245 config. add_distributed_option_extension ( opt) ?;
255-
246+ let metadata = config . get_extension :: < ContextGrpcMetadata > ( ) . unwrap ( ) ;
256247 let mut new_config = SessionConfig :: new ( ) ;
257- new_config. set_extension ( config. get_extension :: < ContextGrpcMetadata > ( ) . unwrap ( ) ) ;
258- new_config. retrieve_distributed_option_extension :: < CustomExtension > ( ) ?;
248+ new_config. retrieve_distributed_option_extension :: < CustomExtension > ( & metadata. 0 ) ?;
259249
260250 let opt = get_ext :: < CustomExtension > ( & config) ;
261251 let new_opt = get_ext :: < CustomExtension > ( & new_config) ;
@@ -294,12 +284,16 @@ mod tests {
294284 fn test_new_extension_overwrites_previous ( ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
295285 let mut config = SessionConfig :: new ( ) ;
296286
297- let mut opt1 = CustomExtension :: default ( ) ;
298- opt1. foo = "first" . to_string ( ) ;
287+ let opt1 = CustomExtension {
288+ foo : "first" . to_string ( ) ,
289+ ..Default :: default ( )
290+ } ;
299291 config. add_distributed_option_extension ( opt1) ?;
300292
301- let mut opt2 = CustomExtension :: default ( ) ;
302- opt2. bar = 42 ;
293+ let opt2 = CustomExtension {
294+ bar : 42 ,
295+ ..Default :: default ( )
296+ } ;
303297 config. add_distributed_option_extension ( opt2) ?;
304298
305299 let flight_metadata = config. get_extension :: < ContextGrpcMetadata > ( ) . unwrap ( ) ;
@@ -317,7 +311,7 @@ mod tests {
317311 fn test_propagate_no_metadata ( ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
318312 let mut config = SessionConfig :: new ( ) ;
319313
320- config. retrieve_distributed_option_extension :: < CustomExtension > ( ) ?;
314+ config. retrieve_distributed_option_extension :: < CustomExtension > ( & Default :: default ( ) ) ?;
321315
322316 let extension = config. options ( ) . extensions . get :: < CustomExtension > ( ) ;
323317 assert ! ( extension. is_none( ) ) ;
@@ -330,13 +324,11 @@ mod tests {
330324 let mut config = SessionConfig :: new ( ) ;
331325 let mut header_map = HeaderMap :: new ( ) ;
332326 header_map. insert (
333- HeaderName :: from_str ( "x-datafusion-distributed-other.setting" ) . unwrap ( ) ,
327+ HeaderName :: from_str ( "x-datafusion-distributed-config- other.setting" ) . unwrap ( ) ,
334328 HeaderValue :: from_str ( "value" ) . unwrap ( ) ,
335329 ) ;
336330
337- let flight_metadata = ContextGrpcMetadata :: from_headers ( header_map) ;
338- config. set_extension ( std:: sync:: Arc :: new ( flight_metadata) ) ;
339- config. retrieve_distributed_option_extension :: < CustomExtension > ( ) ?;
331+ config. retrieve_distributed_option_extension :: < CustomExtension > ( & header_map) ?;
340332
341333 let extension = config. options ( ) . extensions . get :: < CustomExtension > ( ) ;
342334 assert ! ( extension. is_none( ) ) ;
@@ -348,13 +340,17 @@ mod tests {
348340 fn test_multiple_extensions_different_prefixes ( ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
349341 let mut config = SessionConfig :: new ( ) ;
350342
351- let mut custom_opt = CustomExtension :: default ( ) ;
352- custom_opt. foo = "custom_value" . to_string ( ) ;
353- custom_opt. bar = 123 ;
343+ let custom_opt = CustomExtension {
344+ foo : "custom_value" . to_string ( ) ,
345+ bar : 123 ,
346+ ..Default :: default ( )
347+ } ;
354348
355- let mut another_opt = AnotherExtension :: default ( ) ;
356- another_opt. setting1 = "other" . to_string ( ) ;
357- another_opt. setting2 = 456 ;
349+ let another_opt = AnotherExtension {
350+ setting1 : "other" . to_string ( ) ,
351+ setting2 : 456 ,
352+ ..Default :: default ( )
353+ } ;
358354
359355 config. add_distributed_option_extension ( custom_opt) ?;
360356 config. add_distributed_option_extension ( another_opt) ?;
@@ -384,9 +380,8 @@ mod tests {
384380 ) ;
385381
386382 let mut new_config = SessionConfig :: new ( ) ;
387- new_config. set_extension ( flight_metadata) ;
388- new_config. retrieve_distributed_option_extension :: < CustomExtension > ( ) ?;
389- new_config. retrieve_distributed_option_extension :: < AnotherExtension > ( ) ?;
383+ new_config. retrieve_distributed_option_extension :: < CustomExtension > ( metadata) ?;
384+ new_config. retrieve_distributed_option_extension :: < AnotherExtension > ( metadata) ?;
390385
391386 let propagated_custom = get_ext :: < CustomExtension > ( & new_config) ;
392387 let propagated_another = get_ext :: < AnotherExtension > ( & new_config) ;
0 commit comments