Skip to content

Commit 82a2c82

Browse files
committed
Extend the SessionBuilder trait to be able to operate not only at the SessionStateBuilder level, but also on SessionState and SessionContext
1 parent 17e53d2 commit 82a2c82

File tree

9 files changed

+122
-46
lines changed

9 files changed

+122
-46
lines changed

src/flight_service/do_get.rs

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use arrow_flight::flight_service_server::FlightService;
1010
use arrow_flight::Ticket;
1111
use datafusion::execution::SessionStateBuilder;
1212
use datafusion::optimizer::OptimizerConfig;
13+
use datafusion::prelude::SessionContext;
1314
use futures::TryStreamExt;
1415
use prost::Message;
1516
use std::sync::Arc;
@@ -42,8 +43,17 @@ impl ArrowFlightEndpoint {
4243
let state_builder = SessionStateBuilder::new()
4344
.with_runtime_env(Arc::clone(&self.runtime))
4445
.with_default_features();
46+
let state_builder = self
47+
.session_builder
48+
.session_state_builder(state_builder)
49+
.map_err(|err| datafusion_error_to_tonic_status(&err))?;
4550

46-
let mut state = self.session_builder.on_new_session(state_builder).build();
51+
let state = state_builder.build();
52+
let mut state = self
53+
.session_builder
54+
.session_state(state)
55+
.await
56+
.map_err(|err| datafusion_error_to_tonic_status(&err))?;
4757

4858
let function_registry = state.function_registry().ok_or(Status::invalid_argument(
4959
"FunctionRegistry not present in newly built SessionState",
@@ -55,7 +65,7 @@ impl ArrowFlightEndpoint {
5565
combined_codec.push_arc(Arc::clone(&user_codec));
5666
}
5767

58-
let mut stage = stage_from_proto(
68+
let stage = stage_from_proto(
5969
stage_msg,
6070
function_registry,
6171
&self.runtime.as_ref(),
@@ -69,8 +79,16 @@ impl ArrowFlightEndpoint {
6979
config.set_extension(Arc::clone(&self.channel_manager));
7080
config.set_extension(Arc::new(stage));
7181

82+
let ctx = SessionContext::new_with_state(state);
83+
84+
let ctx = self
85+
.session_builder
86+
.session_context(ctx)
87+
.await
88+
.map_err(|err| datafusion_error_to_tonic_status(&err))?;
89+
7290
let stream = inner_plan
73-
.execute(doget.partition as usize, state.task_ctx())
91+
.execute(doget.partition as usize, ctx.task_ctx())
7492
.map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?;
7593

7694
let flight_data_stream = FlightDataEncoderBuilder::new()

src/flight_service/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ mod stream_partitioner_registry;
66
pub(crate) use do_get::DoGet;
77

88
pub use service::ArrowFlightEndpoint;
9-
pub use session_builder::SessionBuilder;
9+
pub use session_builder::{NoopSessionBuilder, SessionBuilder};

src/flight_service/session_builder.rs

Lines changed: 74 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
use datafusion::execution::SessionStateBuilder;
1+
use async_trait::async_trait;
2+
use datafusion::error::DataFusionError;
3+
use datafusion::execution::{SessionState, SessionStateBuilder};
4+
use datafusion::prelude::SessionContext;
25

36
/// Trait called by the Arrow Flight endpoint that handles distributed parts of a DataFusion
47
/// plan for building a DataFusion's [datafusion::prelude::SessionContext].
8+
#[async_trait]
59
pub trait SessionBuilder {
610
/// Takes a [SessionStateBuilder] and adds whatever is necessary for it to work, like
711
/// custom extension codecs, custom physical optimization rules, UDFs, UDAFs, config
@@ -10,8 +14,9 @@ pub trait SessionBuilder {
1014
/// Example: adding some custom extension plan codecs
1115
///
1216
/// ```rust
13-
///
1417
/// # use std::sync::Arc;
18+
/// # use async_trait::async_trait;
19+
/// # use datafusion::error::DataFusionError;
1520
/// # use datafusion::execution::runtime_env::RuntimeEnv;
1621
/// # use datafusion::execution::{FunctionRegistry, SessionStateBuilder};
1722
/// # use datafusion::physical_plan::ExecutionPlan;
@@ -33,22 +38,81 @@ pub trait SessionBuilder {
3338
///
3439
/// #[derive(Clone)]
3540
/// struct CustomSessionBuilder;
41+
///
42+
/// #[async_trait]
3643
/// impl SessionBuilder for CustomSessionBuilder {
37-
/// fn on_new_session(&self, mut builder: SessionStateBuilder) -> SessionStateBuilder {
44+
/// fn session_state_builder(&self, mut builder: SessionStateBuilder) -> Result<SessionStateBuilder, DataFusionError> {
3845
/// // Add your UDFs, optimization rules, etc...
39-
/// with_user_codec(builder, CustomExecCodec)
46+
/// Ok(with_user_codec(builder, CustomExecCodec))
47+
/// }
48+
/// }
49+
/// ```
50+
fn session_state_builder(
51+
&self,
52+
builder: SessionStateBuilder,
53+
) -> Result<SessionStateBuilder, DataFusionError> {
54+
Ok(builder)
55+
}
56+
57+
/// Modifies the [SessionState] and returns it. Same as [SessionBuilder::session_state_builder]
58+
/// but operating on an already built [SessionState].
59+
///
60+
/// Example:
61+
///
62+
/// ```rust
63+
/// # use async_trait::async_trait;
64+
/// # use datafusion::common::DataFusionError;
65+
/// # use datafusion::execution::SessionState;
66+
/// # use datafusion_distributed::SessionBuilder;
67+
///
68+
/// #[derive(Clone)]
69+
/// struct CustomSessionBuilder;
70+
///
71+
/// #[async_trait]
72+
/// impl SessionBuilder for CustomSessionBuilder {
73+
/// async fn session_state(&self, state: SessionState) -> Result<SessionState, DataFusionError> {
74+
/// // mutate the state adding any custom logic
75+
/// Ok(state)
4076
/// }
4177
/// }
4278
/// ```
43-
fn on_new_session(&self, builder: SessionStateBuilder) -> SessionStateBuilder;
79+
async fn session_state(&self, state: SessionState) -> Result<SessionState, DataFusionError> {
80+
Ok(state)
81+
}
82+
83+
/// Modifies the [SessionContext] and returns it. Same as [SessionBuilder::session_state_builder]
84+
/// or [SessionBuilder::session_state] but operation on an already built [SessionContext].
85+
///
86+
/// Example:
87+
///
88+
/// ```rust
89+
/// # use async_trait::async_trait;
90+
/// # use datafusion::common::DataFusionError;
91+
/// # use datafusion::prelude::SessionContext;
92+
/// # use datafusion_distributed::SessionBuilder;
93+
///
94+
/// #[derive(Clone)]
95+
/// struct CustomSessionBuilder;
96+
///
97+
/// #[async_trait]
98+
/// impl SessionBuilder for CustomSessionBuilder {
99+
/// async fn session_context(&self, ctx: SessionContext) -> Result<SessionContext, DataFusionError> {
100+
/// // mutate the context adding any custom logic
101+
/// Ok(ctx)
102+
/// }
103+
/// }
104+
/// ```
105+
async fn session_context(
106+
&self,
107+
ctx: SessionContext,
108+
) -> Result<SessionContext, DataFusionError> {
109+
Ok(ctx)
110+
}
44111
}
45112

46113
/// Noop implementation of the [SessionBuilder]. Used by default if no [SessionBuilder] is provided
47114
/// while building the Arrow Flight endpoint.
115+
#[derive(Debug, Clone)]
48116
pub struct NoopSessionBuilder;
49117

50-
impl SessionBuilder for NoopSessionBuilder {
51-
fn on_new_session(&self, builder: SessionStateBuilder) -> SessionStateBuilder {
52-
builder
53-
}
54-
}
118+
impl SessionBuilder for NoopSessionBuilder {}

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ mod user_provided_codec;
1313
pub mod test_utils;
1414

1515
pub use channel_manager::{BoxCloneSyncChannel, ChannelManager, ChannelResolver};
16-
pub use flight_service::{ArrowFlightEndpoint, SessionBuilder};
16+
pub use flight_service::{ArrowFlightEndpoint, NoopSessionBuilder, SessionBuilder};
1717
pub use physical_optimizer::DistributedPhysicalOptimizerRule;
1818
pub use plan::ArrowFlightReadExec;
1919
pub use stage::{display_stage_graphviz, ExecutionStage};

src/test_utils/localhost.rs

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,6 @@ use std::time::Duration;
1414
use tonic::transport::{Channel, Server};
1515
use url::Url;
1616

17-
#[derive(Debug, Clone)]
18-
pub struct NoopSessionBuilder;
19-
impl SessionBuilder for NoopSessionBuilder {
20-
fn on_new_session(&self, builder: SessionStateBuilder) -> SessionStateBuilder {
21-
builder
22-
}
23-
}
24-
2517
pub async fn start_localhost_context<N, I, B>(
2618
ports: I,
2719
session_builder: B,
@@ -49,12 +41,16 @@ where
4941

5042
let config = SessionConfig::new().with_target_partitions(3);
5143

52-
let state = SessionStateBuilder::new()
44+
let builder = SessionStateBuilder::new()
5345
.with_default_features()
54-
.with_config(config)
55-
.build();
46+
.with_config(config);
47+
let builder = session_builder.session_state_builder(builder).unwrap();
48+
49+
let state = builder.build();
50+
let state = session_builder.session_state(state).await.unwrap();
5651

5752
let ctx = SessionContext::new_with_state(state);
53+
let ctx = session_builder.session_context(ctx).await.unwrap();
5854

5955
ctx.state_ref()
6056
.write()

tests/custom_extension_codec.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ mod tests {
2525
use datafusion_distributed::assert_snapshot;
2626
use datafusion_distributed::test_utils::localhost::start_localhost_context;
2727
use datafusion_distributed::{
28-
add_user_codec, with_user_codec, ArrowFlightReadExec, DistributedPhysicalOptimizerRule,
29-
SessionBuilder,
28+
with_user_codec, ArrowFlightReadExec, DistributedPhysicalOptimizerRule, SessionBuilder,
3029
};
3130
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
3231
use datafusion_proto::protobuf::proto_error;
@@ -42,14 +41,16 @@ mod tests {
4241
#[derive(Clone)]
4342
struct CustomSessionBuilder;
4443
impl SessionBuilder for CustomSessionBuilder {
45-
fn on_new_session(&self, builder: SessionStateBuilder) -> SessionStateBuilder {
46-
with_user_codec(builder, Int64ListExecCodec)
44+
fn session_state_builder(
45+
&self,
46+
builder: SessionStateBuilder,
47+
) -> Result<SessionStateBuilder, DataFusionError> {
48+
Ok(with_user_codec(builder, Int64ListExecCodec))
4749
}
4850
}
4951

50-
let (mut ctx, _guard) =
52+
let (ctx, _guard) =
5153
start_localhost_context([50050, 50051, 50052], CustomSessionBuilder).await;
52-
add_user_codec(&mut ctx, Int64ListExecCodec);
5354

5455
let single_node_plan = build_plan(false)?;
5556
assert_snapshot!(displayable(single_node_plan.as_ref()).indent(true).to_string(), @r"

tests/distributed_aggregation.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22
mod tests {
33
use datafusion::arrow::util::pretty::pretty_format_batches;
44
use datafusion::physical_plan::{displayable, execute_stream};
5-
use datafusion_distributed::assert_snapshot;
6-
use datafusion_distributed::test_utils::localhost::{
7-
start_localhost_context, NoopSessionBuilder,
8-
};
5+
use datafusion_distributed::test_utils::localhost::start_localhost_context;
96
use datafusion_distributed::test_utils::parquet::register_parquet_tables;
107
use datafusion_distributed::test_utils::plan::distribute_aggregate;
8+
use datafusion_distributed::{assert_snapshot, NoopSessionBuilder};
119
use futures::TryStreamExt;
1210
use std::error::Error;
1311

tests/error_propagation.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ mod tests {
1313
};
1414
use datafusion_distributed::test_utils::localhost::start_localhost_context;
1515
use datafusion_distributed::{
16-
add_user_codec, with_user_codec, ArrowFlightReadExec, DistributedPhysicalOptimizerRule,
17-
SessionBuilder,
16+
with_user_codec, ArrowFlightReadExec, DistributedPhysicalOptimizerRule, SessionBuilder,
1817
};
1918
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
2019
use datafusion_proto::protobuf::proto_error;
@@ -30,13 +29,15 @@ mod tests {
3029
#[derive(Clone)]
3130
struct CustomSessionBuilder;
3231
impl SessionBuilder for CustomSessionBuilder {
33-
fn on_new_session(&self, builder: SessionStateBuilder) -> SessionStateBuilder {
34-
with_user_codec(builder, ErrorExecCodec)
32+
fn session_state_builder(
33+
&self,
34+
builder: SessionStateBuilder,
35+
) -> Result<SessionStateBuilder, DataFusionError> {
36+
Ok(with_user_codec(builder, ErrorExecCodec))
3537
}
3638
}
37-
let (mut ctx, _guard) =
39+
let (ctx, _guard) =
3840
start_localhost_context([50050, 50051, 50053], CustomSessionBuilder).await;
39-
add_user_codec(&mut ctx, ErrorExecCodec);
4041

4142
let mut plan: Arc<dyn ExecutionPlan> = Arc::new(ErrorExec::new("something failed"));
4243

tests/highly_distributed_query.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@
22
mod tests {
33
use datafusion::physical_expr::Partitioning;
44
use datafusion::physical_plan::{displayable, execute_stream};
5-
use datafusion_distributed::test_utils::localhost::{
6-
start_localhost_context, NoopSessionBuilder,
7-
};
5+
use datafusion_distributed::test_utils::localhost::start_localhost_context;
86
use datafusion_distributed::test_utils::parquet::register_parquet_tables;
9-
use datafusion_distributed::{assert_snapshot, ArrowFlightReadExec};
7+
use datafusion_distributed::{assert_snapshot, ArrowFlightReadExec, NoopSessionBuilder};
108
use futures::TryStreamExt;
119
use std::error::Error;
1210
use std::sync::Arc;

0 commit comments

Comments
 (0)