Skip to content

Commit 25d24d2

Browse files
committed
Rework SessionBuilder
1 parent d35ddff commit 25d24d2

File tree

15 files changed

+218
-282
lines changed

15 files changed

+218
-282
lines changed

benchmarks/src/tpch/run.rs

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,16 @@ use datafusion::datasource::listing::{
3636
};
3737
use datafusion::datasource::{MemTable, TableProvider};
3838
use datafusion::error::{DataFusionError, Result};
39-
use datafusion::execution::SessionStateBuilder;
39+
use datafusion::execution::{SessionState, SessionStateBuilder};
4040
use datafusion::physical_plan::display::DisplayableExecutionPlan;
4141
use datafusion::physical_plan::{collect, displayable};
4242
use datafusion::prelude::*;
4343

4444
use crate::util::{print_memory_stats, BenchmarkRun, CommonOpt, QueryResult};
4545
use datafusion_distributed::test_utils::localhost::start_localhost_context;
46-
use datafusion_distributed::{DistributedPhysicalOptimizerRule, SessionBuilder};
46+
use datafusion_distributed::{
47+
DistributedPhysicalOptimizerRule, DistributedSessionBuilder, DistributedSessionBuilderContext,
48+
};
4749
use log::info;
4850
use structopt::StructOpt;
4951

@@ -110,11 +112,13 @@ pub struct RunOpt {
110112
}
111113

