Skip to content

Commit d4b0c35

Browse files
committed
Add failing test
1 parent f6dfaa6 commit d4b0c35

File tree

1 file changed

+60
-14
lines changed

1 file changed

+60
-14
lines changed

tests/custom_config_extension.rs

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,26 @@ mod tests {
2626
use std::fmt::Formatter;
2727
use std::sync::Arc;
2828

29+
async fn build_state(
30+
ctx: DistributedSessionBuilderContext,
31+
) -> Result<SessionState, DataFusionError> {
32+
Ok(SessionStateBuilder::new()
33+
.with_runtime_env(ctx.runtime_env)
34+
.with_default_features()
35+
.with_distributed_option_extension_from_headers::<CustomExtension>(&ctx.headers)?
36+
.with_distributed_user_codec(CustomConfigExtensionRequiredExecCodec)
37+
.build())
38+
}
39+
2940
#[tokio::test]
3041
async fn custom_config_extension() -> Result<(), Box<dyn std::error::Error>> {
31-
async fn build_state(
32-
ctx: DistributedSessionBuilderContext,
33-
) -> Result<SessionState, DataFusionError> {
34-
Ok(SessionStateBuilder::new()
35-
.with_runtime_env(ctx.runtime_env)
36-
.with_default_features()
37-
.with_distributed_option_extension_from_headers::<CustomExtension>(&ctx.headers)?
38-
.with_distributed_user_codec(CustomConfigExtensionRequiredExecCodec)
39-
.build())
40-
}
41-
4242
let (mut ctx, _guard) = start_localhost_context(3, build_state).await;
4343
ctx = SessionStateBuilder::from(ctx.state())
4444
.with_distributed_option_extension(CustomExtension {
4545
foo: "foo".to_string(),
4646
bar: 1,
4747
baz: true,
48+
throw_err: false,
4849
})?
4950
.build()
5051
.into();
@@ -69,11 +70,54 @@ mod tests {
6970
Ok(())
7071
}
7172

73+
#[tokio::test]
74+
async fn custom_config_extension_runtime_change() -> Result<(), Box<dyn std::error::Error>> {
75+
let (mut ctx, _guard) = start_localhost_context(3, build_state).await;
76+
ctx = SessionStateBuilder::from(ctx.state())
77+
.with_distributed_option_extension(CustomExtension {
78+
throw_err: true,
79+
..Default::default()
80+
})?
81+
.build()
82+
.into();
83+
84+
let mut plan: Arc<dyn ExecutionPlan> = Arc::new(CustomConfigExtensionRequiredExec::new());
85+
86+
for size in [1, 2, 3] {
87+
plan = Arc::new(NetworkShuffleExec::try_new(
88+
Arc::new(RepartitionExec::try_new(
89+
plan,
90+
Partitioning::Hash(vec![], 10),
91+
)?),
92+
size,
93+
)?);
94+
}
95+
96+
let plan = distribute_plan(plan)?.unwrap();
97+
98+
// If the value is modified after setting it as a distributed option extension, it should
99+
// propagate the correct headers.
100+
ctx.state_ref()
101+
.write()
102+
.config_mut()
103+
.options_mut()
104+
.extensions
105+
.get_mut::<CustomExtension>()
106+
.unwrap()
107+
.throw_err = false;
108+
let stream = execute_stream(plan, ctx.task_ctx())?;
109+
// It should not fail.
110+
stream.try_collect::<Vec<_>>().await?;
111+
112+
Ok(())
113+
}
114+
72115
extensions_options! {
73116
pub struct CustomExtension {
74117
pub foo: String, default = "".to_string()
75118
pub bar: usize, default = 0
76119
pub baz: bool, default = false
120+
pub throw_err: bool, default = true
77121
}
78122
}
79123

@@ -135,14 +179,16 @@ mod tests {
135179
_: usize,
136180
ctx: Arc<TaskContext>,
137181
) -> datafusion::common::Result<SendableRecordBatchStream> {
138-
if ctx
182+
let Some(ext) = ctx
139183
.session_config()
140184
.options()
141185
.extensions
142186
.get::<CustomExtension>()
143-
.is_none()
144-
{
187+
else {
145188
return internal_err!("CustomExtension not found in context");
189+
};
190+
if ext.throw_err {
191+
return internal_err!("CustomExtension requested an error to be thrown");
146192
}
147193
Ok(Box::pin(RecordBatchStreamAdapter::new(
148194
self.schema(),

0 commit comments

Comments
 (0)