Skip to content

Commit f8bb759

Browse files
authored
Improve SessionBuilder ergonomy and fix clippy errors (#103)
* Add `ConfigExtensionExt`, allowing the propagation of arbitrary [ConfigExtension]s across network boundaries * Rename propagate_distributed_option_extension to retrieve_distributed_option_extension * Change x-datafusion-distributed- to x-datafusion-distributed-config- * Check for full format!("{}.", T::PREFIX) in gRPC keys * Remove double clone * Fix tests * Fix tests * Rework SessionBuilder * Fix clippy errors
1 parent 113cc1b commit f8bb759

File tree

16 files changed

+245
-300
lines changed

16 files changed

+245
-300
lines changed

benchmarks/src/tpch/run.rs

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,16 @@ use datafusion::datasource::listing::{
3636
};
3737
use datafusion::datasource::{MemTable, TableProvider};
3838
use datafusion::error::{DataFusionError, Result};
39-
use datafusion::execution::SessionStateBuilder;
39+
use datafusion::execution::{SessionState, SessionStateBuilder};
4040
use datafusion::physical_plan::display::DisplayableExecutionPlan;
4141
use datafusion::physical_plan::{collect, displayable};
4242
use datafusion::prelude::*;
4343

4444
use crate::util::{print_memory_stats, BenchmarkRun, CommonOpt, QueryResult};
4545
use datafusion_distributed::test_utils::localhost::start_localhost_context;
46-
use datafusion_distributed::{DistributedPhysicalOptimizerRule, SessionBuilder};
46+
use datafusion_distributed::{
47+
DistributedPhysicalOptimizerRule, DistributedSessionBuilder, DistributedSessionBuilderContext,
48+
};
4749
use log::info;
4850
use structopt::StructOpt;
4951

@@ -110,11 +112,13 @@ pub struct RunOpt {
110112
}
111113

112114
#[async_trait]
113-
impl SessionBuilder for RunOpt {
114-
fn session_state_builder(
115+
impl DistributedSessionBuilder for RunOpt {
116+
async fn build_session_state(
115117
&self,
116-
mut builder: SessionStateBuilder,
117-
) -> Result<SessionStateBuilder, DataFusionError> {
118+
_ctx: DistributedSessionBuilderContext,
119+
) -> Result<SessionState, DataFusionError> {
120+
let mut builder = SessionStateBuilder::new().with_default_features();
121+
118122
let mut config = self
119123
.common
120124
.config()?
@@ -145,17 +149,14 @@ impl SessionBuilder for RunOpt {
145149
builder = builder.with_physical_optimizer_rule(Arc::new(rule));
146150
}
147151

148-
Ok(builder
152+
let state = builder
149153
.with_config(config)
150-
.with_runtime_env(rt_builder.build_arc()?))
151-
}
154+
.with_runtime_env(rt_builder.build_arc()?)
155+
.build();
152156

153-
async fn session_context(
154-
&self,
155-
ctx: SessionContext,
156-
) -> std::result::Result<SessionContext, DataFusionError> {
157+
let ctx = SessionContext::from(state);
157158
self.register_tables(&ctx).await?;
158-
Ok(ctx)
159+
Ok(ctx.state())
159160
}
160161
}
161162

src/common/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1+
#[allow(unused)]
12
pub mod ttl_map;
23
pub mod util;

src/common/ttl_map.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ where
9393
shard.insert(key);
9494
}
9595
BucketOp::Clear => {
96-
let keys_to_delete = std::mem::replace(&mut shard, HashSet::new());
96+
let keys_to_delete = std::mem::take(&mut shard);
9797
for key in keys_to_delete {
9898
data.remove(&key);
9999
}
@@ -252,14 +252,14 @@ where
252252

253253
/// run_gc_loop will continuously clear expired entries from the map, checking every `period`. The
254254
/// function terminates if `shutdown` is signalled.
255-
async fn run_gc_loop(time: Arc<AtomicU64>, period: Duration, buckets: &Vec<Bucket<K>>) {
255+
async fn run_gc_loop(time: Arc<AtomicU64>, period: Duration, buckets: &[Bucket<K>]) {
256256
loop {
257257
tokio::time::sleep(period).await;
258258
Self::gc(time.clone(), buckets);
259259
}
260260
}
261261

262-
fn gc(time: Arc<AtomicU64>, buckets: &Vec<Bucket<K>>) {
262+
fn gc(time: Arc<AtomicU64>, buckets: &[Bucket<K>]) {
263263
let index = time.load(std::sync::atomic::Ordering::SeqCst) % buckets.len() as u64;
264264
buckets[index as usize].clear();
265265
time.fetch_add(1, std::sync::atomic::Ordering::SeqCst);

src/config_extension_ext.rs

Lines changed: 50 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ pub trait ConfigExtensionExt {
2727
/// # use async_trait::async_trait;
2828
/// # use datafusion::common::{extensions_options, DataFusionError};
2929
/// # use datafusion::config::ConfigExtension;
30-
/// # use datafusion::execution::SessionState;
30+
/// # use datafusion::execution::{SessionState, SessionStateBuilder};
3131
/// # use datafusion::prelude::SessionConfig;
32-
/// # use datafusion_distributed::{ConfigExtensionExt, SessionBuilder};
32+
/// # use datafusion_distributed::{ConfigExtensionExt, DistributedSessionBuilder, DistributedSessionBuilderContext};
3333
///
3434
/// extensions_options! {
3535
/// pub struct CustomExtension {
@@ -52,11 +52,13 @@ pub trait ConfigExtensionExt {
5252
/// struct MyCustomSessionBuilder;
5353
///
5454
/// #[async_trait]
55-
/// impl SessionBuilder for MyCustomSessionBuilder {
56-
/// async fn session_state(&self, mut state: SessionState) -> Result<SessionState, DataFusionError> {
55+
/// impl DistributedSessionBuilder for MyCustomSessionBuilder {
56+
/// async fn build_session_state(&self, ctx: DistributedSessionBuilderContext) -> Result<SessionState, DataFusionError> {
57+
/// let mut state = SessionStateBuilder::new().build();
58+
///
5759
/// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will
5860
/// // know how to deserialize the CustomExtension from the gRPC metadata.
59-
/// state.retrieve_distributed_option_extension::<CustomExtension>()?;
61+
/// state.retrieve_distributed_option_extension::<CustomExtension>(&ctx.headers)?;
6062
/// Ok(state)
6163
/// }
6264
/// }
@@ -76,9 +78,9 @@ pub trait ConfigExtensionExt {
7678
/// # use async_trait::async_trait;
7779
/// # use datafusion::common::{extensions_options, DataFusionError};
7880
/// # use datafusion::config::ConfigExtension;
79-
/// # use datafusion::execution::SessionState;
81+
/// # use datafusion::execution::{SessionState, SessionStateBuilder};
8082
/// # use datafusion::prelude::SessionConfig;
81-
/// # use datafusion_distributed::{ConfigExtensionExt, SessionBuilder};
83+
/// # use datafusion_distributed::{ConfigExtensionExt, DistributedSessionBuilder, DistributedSessionBuilderContext};
8284
///
8385
/// extensions_options! {
8486
/// pub struct CustomExtension {
@@ -101,17 +103,19 @@ pub trait ConfigExtensionExt {
101103
/// struct MyCustomSessionBuilder;
102104
///
103105
/// #[async_trait]
104-
/// impl SessionBuilder for MyCustomSessionBuilder {
105-
/// async fn session_state(&self, mut state: SessionState) -> Result<SessionState, DataFusionError> {
106+
/// impl DistributedSessionBuilder for MyCustomSessionBuilder {
107+
/// async fn build_session_state(&self, ctx: DistributedSessionBuilderContext) -> Result<SessionState, DataFusionError> {
108+
/// let mut state = SessionStateBuilder::new().build();
106109
/// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will
107110
/// // know how to deserialize the CustomExtension from the gRPC metadata.
108-
/// state.retrieve_distributed_option_extension::<CustomExtension>()?;
111+
/// state.retrieve_distributed_option_extension::<CustomExtension>(&ctx.headers)?;
109112
/// Ok(state)
110113
/// }
111114
/// }
112115
/// ```
113116
fn retrieve_distributed_option_extension<T: ConfigExtension + Default>(
114117
&mut self,
118+
headers: &HeaderMap,
115119
) -> Result<(), DataFusionError>;
116120
}
117121

@@ -153,20 +157,17 @@ impl ConfigExtensionExt for SessionConfig {
153157

154158
fn retrieve_distributed_option_extension<T: ConfigExtension + Default>(
155159
&mut self,
160+
headers: &HeaderMap,
156161
) -> Result<(), DataFusionError> {
157-
let Some(flight_metadata) = self.get_extension::<ContextGrpcMetadata>() else {
158-
return Ok(());
159-
};
160-
161162
let mut result = T::default();
162163
let mut found_some = false;
163-
for (k, v) in flight_metadata.0.iter() {
164+
for (k, v) in headers.iter() {
164165
let key = k.as_str().trim_start_matches(FLIGHT_METADATA_CONFIG_PREFIX);
165166
let prefix = format!("{}.", T::PREFIX);
166167
if key.starts_with(&prefix) {
167168
found_some = true;
168169
result.set(
169-
&key.trim_start_matches(&prefix),
170+
key.trim_start_matches(&prefix),
170171
v.to_str().map_err(|err| {
171172
internal_datafusion_err!("Cannot parse header value: {err}")
172173
})?,
@@ -185,7 +186,7 @@ impl ConfigExtensionExt for SessionStateBuilder {
185186
delegate! {
186187
to self.config().get_or_insert_default() {
187188
fn add_distributed_option_extension<T: ConfigExtension + Default>(&mut self, t: T) -> Result<(), DataFusionError>;
188-
fn retrieve_distributed_option_extension<T: ConfigExtension + Default>(&mut self) -> Result<(), DataFusionError>;
189+
fn retrieve_distributed_option_extension<T: ConfigExtension + Default>(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>;
189190
}
190191
}
191192
}
@@ -194,7 +195,7 @@ impl ConfigExtensionExt for SessionState {
194195
delegate! {
195196
to self.config_mut() {
196197
fn add_distributed_option_extension<T: ConfigExtension + Default>(&mut self, t: T) -> Result<(), DataFusionError>;
197-
fn retrieve_distributed_option_extension<T: ConfigExtension + Default>(&mut self) -> Result<(), DataFusionError>;
198+
fn retrieve_distributed_option_extension<T: ConfigExtension + Default>(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>;
198199
}
199200
}
200201
}
@@ -203,7 +204,7 @@ impl ConfigExtensionExt for SessionContext {
203204
delegate! {
204205
to self.state_ref().write().config_mut() {
205206
fn add_distributed_option_extension<T: ConfigExtension + Default>(&mut self, t: T) -> Result<(), DataFusionError>;
206-
fn retrieve_distributed_option_extension<T: ConfigExtension + Default>(&mut self) -> Result<(), DataFusionError>;
207+
fn retrieve_distributed_option_extension<T: ConfigExtension + Default>(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>;
207208
}
208209
}
209210
}
@@ -212,17 +213,6 @@ impl ConfigExtensionExt for SessionContext {
212213
pub(crate) struct ContextGrpcMetadata(pub HeaderMap);
213214

214215
impl ContextGrpcMetadata {
215-
pub(crate) fn from_headers(metadata: HeaderMap) -> Self {
216-
let mut new = HeaderMap::new();
217-
for (k, v) in metadata.into_iter() {
218-
let Some(k) = k else { continue };
219-
if k.as_str().starts_with(FLIGHT_METADATA_CONFIG_PREFIX) {
220-
new.insert(k, v);
221-
}
222-
}
223-
Self(new)
224-
}
225-
226216
fn merge(mut self, other: Self) -> Self {
227217
for (k, v) in other.0.into_iter() {
228218
let Some(k) = k else { continue };
@@ -246,16 +236,16 @@ mod tests {
246236
fn test_propagation() -> Result<(), Box<dyn std::error::Error>> {
247237
let mut config = SessionConfig::new();
248238

249-
let mut opt = CustomExtension::default();
250-
opt.foo = "foo".to_string();
251-
opt.bar = 1;
252-
opt.baz = true;
239+
let opt = CustomExtension {
240+
foo: "".to_string(),
241+
bar: 0,
242+
baz: false,
243+
};
253244

254245
config.add_distributed_option_extension(opt)?;
255-
246+
let metadata = config.get_extension::<ContextGrpcMetadata>().unwrap();
256247
let mut new_config = SessionConfig::new();
257-
new_config.set_extension(config.get_extension::<ContextGrpcMetadata>().unwrap());
258-
new_config.retrieve_distributed_option_extension::<CustomExtension>()?;
248+
new_config.retrieve_distributed_option_extension::<CustomExtension>(&metadata.0)?;
259249

260250
let opt = get_ext::<CustomExtension>(&config);
261251
let new_opt = get_ext::<CustomExtension>(&new_config);
@@ -294,12 +284,16 @@ mod tests {
294284
fn test_new_extension_overwrites_previous() -> Result<(), Box<dyn std::error::Error>> {
295285
let mut config = SessionConfig::new();
296286

297-
let mut opt1 = CustomExtension::default();
298-
opt1.foo = "first".to_string();
287+
let opt1 = CustomExtension {
288+
foo: "first".to_string(),
289+
..Default::default()
290+
};
299291
config.add_distributed_option_extension(opt1)?;
300292

301-
let mut opt2 = CustomExtension::default();
302-
opt2.bar = 42;
293+
let opt2 = CustomExtension {
294+
bar: 42,
295+
..Default::default()
296+
};
303297
config.add_distributed_option_extension(opt2)?;
304298

305299
let flight_metadata = config.get_extension::<ContextGrpcMetadata>().unwrap();
@@ -317,7 +311,7 @@ mod tests {
317311
fn test_propagate_no_metadata() -> Result<(), Box<dyn std::error::Error>> {
318312
let mut config = SessionConfig::new();
319313

320-
config.retrieve_distributed_option_extension::<CustomExtension>()?;
314+
config.retrieve_distributed_option_extension::<CustomExtension>(&Default::default())?;
321315

322316
let extension = config.options().extensions.get::<CustomExtension>();
323317
assert!(extension.is_none());
@@ -330,13 +324,11 @@ mod tests {
330324
let mut config = SessionConfig::new();
331325
let mut header_map = HeaderMap::new();
332326
header_map.insert(
333-
HeaderName::from_str("x-datafusion-distributed-other.setting").unwrap(),
327+
HeaderName::from_str("x-datafusion-distributed-config-other.setting").unwrap(),
334328
HeaderValue::from_str("value").unwrap(),
335329
);
336330

337-
let flight_metadata = ContextGrpcMetadata::from_headers(header_map);
338-
config.set_extension(std::sync::Arc::new(flight_metadata));
339-
config.retrieve_distributed_option_extension::<CustomExtension>()?;
331+
config.retrieve_distributed_option_extension::<CustomExtension>(&header_map)?;
340332

341333
let extension = config.options().extensions.get::<CustomExtension>();
342334
assert!(extension.is_none());
@@ -348,13 +340,17 @@ mod tests {
348340
fn test_multiple_extensions_different_prefixes() -> Result<(), Box<dyn std::error::Error>> {
349341
let mut config = SessionConfig::new();
350342

351-
let mut custom_opt = CustomExtension::default();
352-
custom_opt.foo = "custom_value".to_string();
353-
custom_opt.bar = 123;
343+
let custom_opt = CustomExtension {
344+
foo: "custom_value".to_string(),
345+
bar: 123,
346+
..Default::default()
347+
};
354348

355-
let mut another_opt = AnotherExtension::default();
356-
another_opt.setting1 = "other".to_string();
357-
another_opt.setting2 = 456;
349+
let another_opt = AnotherExtension {
350+
setting1: "other".to_string(),
351+
setting2: 456,
352+
..Default::default()
353+
};
358354

359355
config.add_distributed_option_extension(custom_opt)?;
360356
config.add_distributed_option_extension(another_opt)?;
@@ -384,9 +380,8 @@ mod tests {
384380
);
385381

386382
let mut new_config = SessionConfig::new();
387-
new_config.set_extension(flight_metadata);
388-
new_config.retrieve_distributed_option_extension::<CustomExtension>()?;
389-
new_config.retrieve_distributed_option_extension::<AnotherExtension>()?;
383+
new_config.retrieve_distributed_option_extension::<CustomExtension>(metadata)?;
384+
new_config.retrieve_distributed_option_extension::<AnotherExtension>(metadata)?;
390385

391386
let propagated_custom = get_ext::<CustomExtension>(&new_config);
392387
let propagated_another = get_ext::<AnotherExtension>(&new_config);

0 commit comments

Comments
 (0)