Skip to content

Commit 04728a1

Browse files
committed
Better UX for providing user defined codecs
1 parent 957bffd commit 04728a1

File tree

10 files changed

+98
-89
lines changed

10 files changed

+98
-89
lines changed

src/flight_service/do_get.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::composed_extension_codec::ComposedPhysicalExtensionCodec;
22
use crate::errors::datafusion_error_to_tonic_status;
33
use crate::flight_service::service::ArrowFlightEndpoint;
4+
use crate::get_user_codec;
45
use crate::plan::DistributedCodec;
56
use crate::stage::{stage_from_proto, ExecutionStageProto};
67
use arrow_flight::encode::FlightDataEncoderBuilder;
@@ -48,10 +49,9 @@ impl ArrowFlightEndpoint {
4849
"FunctionRegistry not present in newly built SessionState",
4950
))?;
5051

51-
let user_codec = self.session_builder.codec();
5252
let mut combined_codec = ComposedPhysicalExtensionCodec::default();
5353
combined_codec.push(DistributedCodec);
54-
if let Some(ref user_codec) = user_codec {
54+
if let Some(ref user_codec) = get_user_codec(state.config()) {
5555
combined_codec.push_arc(Arc::clone(&user_codec));
5656
}
5757

@@ -62,9 +62,6 @@ impl ArrowFlightEndpoint {
6262
&combined_codec,
6363
)
6464
.map_err(|err| Status::invalid_argument(format!("Cannot decode stage proto: {err}")))?;
65-
if let Some(user_codec) = user_codec {
66-
stage = stage.with_user_codec(user_codec)
67-
}
6865
let inner_plan = Arc::clone(&stage.plan);
6966

7067
// Add the extensions that might be required for ExecutionPlan nodes in the plan

src/flight_service/session_builder.rs

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,22 @@
11
use datafusion::execution::SessionStateBuilder;
2-
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
3-
use std::sync::Arc;
42

53
/// Trait called by the Arrow Flight endpoint that handles distributed parts of a DataFusion
64
/// plan for building a DataFusion's [datafusion::prelude::SessionContext].
75
pub trait SessionBuilder {
86
/// Takes a [SessionStateBuilder] and adds whatever is necessary for it to work, like
9-
/// custom physical optimization rules, UDFs, UDAFs, config extensions, etc...
7+
/// custom extension codecs, custom physical optimization rules, UDFs, UDAFs, config
8+
/// extensions, etc...
109
///
11-
/// Example:
10+
/// Example: adding some custom extension plan codecs
1211
///
1312
/// ```rust
14-
/// # use datafusion::execution::{FunctionRegistry, SessionStateBuilder};
15-
/// # use datafusion_distributed::{SessionBuilder};
16-
///
17-
/// #[derive(Clone)]
18-
/// struct CustomSessionBuilder;
19-
/// impl SessionBuilder for CustomSessionBuilder {
20-
/// fn on_new_session(&self, mut builder: SessionStateBuilder) -> SessionStateBuilder {
21-
/// // add your own UDFs, optimization rules, etc...
22-
/// builder
23-
/// }
24-
/// }
25-
/// ```
26-
fn on_new_session(&self, builder: SessionStateBuilder) -> SessionStateBuilder {
27-
builder
28-
}
29-
30-
/// Allows users to provide their own codecs.
3113
///
32-
/// ```rust
3314
/// # use std::sync::Arc;
3415
/// # use datafusion::execution::runtime_env::RuntimeEnv;
3516
/// # use datafusion::execution::{FunctionRegistry, SessionStateBuilder};
3617
/// # use datafusion::physical_plan::ExecutionPlan;
3718
/// # use datafusion_proto::physical_plan::PhysicalExtensionCodec;
38-
/// # use datafusion_distributed::{SessionBuilder};
19+
/// # use datafusion_distributed::{with_user_codec, SessionBuilder};
3920
///
4021
/// #[derive(Debug)]
4122
/// struct CustomExecCodec;
@@ -53,14 +34,13 @@ pub trait SessionBuilder {
5334
/// #[derive(Clone)]
5435
/// struct CustomSessionBuilder;
5536
/// impl SessionBuilder for CustomSessionBuilder {
56-
/// fn codec(&self) -> Option<Arc<dyn PhysicalExtensionCodec + 'static>> {
57-
/// Some(Arc::new(CustomExecCodec))
37+
/// fn on_new_session(&self, mut builder: SessionStateBuilder) -> SessionStateBuilder {
38+
/// // Add your UDFs, optimization rules, etc...
39+
/// with_user_codec(builder, CustomExecCodec)
5840
/// }
5941
/// }
6042
/// ```
61-
fn codec(&self) -> Option<Arc<dyn PhysicalExtensionCodec + 'static>> {
62-
None
63-
}
43+
fn on_new_session(&self, builder: SessionStateBuilder) -> SessionStateBuilder;
6444
}
6545

6646
/// Noop implementation of the [SessionBuilder]. Used by default if no [SessionBuilder] is provided

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ mod test_utils;
1010
pub mod physical_optimizer;
1111
pub mod stage;
1212
pub mod task;
13+
mod user_provided_codec;
14+
1315
pub use channel_manager::{BoxCloneSyncChannel, ChannelManager, ChannelResolver};
1416
pub use flight_service::{ArrowFlightEndpoint, SessionBuilder};
1517
pub use plan::ArrowFlightReadExec;
18+
pub use user_provided_codec::{add_user_codec, get_user_codec, with_user_codec};

src/physical_optimizer.rs

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::sync::Arc;
22

3+
use super::stage::ExecutionStage;
34
use crate::{plan::PartitionIsolatorExec, ArrowFlightReadExec};
45
use datafusion::common::tree_node::TreeNodeRecursion;
56
use datafusion::error::DataFusionError;
@@ -15,15 +16,9 @@ use datafusion::{
1516
displayable, repartition::RepartitionExec, ExecutionPlan, ExecutionPlanProperties,
1617
},
1718
};
18-
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
19-
20-
use super::stage::ExecutionStage;
2119

