diff --git a/tests/custom_config_extension.rs b/tests/custom_config_extension.rs index 2ffa0c9..5733f77 100644 --- a/tests/custom_config_extension.rs +++ b/tests/custom_config_extension.rs @@ -26,25 +26,26 @@ mod tests { use std::fmt::Formatter; use std::sync::Arc; + async fn build_state( + ctx: DistributedSessionBuilderContext, + ) -> Result { + Ok(SessionStateBuilder::new() + .with_runtime_env(ctx.runtime_env) + .with_default_features() + .with_distributed_option_extension_from_headers::(&ctx.headers)? + .with_distributed_user_codec(CustomConfigExtensionRequiredExecCodec) + .build()) + } + #[tokio::test] async fn custom_config_extension() -> Result<(), Box> { - async fn build_state( - ctx: DistributedSessionBuilderContext, - ) -> Result { - Ok(SessionStateBuilder::new() - .with_runtime_env(ctx.runtime_env) - .with_default_features() - .with_distributed_option_extension_from_headers::(&ctx.headers)? - .with_distributed_user_codec(CustomConfigExtensionRequiredExecCodec) - .build()) - } - let (mut ctx, _guard) = start_localhost_context(3, build_state).await; ctx = SessionStateBuilder::from(ctx.state()) .with_distributed_option_extension(CustomExtension { foo: "foo".to_string(), bar: 1, baz: true, + throw_err: false, })? .build() .into(); @@ -69,11 +70,61 @@ mod tests { Ok(()) } + #[tokio::test] + // TODO: the solution to this test failure is to, rather than dumping the config extension + // fields into headers immediately when calling `with_distributed_option_extension()`, to instead + // register the ConfigExtension::PREFIX as something that we should lazily capture and send + // in the headers of every network request. In order to do that, first this PR upstream + // https://github.com/apache/datafusion/pull/18887 needs to be shipped. It will be available + // in DataFusion 52.0.0. + #[ignore] + async fn custom_config_extension_runtime_change() -> Result<(), Box> { + let (mut ctx, _guard) = start_localhost_context(3, build_state).await; + ctx = SessionStateBuilder::from(ctx.state()) + .with_distributed_option_extension(CustomExtension { + throw_err: true, + ..Default::default() + })? + .build() + .into(); + + let mut plan: Arc = Arc::new(CustomConfigExtensionRequiredExec::new()); + + for size in [1, 2, 3] { + plan = Arc::new(NetworkShuffleExec::try_new( + Arc::new(RepartitionExec::try_new( + plan, + Partitioning::Hash(vec![], 10), + )?), + size, + )?); + } + + let plan = distribute_plan(plan)?.unwrap(); + + // If the value is modified after setting it as a distributed option extension, it should + // propagate the correct headers. + ctx.state_ref() + .write() + .config_mut() + .options_mut() + .extensions + .get_mut::() + .unwrap() + .throw_err = false; + let stream = execute_stream(plan, ctx.task_ctx())?; + // It should not fail. + stream.try_collect::>().await?; + + Ok(()) + } + extensions_options! { pub struct CustomExtension { pub foo: String, default = "".to_string() pub bar: usize, default = 0 pub baz: bool, default = false + pub throw_err: bool, default = true } } @@ -135,14 +186,16 @@ mod tests { _: usize, ctx: Arc, ) -> datafusion::common::Result { - if ctx + let Some(ext) = ctx .session_config() .options() .extensions .get::() - .is_none() - { + else { return internal_err!("CustomExtension not found in context"); + }; + if ext.throw_err { + return internal_err!("CustomExtension requested an error to be thrown"); } Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(),