112114
#[async_trait]
113-
impl SessionBuilder for RunOpt {
114-
fn session_state_builder(
115+
impl DistributedSessionBuilder for RunOpt {
116+
async fn build_session_state(
115117
&self,
116-
mut builder: SessionStateBuilder,
117-
) -> Result<SessionStateBuilder, DataFusionError> {
118+
_ctx: DistributedSessionBuilderContext,
119+
) -> Result<SessionState, DataFusionError> {
120+
let mut builder = SessionStateBuilder::new().with_default_features();
121+
118122
let mut config = self
119123
.common
120124
.config()?
@@ -145,17 +149,14 @@ impl SessionBuilder for RunOpt {
145149
builder = builder.with_physical_optimizer_rule(Arc::new(rule));
146150
}
147151

148-
Ok(builder
152+
let state = builder
149153
.with_config(config)
150-
.with_runtime_env(rt_builder.build_arc()?))
151-
}
154+
.with_runtime_env(rt_builder.build_arc()?)
155+
.build();
152156

153-
async fn session_context(
154-
&self,
155-
ctx: SessionContext,
156-
) -> std::result::Result<SessionContext, DataFusionError> {
157+
let ctx = SessionContext::from(state);
157158
self.register_tables(&ctx).await?;
158-
Ok(ctx)
159+
Ok(ctx.state())
159160
}
160161
}
161162

src/common/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1+
#[allow(unused)]
12
pub mod ttl_map;
23
pub mod util;

src/config_extension_ext.rs

Lines changed: 26 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ pub trait ConfigExtensionExt {
2727
/// # use async_trait::async_trait;
2828
/// # use datafusion::common::{extensions_options, DataFusionError};
2929
/// # use datafusion::config::ConfigExtension;
30-
/// # use datafusion::execution::SessionState;
30+
/// # use datafusion::execution::{SessionState, SessionStateBuilder};
3131
/// # use datafusion::prelude::SessionConfig;
32-
/// # use datafusion_distributed::{ConfigExtensionExt, SessionBuilder};
32+
/// # use datafusion_distributed::{ConfigExtensionExt, DistributedSessionBuilder, DistributedSessionBuilderContext};
3333
///
3434
/// extensions_options! {
3535
/// pub struct CustomExtension {
@@ -52,11 +52,13 @@ pub trait ConfigExtensionExt {
5252
/// struct MyCustomSessionBuilder;
5353
///
5454
/// #[async_trait]
55-
/// impl SessionBuilder for MyCustomSessionBuilder {
56-
/// async fn session_state(&self, mut state: SessionState) -> Result<SessionState, DataFusionError> {
55+
/// impl DistributedSessionBuilder for MyCustomSessionBuilder {
56+
/// async fn build_session_state(&self, ctx: DistributedSessionBuilderContext) -> Result<SessionState, DataFusionError> {
57+
/// let mut state = SessionStateBuilder::new().build();
58+
///
5759
/// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will
5860
/// // know how to deserialize the CustomExtension from the gRPC metadata.
59-
/// state.retrieve_distributed_option_extension::<CustomExtension>()?;
61+
/// state.retrieve_distributed_option_extension::<CustomExtension>(&ctx.headers)?;
6062
/// Ok(state)
6163
/// }
6264
/// }
@@ -76,9 +78,9 @@ pub trait ConfigExtensionExt {
7678
/// # use async_trait::async_trait;
7779
/// # use datafusion::common::{extensions_options, DataFusionError};
7880
/// # use datafusion::config::ConfigExtension;
79-
/// # use datafusion::execution::SessionState;
81+
/// # use datafusion::execution::{SessionState, SessionStateBuilder};
8082
/// # use datafusion::prelude::SessionConfig;
81-
/// # use datafusion_distributed::{ConfigExtensionExt, SessionBuilder};
83+
/// # use datafusion_distributed::{ConfigExtensionExt, DistributedSessionBuilder, DistributedSessionBuilderContext};
8284
///
8385
/// extensions_options! {
8486
/// pub struct CustomExtension {
@@ -101,17 +103,19 @@ pub trait ConfigExtensionExt {
101103
/// struct MyCustomSessionBuilder;
102104
///
103105
/// #[async_trait]
104-
/// impl SessionBuilder for MyCustomSessionBuilder {
105-
/// async fn session_state(&self, mut state: SessionState) -> Result<SessionState, DataFusionError> {
106+
/// impl DistributedSessionBuilder for MyCustomSessionBuilder {
107+
/// async fn build_session_state(&self, ctx: DistributedSessionBuilderContext) -> Result<SessionState, DataFusionError> {
108+
/// let mut state = SessionStateBuilder::new().build();
106109
/// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will
107110
/// // know how to deserialize the CustomExtension from the gRPC metadata.
108-
/// state.retrieve_distributed_option_extension::<CustomExtension>()?;
111+
/// state.retrieve_distributed_option_extension::<CustomExtension>(&ctx.headers)?;
109112
/// Ok(state)
110113
/// }
111114
/// }
112115
/// ```
113116
fn retrieve_distributed_option_extension<T: ConfigExtension + Default>(
114117
&mut self,
118+
headers: &HeaderMap,
115119
) -> Result<(), DataFusionError>;
116120
}
117121

@@ -153,14 +157,11 @@ impl ConfigExtensionExt for SessionConfig {
153157

154158
fn retrieve_distributed_option_extension<T: ConfigExtension + Default>(
155159
&mut self,
160+
headers: &HeaderMap,
156161
) -> Result<(), DataFusionError> {
157-
let Some(flight_metadata) = self.get_extension::<ContextGrpcMetadata>() else {
158-
return Ok(());
159-
};
160-
161162
let mut result = T::default();
162163
let mut found_some = false;
163-
for (k, v) in flight_metadata.0.iter() {
164+
for (k, v) in headers.iter() {
164165
let key = k.as_str().trim_start_matches(FLIGHT_METADATA_CONFIG_PREFIX);
165166
let prefix = format!("{}.", T::PREFIX);
166167
if key.starts_with(&prefix) {
@@ -185,7 +186,7 @@ impl ConfigExtensionExt for SessionStateBuilder {
185186
delegate! {
186187
to self.config().get_or_insert_default() {
187188
fn add_distributed_option_extension<T: ConfigExtension + Default>(&mut self, t: T) -> Result<(), DataFusionError>;
188-
fn retrieve_distributed_option_extension<T: ConfigExtension + Default>(&mut self) -> Result<(), DataFusionError>;
189+
fn retrieve_distributed_option_extension<T: ConfigExtension + Default>(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>;
189190
}
190191
}
191192
}
@@ -194,7 +195,7 @@ impl ConfigExtensionExt for SessionState {
194195
delegate! {
195196
to self.config_mut() {
196197
fn add_distributed_option_extension<T: ConfigExtension + Default>(&mut self, t: T) -> Result<(), DataFusionError>;
197-
fn retrieve_distributed_option_extension<T: ConfigExtension + Default>(&mut self) -> Result<(), DataFusionError>;
198+
fn retrieve_distributed_option_extension<T: ConfigExtension + Default>(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>;
198199
}
199200
}
200201
}
@@ -203,7 +204,7 @@ impl ConfigExtensionExt for SessionContext {
203204
delegate! {
204205
to self.state_ref().write().config_mut() {
205206
fn add_distributed_option_extension<T: ConfigExtension + Default>(&mut self, t: T) -> Result<(), DataFusionError>;
206-
fn retrieve_distributed_option_extension<T: ConfigExtension + Default>(&mut self) -> Result<(), DataFusionError>;
207+
fn retrieve_distributed_option_extension<T: ConfigExtension + Default>(&mut self, h: &HeaderMap) -> Result<(), DataFusionError>;
207208
}
208209
}
209210
}
@@ -212,17 +213,6 @@ impl ConfigExtensionExt for SessionContext {
212213
pub(crate) struct ContextGrpcMetadata(pub HeaderMap);
213214

214215
impl ContextGrpcMetadata {
215-
pub(crate) fn from_headers(metadata: HeaderMap) -> Self {
216-
let mut new = HeaderMap::new();
217-
for (k, v) in metadata.into_iter() {
218-
let Some(k) = k else { continue };
219-
if k.as_str().starts_with(FLIGHT_METADATA_CONFIG_PREFIX) {
220-
new.insert(k, v);
221-
}
222-
}
223-
Self(new)
224-
}
225-
226216
fn merge(mut self, other: Self) -> Self {
227217
for (k, v) in other.0.into_iter() {
228218
let Some(k) = k else { continue };
@@ -252,10 +242,9 @@ mod tests {
252242
opt.baz = true;
253243

254244
config.add_distributed_option_extension(opt)?;
255-
245+
let metadata = config.get_extension::<ContextGrpcMetadata>().unwrap();
256246
let mut new_config = SessionConfig::new();
257-
new_config.set_extension(config.get_extension::<ContextGrpcMetadata>().unwrap());
258-
new_config.retrieve_distributed_option_extension::<CustomExtension>()?;
247+
new_config.retrieve_distributed_option_extension::<CustomExtension>(&metadata.0)?;
259248

260249
let opt = get_ext::<CustomExtension>(&config);
261250
let new_opt = get_ext::<CustomExtension>(&new_config);
@@ -317,7 +306,7 @@ mod tests {
317306
fn test_propagate_no_metadata() -> Result<(), Box<dyn std::error::Error>> {
318307
let mut config = SessionConfig::new();
319308

320-
config.retrieve_distributed_option_extension::<CustomExtension>()?;
309+
config.retrieve_distributed_option_extension::<CustomExtension>(&Default::default())?;
321310

322311
let extension = config.options().extensions.get::<CustomExtension>();
323312
assert!(extension.is_none());
@@ -330,13 +319,11 @@ mod tests {
330319
let mut config = SessionConfig::new();
331320
let mut header_map = HeaderMap::new();
332321
header_map.insert(
333-
HeaderName::from_str("x-datafusion-distributed-other.setting").unwrap(),
322+
HeaderName::from_str("x-datafusion-distributed-config-other.setting").unwrap(),
334323
HeaderValue::from_str("value").unwrap(),
335324
);
336325

337-
let flight_metadata = ContextGrpcMetadata::from_headers(header_map);
338-
config.set_extension(std::sync::Arc::new(flight_metadata));
339-
config.retrieve_distributed_option_extension::<CustomExtension>()?;
326+
config.retrieve_distributed_option_extension::<CustomExtension>(&header_map)?;
340327

341328
let extension = config.options().extensions.get::<CustomExtension>();
342329
assert!(extension.is_none());
@@ -384,9 +371,8 @@ mod tests {
384371
);
385372

386373
let mut new_config = SessionConfig::new();
387-
new_config.set_extension(flight_metadata);
388-
new_config.retrieve_distributed_option_extension::<CustomExtension>()?;
389-
new_config.retrieve_distributed_option_extension::<AnotherExtension>()?;
374+
new_config.retrieve_distributed_option_extension::<CustomExtension>(&metadata)?;
375+
new_config.retrieve_distributed_option_extension::<AnotherExtension>(&metadata)?;
390376

391377
let propagated_custom = get_ext::<CustomExtension>(&new_config);
392378
let propagated_another = get_ext::<AnotherExtension>(&new_config);

src/flight_service/do_get.rs

Lines changed: 15 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,15 @@ use crate::composed_extension_codec::ComposedPhysicalExtensionCodec;
33
use crate::config_extension_ext::ContextGrpcMetadata;
44
use crate::errors::datafusion_error_to_tonic_status;
55
use crate::flight_service::service::ArrowFlightEndpoint;
6+
use crate::flight_service::session_builder::DistributedSessionBuilderContext;
67
use crate::plan::{DistributedCodec, PartitionGroup};
78
use crate::stage::{stage_from_proto, ExecutionStage, ExecutionStageProto};
89
use crate::user_provided_codec::get_user_codec;
910
use arrow_flight::encode::FlightDataEncoderBuilder;
1011
use arrow_flight::error::FlightError;
1112
use arrow_flight::flight_service_server::FlightService;
1213
use arrow_flight::Ticket;
13-
use datafusion::execution::{SessionState, SessionStateBuilder};
14-
use datafusion::optimizer::OptimizerConfig;
15-
use datafusion::prelude::SessionConfig;
14+
use datafusion::execution::SessionState;
1615
use futures::TryStreamExt;
1716
use prost::Message;
1817
use std::sync::Arc;
@@ -90,7 +89,7 @@ impl ArrowFlightEndpoint {
9089
async fn get_state_and_stage(
9190
&self,
9291
doget: DoGet,
93-
metadata: MetadataMap,
92+
metadata_map: MetadataMap,
9493
) -> Result<(SessionState, Arc<ExecutionStage>), Status> {
9594
let key = doget
9695
.stage_key
@@ -106,54 +105,34 @@ impl ArrowFlightEndpoint {
106105
.stage_proto
107106
.ok_or(Status::invalid_argument("DoGet is missing the stage proto"))?;
108107

109-
let mut config = SessionConfig::default();
110-
config.set_extension(Arc::new(ContextGrpcMetadata::from_headers(
111-
metadata.into_headers(),
112-
)));
113-
114-
let state_builder = SessionStateBuilder::new()
115-
.with_runtime_env(Arc::clone(&self.runtime))
116-
.with_config(config)
117-
.with_default_features();
118-
119-
let state_builder = self
120-
.session_builder
121-
.session_state_builder(state_builder)
122-
.map_err(|err| datafusion_error_to_tonic_status(&err))?;
123-
124-
let state = state_builder.build();
108+
let headers = metadata_map.into_headers();
125109
let mut state = self
126110
.session_builder
127-
.session_state(state)
111+
.build_session_state(DistributedSessionBuilderContext {
112+
runtime_env: Arc::clone(&self.runtime),
113+
headers: headers.clone(),
114+
})
128115
.await
129116
.map_err(|err| datafusion_error_to_tonic_status(&err))?;
130117

131-
let function_registry =
132-
state.function_registry().ok_or(Status::invalid_argument(
133-
"FunctionRegistry not present in newly built SessionState",
134-
))?;
135-
136118
let mut combined_codec = ComposedPhysicalExtensionCodec::default();
137119
combined_codec.push(DistributedCodec);
138120
if let Some(ref user_codec) = get_user_codec(state.config()) {
139121
combined_codec.push_arc(Arc::clone(user_codec));
140122
}
141123

142-
let stage = stage_from_proto(
143-
stage_proto,
144-
function_registry,
145-
self.runtime.as_ref(),
146-
&combined_codec,
147-
)
148-
.map(Arc::new)
149-
.map_err(|err| {
150-
Status::invalid_argument(format!("Cannot decode stage proto: {err}"))
151-
})?;
124+
let stage =
125+
stage_from_proto(stage_proto, &state, self.runtime.as_ref(), &combined_codec)
126+
.map(Arc::new)
127+
.map_err(|err| {
128+
Status::invalid_argument(format!("Cannot decode stage proto: {err}"))
129+
})?;
152130

153131
// Add the extensions that might be required for ExecutionPlan nodes in the plan
154132
let config = state.config_mut();
155133
config.set_extension(Arc::clone(&self.channel_manager));
156134
config.set_extension(stage.clone());
135+
config.set_extension(Arc::new(ContextGrpcMetadata(headers)));
157136

158137
Ok::<_, Status>((state, stage))
159138
})

src/flight_service/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,6 @@ mod session_builder;
55
pub(crate) use do_get::DoGet;
66

77
pub use service::{ArrowFlightEndpoint, StageKey};
8-
pub use session_builder::{NoopSessionBuilder, SessionBuilder};
8+
pub use session_builder::{
9+
DefaultSessionBuilder, DistributedSessionBuilder, DistributedSessionBuilderContext,
10+
};

src/flight_service/service.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::channel_manager::ChannelManager;
2-
use crate::flight_service::session_builder::NoopSessionBuilder;
3-
use crate::flight_service::SessionBuilder;
2+
use crate::flight_service::session_builder::DefaultSessionBuilder;
3+
use crate::flight_service::DistributedSessionBuilder;
44
use crate::stage::ExecutionStage;
55
use crate::ChannelResolver;
66
use arrow_flight::flight_service_server::FlightService;
@@ -36,7 +36,7 @@ pub struct ArrowFlightEndpoint {
3636
pub(super) runtime: Arc<RuntimeEnv>,
3737
#[allow(clippy::type_complexity)]
3838
pub(super) stages: DashMap<StageKey, Arc<OnceCell<(SessionState, Arc<ExecutionStage>)>>>,
39-
pub(super) session_builder: Arc<dyn SessionBuilder + Send + Sync>,
39+
pub(super) session_builder: Arc<dyn DistributedSessionBuilder + Send + Sync>,
4040
}
4141

4242
impl ArrowFlightEndpoint {
@@ -45,13 +45,13 @@ impl ArrowFlightEndpoint {
4545
channel_manager: Arc::new(ChannelManager::new(channel_resolver)),
4646
runtime: Arc::new(RuntimeEnv::default()),
4747
stages: DashMap::new(),
48-
session_builder: Arc::new(NoopSessionBuilder),
48+
session_builder: Arc::new(DefaultSessionBuilder),
4949
}
5050
}
5151

5252
pub fn with_session_builder(
5353
&mut self,
54-
session_builder: impl SessionBuilder + Send + Sync + 'static,
54+
session_builder: impl DistributedSessionBuilder + Send + Sync + 'static,
5555
) {
5656
self.session_builder = Arc::new(session_builder);
5757
}

0 commit comments

Comments
 (0)