Skip to content

Commit 84026d6

Browse files
authored
Support user provided codecs (#81)
1 parent 8971edd commit 84026d6

File tree

12 files changed

+222
-160
lines changed

12 files changed

+222
-160
lines changed

src/composed_extension_codec.rs

Lines changed: 4 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ use datafusion::error::DataFusionError;
33
use datafusion::execution::FunctionRegistry;
44
use datafusion::logical_expr::{AggregateUDF, ScalarUDF};
55
use datafusion::physical_plan::ExecutionPlan;
6-
use datafusion::prelude::SessionConfig;
76
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
87
use std::fmt::Debug;
98
use std::sync::Arc;
@@ -25,42 +24,10 @@ impl ComposedPhysicalExtensionCodec {
2524
self.codecs.push(Arc::new(codec));
2625
}
2726

28-
/// Adds a new [PhysicalExtensionCodec] from DataFusion's [SessionConfig] extensions.
29-
///
30-
/// If users have a custom [PhysicalExtensionCodec] for their own nodes, they should
31-
/// populate the config extensions with a [PhysicalExtensionCodec] so that we can use
32-
/// it while encoding/decoding plans to/from protobuf.
33-
///
34-
/// Example:
35-
/// ```rust
36-
/// # use std::sync::Arc;
37-
/// # use datafusion::execution::FunctionRegistry;
38-
/// # use datafusion::physical_plan::ExecutionPlan;
39-
/// # use datafusion::prelude::SessionConfig;
40-
/// # use datafusion_proto::physical_plan::PhysicalExtensionCodec;
41-
///
42-
/// #[derive(Debug)]
43-
/// struct CustomUserCodec {}
44-
///
45-
/// impl PhysicalExtensionCodec for CustomUserCodec {
46-
/// fn try_decode(&self, buf: &[u8], inputs: &[Arc<dyn ExecutionPlan>], registry: &dyn FunctionRegistry) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
47-
/// todo!()
48-
/// }
49-
///
50-
/// fn try_encode(&self, node: Arc<dyn ExecutionPlan>, buf: &mut Vec<u8>) -> datafusion::common::Result<()> {
51-
/// todo!()
52-
/// }
53-
/// }
54-
///
55-
/// let mut config = SessionConfig::new();
56-
///
57-
/// let codec: Arc<dyn PhysicalExtensionCodec> = Arc::new(CustomUserCodec {});
58-
/// config.set_extension(Arc::new(codec));
59-
/// ```
60-
pub(crate) fn push_from_config(&mut self, config: &SessionConfig) {
61-
if let Some(user_codec) = config.get_extension::<Arc<dyn PhysicalExtensionCodec>>() {
62-
self.codecs.push(user_codec.as_ref().clone());
63-
}
27+
/// Adds a new [PhysicalExtensionCodec] to the list. These codecs will be tried
28+
/// sequentially until one works.
29+
pub(crate) fn push_arc(&mut self, codec: Arc<dyn PhysicalExtensionCodec>) {
30+
self.codecs.push(codec);
6431
}
6532

6633
fn try_any<T>(

src/flight_service/do_get.rs

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1+
use crate::composed_extension_codec::ComposedPhysicalExtensionCodec;
12
use crate::errors::datafusion_error_to_tonic_status;
23
use crate::flight_service::service::ArrowFlightEndpoint;
34
use crate::plan::DistributedCodec;
45
use crate::stage::{stage_from_proto, ExecutionStageProto};
6+
use crate::user_provided_codec::get_user_codec;
57
use arrow_flight::encode::FlightDataEncoderBuilder;
68
use arrow_flight::error::FlightError;
79
use arrow_flight::flight_service_server::FlightService;
810
use arrow_flight::Ticket;
911
use datafusion::execution::SessionStateBuilder;
1012
use datafusion::optimizer::OptimizerConfig;
11-
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
1213
use futures::TryStreamExt;
1314
use prost::Message;
1415
use std::sync::Arc;
@@ -48,25 +49,32 @@ impl ArrowFlightEndpoint {
4849
"FunctionRegistry not present in newly built SessionState",
4950
))?;
5051

51-
let codec = DistributedCodec {};
52-
let codec = Arc::new(codec) as Arc<dyn PhysicalExtensionCodec>;
52+
let mut combined_codec = ComposedPhysicalExtensionCodec::default();
53+
combined_codec.push(DistributedCodec);
54+
if let Some(ref user_codec) = get_user_codec(state.config()) {
55+
combined_codec.push_arc(Arc::clone(&user_codec));
56+
}
5357

54-
let stage = stage_from_proto(stage_msg, function_registry, &self.runtime.as_ref(), codec)
55-
.map(Arc::new)
56-
.map_err(|err| Status::invalid_argument(format!("Cannot decode stage proto: {err}")))?;
58+
let mut stage = stage_from_proto(
59+
stage_msg,
60+
function_registry,
61+
&self.runtime.as_ref(),
62+
&combined_codec,
63+
)
64+
.map_err(|err| Status::invalid_argument(format!("Cannot decode stage proto: {err}")))?;
65+
let inner_plan = Arc::clone(&stage.plan);
5766

5867
// Add the extensions that might be required for ExecutionPlan nodes in the plan
5968
let config = state.config_mut();
6069
config.set_extension(Arc::clone(&self.channel_manager));
61-
config.set_extension(stage.clone());
70+
config.set_extension(Arc::new(stage));
6271

63-
let stream = stage
64-
.plan
72+
let stream = inner_plan
6573
.execute(doget.partition as usize, state.task_ctx())
6674
.map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?;
6775

6876
let flight_data_stream = FlightDataEncoderBuilder::new()
69-
.with_schema(stage.plan.schema().clone())
77+
.with_schema(inner_plan.schema().clone())
7078
.build(stream.map_err(|err| {
7179
FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err)))
7280
}));

