Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 67 additions & 14 deletions tests/custom_config_extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,26 @@ mod tests {
use std::fmt::Formatter;
use std::sync::Arc;

async fn build_state(
ctx: DistributedSessionBuilderContext,
) -> Result<SessionState, DataFusionError> {
Ok(SessionStateBuilder::new()
.with_runtime_env(ctx.runtime_env)
.with_default_features()
.with_distributed_option_extension_from_headers::<CustomExtension>(&ctx.headers)?
.with_distributed_user_codec(CustomConfigExtensionRequiredExecCodec)
.build())
}

#[tokio::test]
async fn custom_config_extension() -> Result<(), Box<dyn std::error::Error>> {
async fn build_state(
ctx: DistributedSessionBuilderContext,
) -> Result<SessionState, DataFusionError> {
Ok(SessionStateBuilder::new()
.with_runtime_env(ctx.runtime_env)
.with_default_features()
.with_distributed_option_extension_from_headers::<CustomExtension>(&ctx.headers)?
.with_distributed_user_codec(CustomConfigExtensionRequiredExecCodec)
.build())
}

let (mut ctx, _guard) = start_localhost_context(3, build_state).await;
ctx = SessionStateBuilder::from(ctx.state())
.with_distributed_option_extension(CustomExtension {
foo: "foo".to_string(),
bar: 1,
baz: true,
throw_err: false,
})?
.build()
.into();
Expand All @@ -69,11 +70,61 @@ mod tests {
Ok(())
}

#[tokio::test]
// TODO: the solution to this test failure is to, rather than dumping the config extension
// fields into headers immediately when calling `with_distributed_option_extension()`, to instead
// register the ConfigExtension::PREFIX as something that we should lazily capture and send
// in the headers of every network request. In order to do that, first this PR upstream
// https://github.com/apache/datafusion/pull/18887 needs to be shipped. It will be available
// in DataFusion 52.0.0.
#[ignore]
async fn custom_config_extension_runtime_change() -> Result<(), Box<dyn std::error::Error>> {
let (mut ctx, _guard) = start_localhost_context(3, build_state).await;
ctx = SessionStateBuilder::from(ctx.state())
.with_distributed_option_extension(CustomExtension {
throw_err: true,
..Default::default()
})?
.build()
.into();

let mut plan: Arc<dyn ExecutionPlan> = Arc::new(CustomConfigExtensionRequiredExec::new());

for size in [1, 2, 3] {
plan = Arc::new(NetworkShuffleExec::try_new(
Arc::new(RepartitionExec::try_new(
plan,
Partitioning::Hash(vec![], 10),
)?),
size,
)?);
}

let plan = distribute_plan(plan)?.unwrap();

// If the value is modified after setting it as a distributed option extension, it should
// propagate the correct headers.
ctx.state_ref()
.write()
.config_mut()
.options_mut()
.extensions
.get_mut::<CustomExtension>()
.unwrap()
.throw_err = false;
let stream = execute_stream(plan, ctx.task_ctx())?;
// It should not fail.
stream.try_collect::<Vec<_>>().await?;

Ok(())
}

extensions_options! {
pub struct CustomExtension {
pub foo: String, default = "".to_string()
pub bar: usize, default = 0
pub baz: bool, default = false
pub throw_err: bool, default = true
}
}

Expand Down Expand Up @@ -135,14 +186,16 @@ mod tests {
_: usize,
ctx: Arc<TaskContext>,
) -> datafusion::common::Result<SendableRecordBatchStream> {
if ctx
let Some(ext) = ctx
.session_config()
.options()
.extensions
.get::<CustomExtension>()
.is_none()
{
else {
return internal_err!("CustomExtension not found in context");
};
if ext.throw_err {
return internal_err!("CustomExtension requested an error to be thrown");
}
Ok(Box::pin(RecordBatchStreamAdapter::new(
self.schema(),
Expand Down