@@ -26,25 +26,26 @@ mod tests {
2626 use std:: fmt:: Formatter ;
2727 use std:: sync:: Arc ;
2828
29+ async fn build_state (
30+ ctx : DistributedSessionBuilderContext ,
31+ ) -> Result < SessionState , DataFusionError > {
32+ Ok ( SessionStateBuilder :: new ( )
33+ . with_runtime_env ( ctx. runtime_env )
34+ . with_default_features ( )
35+ . with_distributed_option_extension_from_headers :: < CustomExtension > ( & ctx. headers ) ?
36+ . with_distributed_user_codec ( CustomConfigExtensionRequiredExecCodec )
37+ . build ( ) )
38+ }
39+
2940 #[ tokio:: test]
3041 async fn custom_config_extension ( ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
31- async fn build_state (
32- ctx : DistributedSessionBuilderContext ,
33- ) -> Result < SessionState , DataFusionError > {
34- Ok ( SessionStateBuilder :: new ( )
35- . with_runtime_env ( ctx. runtime_env )
36- . with_default_features ( )
37- . with_distributed_option_extension_from_headers :: < CustomExtension > ( & ctx. headers ) ?
38- . with_distributed_user_codec ( CustomConfigExtensionRequiredExecCodec )
39- . build ( ) )
40- }
41-
4242 let ( mut ctx, _guard) = start_localhost_context ( 3 , build_state) . await ;
4343 ctx = SessionStateBuilder :: from ( ctx. state ( ) )
4444 . with_distributed_option_extension ( CustomExtension {
4545 foo : "foo" . to_string ( ) ,
4646 bar : 1 ,
4747 baz : true ,
48+ throw_err : false ,
4849 } ) ?
4950 . build ( )
5051 . into ( ) ;
@@ -69,11 +70,54 @@ mod tests {
6970 Ok ( ( ) )
7071 }
7172
73+ #[ tokio:: test]
74+ async fn custom_config_extension_runtime_change ( ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
75+ let ( mut ctx, _guard) = start_localhost_context ( 3 , build_state) . await ;
76+ ctx = SessionStateBuilder :: from ( ctx. state ( ) )
77+ . with_distributed_option_extension ( CustomExtension {
78+ throw_err : true ,
79+ ..Default :: default ( )
80+ } ) ?
81+ . build ( )
82+ . into ( ) ;
83+
84+ let mut plan: Arc < dyn ExecutionPlan > = Arc :: new ( CustomConfigExtensionRequiredExec :: new ( ) ) ;
85+
86+ for size in [ 1 , 2 , 3 ] {
87+ plan = Arc :: new ( NetworkShuffleExec :: try_new (
88+ Arc :: new ( RepartitionExec :: try_new (
89+ plan,
90+ Partitioning :: Hash ( vec ! [ ] , 10 ) ,
91+ ) ?) ,
92+ size,
93+ ) ?) ;
94+ }
95+
96+ let plan = distribute_plan ( plan) ?. unwrap ( ) ;
97+
98+ // If the value is modified after setting it as a distributed option extension, it should
99+ // propagate the correct headers.
100+ ctx. state_ref ( )
101+ . write ( )
102+ . config_mut ( )
103+ . options_mut ( )
104+ . extensions
105+ . get_mut :: < CustomExtension > ( )
106+ . unwrap ( )
107+ . throw_err = false ;
108+ let stream = execute_stream ( plan, ctx. task_ctx ( ) ) ?;
109+ // It should not fail.
110+ stream. try_collect :: < Vec < _ > > ( ) . await ?;
111+
112+ Ok ( ( ) )
113+ }
114+
72115 extensions_options ! {
73116 pub struct CustomExtension {
74117 pub foo: String , default = "" . to_string( )
75118 pub bar: usize , default = 0
76119 pub baz: bool , default = false
120+ pub throw_err: bool , default = true
77121 }
78122 }
79123
@@ -135,14 +179,16 @@ mod tests {
135179 _: usize ,
136180 ctx : Arc < TaskContext > ,
137181 ) -> datafusion:: common:: Result < SendableRecordBatchStream > {
138- if ctx
182+ let Some ( ext ) = ctx
139183 . session_config ( )
140184 . options ( )
141185 . extensions
142186 . get :: < CustomExtension > ( )
143- . is_none ( )
144- {
187+ else {
145188 return internal_err ! ( "CustomExtension not found in context" ) ;
189+ } ;
190+ if ext. throw_err {
191+ return internal_err ! ( "CustomExtension requested an error to be thrown" ) ;
146192 }
147193 Ok ( Box :: pin ( RecordBatchStreamAdapter :: new (
148194 self . schema ( ) ,
0 commit comments