2220
#[derive(Debug, Default)]
2321
pub struct DistributedPhysicalOptimizerRule {
24-
/// Optional codec to assist in serializing and deserializing any custom
25-
/// ExecutionPlan nodes
26-
codec: Option<Arc<dyn PhysicalExtensionCodec>>,
2722
/// maximum number of partitions per task. This is used to determine how many
2823
/// tasks to create for each stage
2924
partitions_per_task: Option<usize>,
@@ -32,18 +27,10 @@ pub struct DistributedPhysicalOptimizerRule {
3227
impl DistributedPhysicalOptimizerRule {
3328
pub fn new() -> Self {
3429
DistributedPhysicalOptimizerRule {
35-
codec: None,
3630
partitions_per_task: None,
3731
}
3832
}
3933

40-
/// Set a codec to use to assist in serializing and deserializing
41-
/// custom ExecutionPlan nodes.
42-
pub fn with_codec(mut self, codec: impl PhysicalExtensionCodec + 'static) -> Self {
43-
self.codec = Some(Arc::new(codec));
44-
self
45-
}
46-
4734
/// Set the maximum number of partitions per task. This is used to determine how many
4835
/// tasks to create for each stage.
4936
///
@@ -156,9 +143,6 @@ impl DistributedPhysicalOptimizerRule {
156143
if let Some(partitions_per_task) = self.partitions_per_task {
157144
stage = stage.with_maximum_partitions_per_task(partitions_per_task);
158145
}
159-
if let Some(codec) = self.codec.as_ref() {
160-
stage = stage.with_user_codec(codec.clone());
161-
}
162146
stage.depth = depth;
163147

164148
Ok(stage)

src/plan/arrow_flight_read.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::channel_manager::ChannelManager;
33
use crate::composed_extension_codec::ComposedPhysicalExtensionCodec;
44
use crate::errors::tonic_status_to_datafusion_error;
55
use crate::flight_service::DoGet;
6+
use crate::get_user_codec;
67
use crate::plan::DistributedCodec;
78
use crate::stage::{proto_from_stage, ExecutionStage};
89
use arrow_flight::decode::FlightRecordBatchStream;
@@ -174,7 +175,7 @@ impl ExecutionPlan for ArrowFlightReadExec {
174175

175176
let mut combined_codec = ComposedPhysicalExtensionCodec::default();
176177
combined_codec.push(DistributedCodec {});
177-
if let Some(ref user_codec) = stage.user_codec {
178+
if let Some(ref user_codec) = get_user_codec(context.session_config()) {
178179
combined_codec.push_arc(Arc::clone(user_codec));
179180
}
180181

src/stage/proto.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ pub fn stage_from_proto(
8181
plan,
8282
inputs,
8383
tasks: msg.tasks,
84-
user_codec: None,
8584
depth: 0,
8685
})
8786
}
@@ -147,7 +146,6 @@ mod tests {
147146
plan: physical_plan,
148147
inputs: vec![],
149148
tasks: vec![],
150-
user_codec: Some(Arc::new(DefaultPhysicalExtensionCodec {})),
151149
depth: 0,
152150
};
153151

src/stage/stage.rs

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ use datafusion::error::{DataFusionError, Result};
55
use datafusion::execution::TaskContext;
66
use datafusion::physical_plan::{displayable, ExecutionPlan};
77
use datafusion::prelude::SessionContext;
8-
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
98

109
use itertools::Itertools;
1110
use rand::Rng;
@@ -32,8 +31,6 @@ pub struct ExecutionStage {
3231
/// Our tasks which tell us how finely grained to execute the partitions in
3332
/// the plan
3433
pub tasks: Vec<ExecutionTask>,
35-
/// An optional codec to assist in serializing and deserializing this stage
36-
pub user_codec: Option<Arc<dyn PhysicalExtensionCodec>>,
3734
/// tree depth of our location in the stage tree, used for display only
3835
pub depth: usize,
3936
}
@@ -65,7 +62,6 @@ impl ExecutionStage {
6562
.map(|s| s as Arc<dyn ExecutionPlan>)
6663
.collect(),
6764
tasks: vec![ExecutionTask::new(partition_group)],
68-
user_codec: None,
6965
depth: 0,
7066
}
7167
}
@@ -93,13 +89,6 @@ impl ExecutionStage {
9389
self
9490
}
9591

96-
/// Sets the codec for this stage, which is used to serialize and deserialize the plan
97-
/// and its inputs.
98-
pub fn with_user_codec(mut self, codec: Arc<dyn PhysicalExtensionCodec>) -> Self {
99-
self.user_codec = Some(codec);
100-
self
101-
}
102-
10392
/// Returns the name of this stage
10493
pub fn name(&self) -> String {
10594
format!("Stage {:<3}", self.num)
@@ -174,7 +163,6 @@ impl ExecutionStage {
174163
plan: self.plan.clone(),
175164
inputs: assigned_children,
176165
tasks: assigned_tasks,
177-
user_codec: self.user_codec.clone(),
178166
depth: self.depth,
179167
};
180168

@@ -205,7 +193,6 @@ impl ExecutionPlan for ExecutionStage {
205193
plan: self.plan.clone(),
206194
inputs: children,
207195
tasks: self.tasks.clone(),
208-
user_codec: self.user_codec.clone(),
209196
depth: self.depth,
210197
}))
211198
}

src/user_provided_codec.rs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
use datafusion::execution::SessionStateBuilder;
2+
use datafusion::prelude::{SessionConfig, SessionContext};
3+
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
4+
use std::sync::Arc;
5+
6+
pub struct UserProvidedCodec(Arc<dyn PhysicalExtensionCodec>);
7+
8+
#[allow(private_bounds)]
9+
pub fn add_user_codec(
10+
transport: &mut impl UserCodecTransport,
11+
codec: impl PhysicalExtensionCodec + 'static,
12+
) {
13+
transport.set(codec);
14+
}
15+
16+
#[allow(private_bounds)]
17+
pub fn with_user_codec<T: UserCodecTransport>(
18+
mut transport: T,
19+
codec: impl PhysicalExtensionCodec + 'static,
20+
) -> T {
21+
transport.set(codec);
22+
transport
23+
}
24+
25+
#[allow(private_bounds)]
26+
pub fn get_user_codec(
27+
transport: &impl UserCodecTransport,
28+
) -> Option<Arc<dyn PhysicalExtensionCodec>> {
29+
transport.get()
30+
}
31+
32+
trait UserCodecTransport {
33+
fn set(&mut self, codec: impl PhysicalExtensionCodec + 'static);
34+
fn get(&self) -> Option<Arc<dyn PhysicalExtensionCodec>>;
35+
}
36+
37+
impl UserCodecTransport for SessionConfig {
38+
fn set(&mut self, codec: impl PhysicalExtensionCodec + 'static) {
39+
self.set_extension(Arc::new(UserProvidedCodec(Arc::new(codec))));
40+
}
41+
42+
fn get(&self) -> Option<Arc<dyn PhysicalExtensionCodec>> {
43+
Some(Arc::clone(&self.get_extension::<UserProvidedCodec>()?.0))
44+
}
45+
}
46+
47+
impl UserCodecTransport for SessionContext {
48+
fn set(&mut self, codec: impl PhysicalExtensionCodec + 'static) {
49+
self.state_ref().write().config_mut().set(codec)
50+
}
51+
52+
fn get(&self) -> Option<Arc<dyn PhysicalExtensionCodec>> {
53+
self.state_ref().read().config().get()
54+
}
55+
}
56+
57+
impl UserCodecTransport for SessionStateBuilder {
58+
fn set(&mut self, codec: impl PhysicalExtensionCodec + 'static) {
59+
self.config().get_or_insert_default().set(codec);
60+
}
61+
62+
fn get(&self) -> Option<Arc<dyn PhysicalExtensionCodec>> {
63+
// Nobody will never want to retriever a user codec from a SessionStateBuilder
64+
None
65+
}
66+
}

tests/custom_extension_codec.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ mod tests {
2828
displayable, execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties,
2929
};
3030
use datafusion_distributed::physical_optimizer::DistributedPhysicalOptimizerRule;
31-
use datafusion_distributed::{ArrowFlightReadExec, SessionBuilder};
31+
use datafusion_distributed::{
32+
add_user_codec, with_user_codec, ArrowFlightReadExec, SessionBuilder,
33+
};
3234
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
3335
use datafusion_proto::protobuf::proto_error;
3436
use futures::{stream, TryStreamExt};
@@ -40,17 +42,17 @@ mod tests {
4042
#[tokio::test]
4143
#[ignore]
4244
async fn custom_extension_codec() -> Result<(), Box<dyn std::error::Error>> {
43-
// 1. The codec should be added to the extension builder.
4445
#[derive(Clone)]
4546
struct CustomSessionBuilder;
4647
impl SessionBuilder for CustomSessionBuilder {
47-
fn codec(&self) -> Option<Arc<dyn PhysicalExtensionCodec>> {
48-
Some(Arc::new(Int64ListExecCodec))
48+
fn on_new_session(&self, builder: SessionStateBuilder) -> SessionStateBuilder {
49+
with_user_codec(builder, Int64ListExecCodec)
4950
}
5051
}
5152

52-
let (ctx, _guard) =
53+
let (mut ctx, _guard) =
5354
start_localhost_context([50050, 50051, 50052], CustomSessionBuilder).await;
55+
add_user_codec(&mut ctx, Int64ListExecCodec);
5456

5557
let single_node_plan = build_plan(false)?;
5658
assert_snapshot!(displayable(single_node_plan.as_ref()).indent(true).to_string(), @r"
@@ -60,10 +62,8 @@ mod tests {
6062
");
6163

6264
let distributed_plan = build_plan(true)?;
63-
let distributed_plan = DistributedPhysicalOptimizerRule::default()
64-
// 1. The codec should be added to the DistributedPhysicalOptimizerRule.
65-
.with_codec(Int64ListExecCodec)
66-
.distribute_plan(distributed_plan)?;
65+
let distributed_plan =
66+
DistributedPhysicalOptimizerRule::default().distribute_plan(distributed_plan)?;
6767

6868
assert_snapshot!(displayable(&distributed_plan).indent(true).to_string(), @r"
6969
┌───── Stage 3 Task: partitions: 0,unassigned]

tests/error_propagation.rs

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ mod tests {
1616
execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties,
1717
};
1818
use datafusion_distributed::physical_optimizer::DistributedPhysicalOptimizerRule;
19-
use datafusion_distributed::{ArrowFlightReadExec, SessionBuilder};
19+
use datafusion_distributed::{
20+
add_user_codec, with_user_codec, ArrowFlightReadExec, SessionBuilder,
21+
};
2022
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
2123
use datafusion_proto::protobuf::proto_error;
2224
use futures::{stream, TryStreamExt};
@@ -27,26 +29,17 @@ mod tests {
2729
use std::sync::Arc;
2830

2931
#[tokio::test]
30-
#[ignore]
3132
async fn test_error_propagation() -> Result<(), Box<dyn Error>> {
3233
#[derive(Clone)]
3334
struct CustomSessionBuilder;
3435
impl SessionBuilder for CustomSessionBuilder {
35-
fn on_new_session(&self, mut builder: SessionStateBuilder) -> SessionStateBuilder {
36-
let codec: Arc<dyn PhysicalExtensionCodec> = Arc::new(ErrorExecCodec);
37-
let config = builder.config().get_or_insert_default();
38-
config.set_extension(Arc::new(codec));
39-
builder
36+
fn on_new_session(&self, builder: SessionStateBuilder) -> SessionStateBuilder {
37+
with_user_codec(builder, ErrorExecCodec)
4038
}
4139
}
42-
let (ctx, _guard) =
40+
let (mut ctx, _guard) =
4341
start_localhost_context([50050, 50051, 50053], CustomSessionBuilder).await;
44-
45-
let codec: Arc<dyn PhysicalExtensionCodec> = Arc::new(ErrorExecCodec);
46-
ctx.state_ref()
47-
.write()
48-
.config_mut()
49-
.set_extension(Arc::new(codec));
42+
add_user_codec(&mut ctx, ErrorExecCodec);
5043

5144
let mut plan: Arc<dyn ExecutionPlan> = Arc::new(ErrorExec::new("something failed"));
5245

0 commit comments

Comments
 (0)