@@ -27,9 +27,9 @@ pub trait ConfigExtensionExt {
2727 /// # use async_trait::async_trait;
2828 /// # use datafusion::common::{extensions_options, DataFusionError};
2929 /// # use datafusion::config::ConfigExtension;
30- /// # use datafusion::execution::SessionState;
30+ /// # use datafusion::execution::{ SessionState, SessionStateBuilder} ;
3131 /// # use datafusion::prelude::SessionConfig;
32- /// # use datafusion_distributed::{ConfigExtensionExt, SessionBuilder };
32+ /// # use datafusion_distributed::{ConfigExtensionExt, DistributedSessionBuilder, DistributedSessionBuilderContext };
3333 ///
3434 /// extensions_options! {
3535 /// pub struct CustomExtension {
@@ -52,11 +52,13 @@ pub trait ConfigExtensionExt {
5252 /// struct MyCustomSessionBuilder;
5353 ///
5454 /// #[async_trait]
55- /// impl SessionBuilder for MyCustomSessionBuilder {
56- /// async fn session_state(&self, mut state: SessionState) -> Result<SessionState, DataFusionError> {
55+ /// impl DistributedSessionBuilder for MyCustomSessionBuilder {
56+ /// async fn build_session_state(&self, ctx: DistributedSessionBuilderContext) -> Result<SessionState, DataFusionError> {
57+ /// let mut state = SessionStateBuilder::new().build();
58+ ///
5759 /// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will
5860 /// // know how to deserialize the CustomExtension from the gRPC metadata.
59- /// state.retrieve_distributed_option_extension::<CustomExtension>()?;
61+ /// state.retrieve_distributed_option_extension::<CustomExtension>(&ctx.headers )?;
6062 /// Ok(state)
6163 /// }
6264 /// }
@@ -76,9 +78,9 @@ pub trait ConfigExtensionExt {
7678 /// # use async_trait::async_trait;
7779 /// # use datafusion::common::{extensions_options, DataFusionError};
7880 /// # use datafusion::config::ConfigExtension;
79- /// # use datafusion::execution::SessionState;
81+ /// # use datafusion::execution::{ SessionState, SessionStateBuilder} ;
8082 /// # use datafusion::prelude::SessionConfig;
81- /// # use datafusion_distributed::{ConfigExtensionExt, SessionBuilder };
83+ /// # use datafusion_distributed::{ConfigExtensionExt, DistributedSessionBuilder, DistributedSessionBuilderContext };
8284 ///
8385 /// extensions_options! {
8486 /// pub struct CustomExtension {
@@ -101,17 +103,19 @@ pub trait ConfigExtensionExt {
101103 /// struct MyCustomSessionBuilder;
102104 ///
103105 /// #[async_trait]
104- /// impl SessionBuilder for MyCustomSessionBuilder {
105- /// async fn session_state(&self, mut state: SessionState) -> Result<SessionState, DataFusionError> {
106+ /// impl DistributedSessionBuilder for MyCustomSessionBuilder {
107+ /// async fn build_session_state(&self, ctx: DistributedSessionBuilderContext) -> Result<SessionState, DataFusionError> {
108+ /// let mut state = SessionStateBuilder::new().build();
106109 /// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will
107110 /// // know how to deserialize the CustomExtension from the gRPC metadata.
108- /// state.retrieve_distributed_option_extension::<CustomExtension>()?;
111+ /// state.retrieve_distributed_option_extension::<CustomExtension>(&ctx.headers )?;
109112 /// Ok(state)
110113 /// }
111114 /// }
112115 /// ```
113116 fn retrieve_distributed_option_extension < T : ConfigExtension + Default > (
114117 & mut self ,
118+ headers : & HeaderMap ,
115119 ) -> Result < ( ) , DataFusionError > ;
116120}
117121
@@ -153,14 +157,11 @@ impl ConfigExtensionExt for SessionConfig {
153157
154158 fn retrieve_distributed_option_extension < T : ConfigExtension + Default > (
155159 & mut self ,
160+ headers : & HeaderMap ,
156161 ) -> Result < ( ) , DataFusionError > {
157- let Some ( flight_metadata) = self . get_extension :: < ContextGrpcMetadata > ( ) else {
158- return Ok ( ( ) ) ;
159- } ;
160-
161162 let mut result = T :: default ( ) ;
162163 let mut found_some = false ;
163- for ( k, v) in flight_metadata . 0 . iter ( ) {
164+ for ( k, v) in headers . iter ( ) {
164165 let key = k. as_str ( ) . trim_start_matches ( FLIGHT_METADATA_CONFIG_PREFIX ) ;
165166 let prefix = format ! ( "{}." , T :: PREFIX ) ;
166167 if key. starts_with ( & prefix) {
@@ -185,7 +186,7 @@ impl ConfigExtensionExt for SessionStateBuilder {
185186 delegate ! {
186187 to self . config( ) . get_or_insert_default( ) {
187188 fn add_distributed_option_extension<T : ConfigExtension + Default >( & mut self , t: T ) -> Result <( ) , DataFusionError >;
188- fn retrieve_distributed_option_extension<T : ConfigExtension + Default >( & mut self ) -> Result <( ) , DataFusionError >;
189+ fn retrieve_distributed_option_extension<T : ConfigExtension + Default >( & mut self , h : & HeaderMap ) -> Result <( ) , DataFusionError >;
189190 }
190191 }
191192}
@@ -194,7 +195,7 @@ impl ConfigExtensionExt for SessionState {
194195 delegate ! {
195196 to self . config_mut( ) {
196197 fn add_distributed_option_extension<T : ConfigExtension + Default >( & mut self , t: T ) -> Result <( ) , DataFusionError >;
197- fn retrieve_distributed_option_extension<T : ConfigExtension + Default >( & mut self ) -> Result <( ) , DataFusionError >;
198+ fn retrieve_distributed_option_extension<T : ConfigExtension + Default >( & mut self , h : & HeaderMap ) -> Result <( ) , DataFusionError >;
198199 }
199200 }
200201}
@@ -203,7 +204,7 @@ impl ConfigExtensionExt for SessionContext {
203204 delegate ! {
204205 to self . state_ref( ) . write( ) . config_mut( ) {
205206 fn add_distributed_option_extension<T : ConfigExtension + Default >( & mut self , t: T ) -> Result <( ) , DataFusionError >;
206- fn retrieve_distributed_option_extension<T : ConfigExtension + Default >( & mut self ) -> Result <( ) , DataFusionError >;
207+ fn retrieve_distributed_option_extension<T : ConfigExtension + Default >( & mut self , h : & HeaderMap ) -> Result <( ) , DataFusionError >;
207208 }
208209 }
209210}
@@ -212,17 +213,6 @@ impl ConfigExtensionExt for SessionContext {
212213pub ( crate ) struct ContextGrpcMetadata ( pub HeaderMap ) ;
213214
214215impl ContextGrpcMetadata {
215- pub ( crate ) fn from_headers ( metadata : HeaderMap ) -> Self {
216- let mut new = HeaderMap :: new ( ) ;
217- for ( k, v) in metadata. into_iter ( ) {
218- let Some ( k) = k else { continue } ;
219- if k. as_str ( ) . starts_with ( FLIGHT_METADATA_CONFIG_PREFIX ) {
220- new. insert ( k, v) ;
221- }
222- }
223- Self ( new)
224- }
225-
226216 fn merge ( mut self , other : Self ) -> Self {
227217 for ( k, v) in other. 0 . into_iter ( ) {
228218 let Some ( k) = k else { continue } ;
@@ -252,10 +242,9 @@ mod tests {
252242 opt. baz = true ;
253243
254244 config. add_distributed_option_extension ( opt) ?;
255-
245+ let metadata = config . get_extension :: < ContextGrpcMetadata > ( ) . unwrap ( ) ;
256246 let mut new_config = SessionConfig :: new ( ) ;
257- new_config. set_extension ( config. get_extension :: < ContextGrpcMetadata > ( ) . unwrap ( ) ) ;
258- new_config. retrieve_distributed_option_extension :: < CustomExtension > ( ) ?;
247+ new_config. retrieve_distributed_option_extension :: < CustomExtension > ( & metadata. 0 ) ?;
259248
260249 let opt = get_ext :: < CustomExtension > ( & config) ;
261250 let new_opt = get_ext :: < CustomExtension > ( & new_config) ;
@@ -317,7 +306,7 @@ mod tests {
317306 fn test_propagate_no_metadata ( ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
318307 let mut config = SessionConfig :: new ( ) ;
319308
320- config. retrieve_distributed_option_extension :: < CustomExtension > ( ) ?;
309+ config. retrieve_distributed_option_extension :: < CustomExtension > ( & Default :: default ( ) ) ?;
321310
322311 let extension = config. options ( ) . extensions . get :: < CustomExtension > ( ) ;
323312 assert ! ( extension. is_none( ) ) ;
@@ -330,13 +319,11 @@ mod tests {
330319 let mut config = SessionConfig :: new ( ) ;
331320 let mut header_map = HeaderMap :: new ( ) ;
332321 header_map. insert (
333- HeaderName :: from_str ( "x-datafusion-distributed-other.setting" ) . unwrap ( ) ,
322+ HeaderName :: from_str ( "x-datafusion-distributed-config- other.setting" ) . unwrap ( ) ,
334323 HeaderValue :: from_str ( "value" ) . unwrap ( ) ,
335324 ) ;
336325
337- let flight_metadata = ContextGrpcMetadata :: from_headers ( header_map) ;
338- config. set_extension ( std:: sync:: Arc :: new ( flight_metadata) ) ;
339- config. retrieve_distributed_option_extension :: < CustomExtension > ( ) ?;
326+ config. retrieve_distributed_option_extension :: < CustomExtension > ( & header_map) ?;
340327
341328 let extension = config. options ( ) . extensions . get :: < CustomExtension > ( ) ;
342329 assert ! ( extension. is_none( ) ) ;
@@ -384,9 +371,8 @@ mod tests {
384371 ) ;
385372
386373 let mut new_config = SessionConfig :: new ( ) ;
387- new_config. set_extension ( flight_metadata) ;
388- new_config. retrieve_distributed_option_extension :: < CustomExtension > ( ) ?;
389- new_config. retrieve_distributed_option_extension :: < AnotherExtension > ( ) ?;
374+ new_config. retrieve_distributed_option_extension :: < CustomExtension > ( & metadata) ?;
375+ new_config. retrieve_distributed_option_extension :: < AnotherExtension > ( & metadata) ?;
390376
391377 let propagated_custom = get_ext :: < CustomExtension > ( & new_config) ;
392378 let propagated_another = get_ext :: < AnotherExtension > ( & new_config) ;
0 commit comments