src/flight_service/session_builder.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pub trait SessionBuilder {
1616
/// # use datafusion::execution::{FunctionRegistry, SessionStateBuilder};
1717
/// # use datafusion::physical_plan::ExecutionPlan;
1818
/// # use datafusion_proto::physical_plan::PhysicalExtensionCodec;
19-
/// # use datafusion_distributed::{SessionBuilder};
19+
/// # use datafusion_distributed::{with_user_codec, SessionBuilder};
2020
///
2121
/// #[derive(Debug)]
2222
/// struct CustomExecCodec;
@@ -35,10 +35,8 @@ pub trait SessionBuilder {
3535
/// struct CustomSessionBuilder;
3636
/// impl SessionBuilder for CustomSessionBuilder {
3737
/// fn on_new_session(&self, mut builder: SessionStateBuilder) -> SessionStateBuilder {
38-
/// let config = builder.config().get_or_insert_default();
39-
/// let codec: Arc<dyn PhysicalExtensionCodec> = Arc::new(CustomExecCodec);
40-
/// config.set_extension(Arc::new(codec));
41-
/// builder
38+
/// // Add your UDFs, optimization rules, etc...
39+
/// with_user_codec(builder, CustomExecCodec)
4240
/// }
4341
/// }
4442
/// ```

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ mod test_utils;
1010
pub mod physical_optimizer;
1111
pub mod stage;
1212
pub mod task;
13+
mod user_provided_codec;
14+
1315
pub use channel_manager::{BoxCloneSyncChannel, ChannelManager, ChannelResolver};
1416
pub use flight_service::{ArrowFlightEndpoint, SessionBuilder};
1517
pub use plan::ArrowFlightReadExec;
18+
pub use user_provided_codec::{add_user_codec, with_user_codec};

src/physical_optimizer.rs

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::sync::Arc;
22

3+
use super::stage::ExecutionStage;
34
use crate::{plan::PartitionIsolatorExec, ArrowFlightReadExec};
45
use datafusion::common::tree_node::TreeNodeRecursion;
56
use datafusion::error::DataFusionError;
@@ -15,15 +16,9 @@ use datafusion::{
1516
displayable, repartition::RepartitionExec, ExecutionPlan, ExecutionPlanProperties,
1617
},
1718
};
18-
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
19-
20-
use super::stage::ExecutionStage;
2119

2220
#[derive(Debug, Default)]
2321
pub struct DistributedPhysicalOptimizerRule {
24-
/// Optional codec to assist in serializing and deserializing any custom
25-
/// ExecutionPlan nodes
26-
codec: Option<Arc<dyn PhysicalExtensionCodec>>,
2722
/// maximum number of partitions per task. This is used to determine how many
2823
/// tasks to create for each stage
2924
partitions_per_task: Option<usize>,
@@ -32,18 +27,10 @@ pub struct DistributedPhysicalOptimizerRule {
3227
impl DistributedPhysicalOptimizerRule {
3328
pub fn new() -> Self {
3429
DistributedPhysicalOptimizerRule {
35-
codec: None,
3630
partitions_per_task: None,
3731
}
3832
}
3933

40-
/// Set a codec to use to assist in serializing and deserializing
41-
/// custom ExecutionPlan nodes.
42-
pub fn with_codec(mut self, codec: Arc<dyn PhysicalExtensionCodec>) -> Self {
43-
self.codec = Some(codec);
44-
self
45-
}
46-
4734
/// Set the maximum number of partitions per task. This is used to determine how many
4835
/// tasks to create for each stage.
4936
///
@@ -156,9 +143,6 @@ impl DistributedPhysicalOptimizerRule {
156143
if let Some(partitions_per_task) = self.partitions_per_task {
157144
stage = stage.with_maximum_partitions_per_task(partitions_per_task);
158145
}
159-
if let Some(codec) = self.codec.as_ref() {
160-
stage = stage.with_codec(codec.clone());
161-
}
162146
stage.depth = depth;
163147

164148
Ok(stage)

src/plan/arrow_flight_read.rs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
use super::combined::CombinedRecordBatchStream;
22
use crate::channel_manager::ChannelManager;
3+
use crate::composed_extension_codec::ComposedPhysicalExtensionCodec;
34
use crate::errors::tonic_status_to_datafusion_error;
45
use crate::flight_service::DoGet;
5-
use crate::stage::{ExecutionStage, ExecutionStageProto};
6+
use crate::plan::DistributedCodec;
7+
use crate::stage::{proto_from_stage, ExecutionStage};
8+
use crate::user_provided_codec::get_user_codec;
69
use arrow_flight::decode::FlightRecordBatchStream;
710
use arrow_flight::error::FlightError;
811
use arrow_flight::flight_service_client::FlightServiceClient;
@@ -170,12 +173,14 @@ impl ExecutionPlan for ArrowFlightReadExec {
170173
this.stage_num
171174
))?;
172175

173-
let child_stage_tasks = child_stage.tasks.clone();
174-
let child_stage_proto = ExecutionStageProto::try_from(child_stage).map_err(|e| {
175-
internal_datafusion_err!(
176-
"ArrowFlightReadExec: failed to convert stage to proto: {}",
177-
e
178-
)
176+
let mut combined_codec = ComposedPhysicalExtensionCodec::default();
177+
combined_codec.push(DistributedCodec {});
178+
if let Some(ref user_codec) = get_user_codec(context.session_config()) {
179+
combined_codec.push_arc(Arc::clone(user_codec));
180+
}
181+
182+
let child_stage_proto = proto_from_stage(child_stage, &combined_codec).map_err(|e| {
183+
internal_datafusion_err!("ArrowFlightReadExec: failed to convert stage to proto: {e}")
179184
})?;
180185

181186
let ticket_bytes = DoGet {
@@ -191,6 +196,7 @@ impl ExecutionPlan for ArrowFlightReadExec {
191196

192197
let schema = child_stage.plan.schema();
193198

199+
let child_stage_tasks = child_stage.tasks.clone();
194200
let stream = async move {
195201
let futs = child_stage_tasks.iter().enumerate().map(|(i, task)| async {
196202
let url = task.url()?.ok_or(internal_datafusion_err!(

src/stage/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ mod proto;
33
mod stage;
44

55
pub use display::display_stage_graphviz;
6-
pub use proto::{stage_from_proto, ExecutionStageProto};
6+
pub use proto::{proto_from_stage, stage_from_proto, ExecutionStageProto};
77
pub use stage::ExecutionStage;

src/stage/proto.rs

Lines changed: 23 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use datafusion_proto::{
1111
protobuf::PhysicalPlanNode,
1212
};
1313

14-
use crate::{plan::DistributedCodec, task::ExecutionTask};
14+
use crate::task::ExecutionTask;
1515

1616
use super::ExecutionStage;
1717

@@ -35,54 +35,42 @@ pub struct ExecutionStageProto {
3535
pub tasks: Vec<ExecutionTask>,
3636
}
3737

38-
impl TryFrom<&ExecutionStage> for ExecutionStageProto {
39-
type Error = DataFusionError;
40-
41-
fn try_from(stage: &ExecutionStage) -> Result<Self, Self::Error> {
42-
let codec = stage.codec.clone().unwrap_or(Arc::new(DistributedCodec {}));
43-
44-
let proto_plan =
45-
PhysicalPlanNode::try_from_physical_plan(stage.plan.clone(), codec.as_ref())?;
46-
let inputs = stage
47-
.child_stages_iter()
48-
.map(|s| Box::new(ExecutionStageProto::try_from(s).unwrap()))
49-
.collect();
50-
51-
Ok(ExecutionStageProto {
52-
num: stage.num as u64,
53-
name: stage.name(),
54-
plan: Some(Box::new(proto_plan)),
55-
inputs,
56-
tasks: stage.tasks.clone(),
57-
})
58-
}
59-
}
60-
61-
impl TryFrom<ExecutionStage> for ExecutionStageProto {
62-
type Error = DataFusionError;
38+
pub fn proto_from_stage(
39+
stage: &ExecutionStage,
40+
codec: &dyn PhysicalExtensionCodec,
41+
) -> Result<ExecutionStageProto, DataFusionError> {
42+
let proto_plan = PhysicalPlanNode::try_from_physical_plan(stage.plan.clone(), codec)?;
43+
let inputs = stage
44+
.child_stages_iter()
45+
.map(|s| Ok(Box::new(proto_from_stage(s, codec)?)))
46+
.collect::<Result<Vec<_>>>()?;
6347

64-
fn try_from(stage: ExecutionStage) -> Result<Self, Self::Error> {
65-
ExecutionStageProto::try_from(&stage)
66-
}
48+
Ok(ExecutionStageProto {
49+
num: stage.num as u64,
50+
name: stage.name(),
51+
plan: Some(Box::new(proto_plan)),
52+
inputs,
53+
tasks: stage.tasks.clone(),
54+
})
6755
}
6856

6957
pub fn stage_from_proto(
7058
msg: ExecutionStageProto,
7159
registry: &dyn FunctionRegistry,
7260
runtime: &RuntimeEnv,
73-
codec: Arc<dyn PhysicalExtensionCodec>,
61+
codec: &dyn PhysicalExtensionCodec,
7462
) -> Result<ExecutionStage> {
7563
let plan_node = msg.plan.ok_or(internal_datafusion_err!(
7664
"ExecutionStageMsg is missing the plan"
7765
))?;
7866

79-
let plan = plan_node.try_into_physical_plan(registry, runtime, codec.as_ref())?;
67+
let plan = plan_node.try_into_physical_plan(registry, runtime, codec)?;
8068

8169
let inputs = msg
8270
.inputs
8371
.into_iter()
8472
.map(|s| {
85-
stage_from_proto(*s, registry, runtime, codec.clone())
73+
stage_from_proto(*s, registry, runtime, codec)
8674
.map(|s| Arc::new(s) as Arc<dyn ExecutionPlan>)
8775
})
8876
.collect::<Result<Vec<_>>>()?;
@@ -93,7 +81,6 @@ pub fn stage_from_proto(
9381
plan,
9482
inputs,
9583
tasks: msg.tasks,
96-
codec: Some(codec),
9784
depth: 0,
9885
})
9986
}
@@ -116,6 +103,7 @@ mod tests {
116103
use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec;
117104
use prost::Message;
118105

106+
use crate::stage::proto::proto_from_stage;
119107
use crate::stage::{proto::stage_from_proto, ExecutionStage, ExecutionStageProto};
120108

121109
// create a simple mem table
@@ -158,12 +146,11 @@ mod tests {
158146
plan: physical_plan,
159147
inputs: vec![],
160148
tasks: vec![],
161-
codec: Some(Arc::new(DefaultPhysicalExtensionCodec {})),
162149
depth: 0,
163150
};
164151

165152
// Convert to proto message
166-
let stage_msg = ExecutionStageProto::try_from(&stage)?;
153+
let stage_msg = proto_from_stage(&stage, &DefaultPhysicalExtensionCodec {})?;
167154

168155
// Serialize to bytes
169156
let mut buf = Vec::new();
@@ -180,7 +167,7 @@ mod tests {
180167
decoded_msg,
181168
&ctx,
182169
ctx.runtime_env().as_ref(),
183-
Arc::new(DefaultPhysicalExtensionCodec {}),
170+
&DefaultPhysicalExtensionCodec {},
184171
)?;
185172

186173
// Compare original and round-tripped stages

0 commit comments

Comments
 (0)