Skip to content

Commit 99ab70b

Browse files
authored
Add failing option extension propagation test (#246)
* Add failing test * Add comment
1 parent ce5218b commit 99ab70b

File tree

1 file changed

+67
-14
lines changed

1 file changed

+67
-14
lines changed

tests/custom_config_extension.rs

Lines changed: 67 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,26 @@ mod tests {
2626
use std::fmt::Formatter;
2727
use std::sync::Arc;
2828

29+
async fn build_state(
30+
ctx: DistributedSessionBuilderContext,
31+
) -> Result<SessionState, DataFusionError> {
32+
Ok(SessionStateBuilder::new()
33+
.with_runtime_env(ctx.runtime_env)
34+
.with_default_features()
35+
.with_distributed_option_extension_from_headers::<CustomExtension>(&ctx.headers)?
36+
.with_distributed_user_codec(CustomConfigExtensionRequiredExecCodec)
37+
.build())
38+
}
39+
2940
#[tokio::test]
3041
async fn custom_config_extension() -> Result<(), Box<dyn std::error::Error>> {
31-
async fn build_state(
32-
ctx: DistributedSessionBuilderContext,
33-
) -> Result<SessionState, DataFusionError> {
34-
Ok(SessionStateBuilder::new()
35-
.with_runtime_env(ctx.runtime_env)
36-
.with_default_features()
37-
.with_distributed_option_extension_from_headers::<CustomExtension>(&ctx.headers)?
38-
.with_distributed_user_codec(CustomConfigExtensionRequiredExecCodec)
39-
.build())
40-
}
41-
4242
let (mut ctx, _guard) = start_localhost_context(3, build_state).await;
4343
ctx = SessionStateBuilder::from(ctx.state())
4444
.with_distributed_option_extension(CustomExtension {
4545
foo: "foo".to_string(),
4646
bar: 1,
4747
baz: true,
48+
throw_err: false,
4849
})?
4950
.build()
5051
.into();
@@ -69,11 +70,61 @@ mod tests {
6970
Ok(())
7071
}
7172

73+
#[tokio::test]
74+
// TODO: the solution to this test failure is to, rather than dumping the config extension
75+
// fields into headers immediately when calling `with_distributed_option_extension()`, to instead
76+
// register the ConfigExtension::PREFIX as something that we should lazily capture and send
77+
// in the headers of every network request. In order to do that, first this PR upstream
78+
// https://github.com/apache/datafusion/pull/18887 needs to be shipped. It will be available
79+
// in DataFusion 52.0.0.
80+
#[ignore]
81+
async fn custom_config_extension_runtime_change() -> Result<(), Box<dyn std::error::Error>> {
82+
let (mut ctx, _guard) = start_localhost_context(3, build_state).await;
83+
ctx = SessionStateBuilder::from(ctx.state())
84+
.with_distributed_option_extension(CustomExtension {
85+
throw_err: true,
86+
..Default::default()
87+
})?
88+
.build()
89+
.into();
90+
91+
let mut plan: Arc<dyn ExecutionPlan> = Arc::new(CustomConfigExtensionRequiredExec::new());
92+
93+
for size in [1, 2, 3] {
94+
plan = Arc::new(NetworkShuffleExec::try_new(
95+
Arc::new(RepartitionExec::try_new(
96+
plan,
97+
Partitioning::Hash(vec![], 10),
98+
)?),
99+
size,
100+
)?);
101+
}
102+
103+
let plan = distribute_plan(plan)?.unwrap();
104+
105+
// If the value is modified after setting it as a distributed option extension, it should
106+
// propagate the correct headers.
107+
ctx.state_ref()
108+
.write()
109+
.config_mut()
110+
.options_mut()
111+
.extensions
112+
.get_mut::<CustomExtension>()
113+
.unwrap()
114+
.throw_err = false;
115+
let stream = execute_stream(plan, ctx.task_ctx())?;
116+
// It should not fail.
117+
stream.try_collect::<Vec<_>>().await?;
118+
119+
Ok(())
120+
}
121+
72122
extensions_options! {
73123
pub struct CustomExtension {
74124
pub foo: String, default = "".to_string()
75125
pub bar: usize, default = 0
76126
pub baz: bool, default = false
127+
pub throw_err: bool, default = true
77128
}
78129
}
79130

@@ -135,14 +186,16 @@ mod tests {
135186
_: usize,
136187
ctx: Arc<TaskContext>,
137188
) -> datafusion::common::Result<SendableRecordBatchStream> {
138-
if ctx
189+
let Some(ext) = ctx
139190
.session_config()
140191
.options()
141192
.extensions
142193
.get::<CustomExtension>()
143-
.is_none()
144-
{
194+
else {
145195
return internal_err!("CustomExtension not found in context");
196+
};
197+
if ext.throw_err {
198+
return internal_err!("CustomExtension requested an error to be thrown");
146199
}
147200
Ok(Box::pin(RecordBatchStreamAdapter::new(
148201
self.schema(),

0 commit comments

Comments
 (0)