Skip to content

Commit 957bffd

Browse files
committed
Allow passing custom codecs
1 parent 4da5c5d commit 957bffd

File tree

9 files changed

+122
-129
lines changed

9 files changed

+122
-129
lines changed

src/composed_extension_codec.rs

Lines changed: 4 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ use datafusion::error::DataFusionError;
33
use datafusion::execution::FunctionRegistry;
44
use datafusion::logical_expr::{AggregateUDF, ScalarUDF};
55
use datafusion::physical_plan::ExecutionPlan;
6-
use datafusion::prelude::SessionConfig;
76
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
87
use std::fmt::Debug;
98
use std::sync::Arc;
@@ -25,42 +24,10 @@ impl ComposedPhysicalExtensionCodec {
2524
self.codecs.push(Arc::new(codec));
2625
}
2726

28-
/// Adds a new [PhysicalExtensionCodec] from DataFusion's [SessionConfig] extensions.
29-
///
30-
/// If users have a custom [PhysicalExtensionCodec] for their own nodes, they should
31-
/// populate the config extensions with a [PhysicalExtensionCodec] so that we can use
32-
/// it while encoding/decoding plans to/from protobuf.
33-
///
34-
/// Example:
35-
/// ```rust
36-
/// # use std::sync::Arc;
37-
/// # use datafusion::execution::FunctionRegistry;
38-
/// # use datafusion::physical_plan::ExecutionPlan;
39-
/// # use datafusion::prelude::SessionConfig;
40-
/// # use datafusion_proto::physical_plan::PhysicalExtensionCodec;
41-
///
42-
/// #[derive(Debug)]
43-
/// struct CustomUserCodec {}
44-
///
45-
/// impl PhysicalExtensionCodec for CustomUserCodec {
46-
/// fn try_decode(&self, buf: &[u8], inputs: &[Arc<dyn ExecutionPlan>], registry: &dyn FunctionRegistry) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
47-
/// todo!()
48-
/// }
49-
///
50-
/// fn try_encode(&self, node: Arc<dyn ExecutionPlan>, buf: &mut Vec<u8>) -> datafusion::common::Result<()> {
51-
/// todo!()
52-
/// }
53-
/// }
54-
///
55-
/// let mut config = SessionConfig::new();
56-
///
57-
/// let codec: Arc<dyn PhysicalExtensionCodec> = Arc::new(CustomUserCodec {});
58-
/// config.set_extension(Arc::new(codec));
59-
/// ```
60-
pub(crate) fn push_from_config(&mut self, config: &SessionConfig) {
61-
if let Some(user_codec) = config.get_extension::<Arc<dyn PhysicalExtensionCodec>>() {
62-
self.codecs.push(user_codec.as_ref().clone());
63-
}
27+
/// Adds a new [PhysicalExtensionCodec] to the list. These codecs will be tried
28+
/// sequentially until one works.
29+
pub(crate) fn push_arc(&mut self, codec: Arc<dyn PhysicalExtensionCodec>) {
30+
self.codecs.push(codec);
6431
}
6532

6633
fn try_any<T>(

src/flight_service/do_get.rs

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use crate::composed_extension_codec::ComposedPhysicalExtensionCodec;
12
use crate::errors::datafusion_error_to_tonic_status;
23
use crate::flight_service::service::ArrowFlightEndpoint;
34
use crate::plan::DistributedCodec;
@@ -8,7 +9,6 @@ use arrow_flight::flight_service_server::FlightService;
89
use arrow_flight::Ticket;
910
use datafusion::execution::SessionStateBuilder;
1011
use datafusion::optimizer::OptimizerConfig;
11-
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
1212
use futures::TryStreamExt;
1313
use prost::Message;
1414
use std::sync::Arc;
@@ -48,25 +48,36 @@ impl ArrowFlightEndpoint {
4848
"FunctionRegistry not present in newly built SessionState",
4949
))?;
5050

51-
let codec = DistributedCodec {};
52-
let codec = Arc::new(codec) as Arc<dyn PhysicalExtensionCodec>;
51+
let user_codec = self.session_builder.codec();
52+
let mut combined_codec = ComposedPhysicalExtensionCodec::default();
53+
combined_codec.push(DistributedCodec);
54+
if let Some(ref user_codec) = user_codec {
55+
combined_codec.push_arc(Arc::clone(&user_codec));
56+
}
5357

54-
let stage = stage_from_proto(stage_msg, function_registry, &self.runtime.as_ref(), codec)
55-
.map(Arc::new)
56-
.map_err(|err| Status::invalid_argument(format!("Cannot decode stage proto: {err}")))?;
58+
let mut stage = stage_from_proto(
59+
stage_msg,
60+
function_registry,
61+
&self.runtime.as_ref(),
62+
&combined_codec,
63+
)
64+
.map_err(|err| Status::invalid_argument(format!("Cannot decode stage proto: {err}")))?;
65+
if let Some(user_codec) = user_codec {
66+
stage = stage.with_user_codec(user_codec)
67+
}
68+
let inner_plan = Arc::clone(&stage.plan);
5769

5870
// Add the extensions that might be required for ExecutionPlan nodes in the plan
5971
let config = state.config_mut();
6072
config.set_extension(Arc::clone(&self.channel_manager));
61-
config.set_extension(stage.clone());
73+
config.set_extension(Arc::new(stage));
6274

63-
let stream = stage
64-
.plan
75+
let stream = inner_plan
6576
.execute(doget.partition as usize, state.task_ctx())
6677
.map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?;
6778

6879
let flight_data_stream = FlightDataEncoderBuilder::new()
69-
.with_schema(stage.plan.schema().clone())
80+
.with_schema(inner_plan.schema().clone())
7081
.build(stream.map_err(|err| {
7182
FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err)))
7283
}));

src/flight_service/session_builder.rs

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,35 @@
11
use datafusion::execution::SessionStateBuilder;
2+
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
3+
use std::sync::Arc;
24

35
/// Trait called by the Arrow Flight endpoint that handles distributed parts of a DataFusion
46
/// plan for building a DataFusion's [datafusion::prelude::SessionContext].
57
pub trait SessionBuilder {
68
/// Takes a [SessionStateBuilder] and adds whatever is necessary for it to work, like
7-
/// custom extension codecs, custom physical optimization rules, UDFs, UDAFs, config
8-
/// extensions, etc...
9+
/// custom physical optimization rules, UDFs, UDAFs, config extensions, etc...
910
///
10-
/// Example: adding some custom extension plan codecs
11+
/// Example:
1112
///
1213
/// ```rust
14+
/// # use datafusion::execution::{FunctionRegistry, SessionStateBuilder};
15+
/// # use datafusion_distributed::{SessionBuilder};
1316
///
17+
/// #[derive(Clone)]
18+
/// struct CustomSessionBuilder;
19+
/// impl SessionBuilder for CustomSessionBuilder {
20+
/// fn on_new_session(&self, mut builder: SessionStateBuilder) -> SessionStateBuilder {
21+
/// // add your own UDFs, optimization rules, etc...
22+
/// builder
23+
/// }
24+
/// }
25+
/// ```
26+
fn on_new_session(&self, builder: SessionStateBuilder) -> SessionStateBuilder {
27+
builder
28+
}
29+
30+
/// Allows users to provide their own codecs.
31+
///
32+
/// ```rust
1433
/// # use std::sync::Arc;
1534
/// # use datafusion::execution::runtime_env::RuntimeEnv;
1635
/// # use datafusion::execution::{FunctionRegistry, SessionStateBuilder};
@@ -34,15 +53,14 @@ pub trait SessionBuilder {
3453
/// #[derive(Clone)]
3554
/// struct CustomSessionBuilder;
3655
/// impl SessionBuilder for CustomSessionBuilder {
37-
/// fn on_new_session(&self, mut builder: SessionStateBuilder) -> SessionStateBuilder {
38-
/// let config = builder.config().get_or_insert_default();
39-
/// let codec: Arc<dyn PhysicalExtensionCodec> = Arc::new(CustomExecCodec);
40-
/// config.set_extension(Arc::new(codec));
41-
/// builder
56+
/// fn codec(&self) -> Option<Arc<dyn PhysicalExtensionCodec + 'static>> {
57+
/// Some(Arc::new(CustomExecCodec))
4258
/// }
4359
/// }
4460
/// ```
45-
fn on_new_session(&self, builder: SessionStateBuilder) -> SessionStateBuilder;
61+
fn codec(&self) -> Option<Arc<dyn PhysicalExtensionCodec + 'static>> {
62+
None
63+
}
4664
}
4765

4866
/// Noop implementation of the [SessionBuilder]. Used by default if no [SessionBuilder] is provided

src/physical_optimizer.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ impl DistributedPhysicalOptimizerRule {
3939

4040
/// Set a codec to use to assist in serializing and deserializing
4141
/// custom ExecutionPlan nodes.
42-
pub fn with_codec(mut self, codec: Arc<dyn PhysicalExtensionCodec>) -> Self {
43-
self.codec = Some(codec);
42+
pub fn with_codec(mut self, codec: impl PhysicalExtensionCodec + 'static) -> Self {
43+
self.codec = Some(Arc::new(codec));
4444
self
4545
}
4646

@@ -157,7 +157,7 @@ impl DistributedPhysicalOptimizerRule {
157157
stage = stage.with_maximum_partitions_per_task(partitions_per_task);
158158
}
159159
if let Some(codec) = self.codec.as_ref() {
160-
stage = stage.with_codec(codec.clone());
160+
stage = stage.with_user_codec(codec.clone());
161161
}
162162
stage.depth = depth;
163163

src/plan/arrow_flight_read.rs

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
use super::combined::CombinedRecordBatchStream;
22
use crate::channel_manager::ChannelManager;
3+
use crate::composed_extension_codec::ComposedPhysicalExtensionCodec;
34
use crate::errors::tonic_status_to_datafusion_error;
45
use crate::flight_service::DoGet;
5-
use crate::stage::{ExecutionStage, ExecutionStageProto};
6+
use crate::plan::DistributedCodec;
7+
use crate::stage::{proto_from_stage, ExecutionStage};
68
use arrow_flight::decode::FlightRecordBatchStream;
79
use arrow_flight::error::FlightError;
810
use arrow_flight::flight_service_client::FlightServiceClient;
@@ -170,12 +172,14 @@ impl ExecutionPlan for ArrowFlightReadExec {
170172
this.stage_num
171173
))?;
172174

173-
let child_stage_tasks = child_stage.tasks.clone();
174-
let child_stage_proto = ExecutionStageProto::try_from(child_stage).map_err(|e| {
175-
internal_datafusion_err!(
176-
"ArrowFlightReadExec: failed to convert stage to proto: {}",
177-
e
178-
)
175+
let mut combined_codec = ComposedPhysicalExtensionCodec::default();
176+
combined_codec.push(DistributedCodec {});
177+
if let Some(ref user_codec) = stage.user_codec {
178+
combined_codec.push_arc(Arc::clone(user_codec));
179+
}
180+
181+
let child_stage_proto = proto_from_stage(child_stage, &combined_codec).map_err(|e| {
182+
internal_datafusion_err!("ArrowFlightReadExec: failed to convert stage to proto: {e}")
179183
})?;
180184

181185
let ticket_bytes = DoGet {
@@ -191,6 +195,7 @@ impl ExecutionPlan for ArrowFlightReadExec {
191195

192196
let schema = child_stage.plan.schema();
193197

198+
let child_stage_tasks = child_stage.tasks.clone();
194199
let stream = async move {
195200
let futs = child_stage_tasks.iter().enumerate().map(|(i, task)| async {
196201
let url = task.url()?.ok_or(internal_datafusion_err!(

src/stage/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ mod proto;
33
mod stage;
44

55
pub use display::display_stage_graphviz;
6-
pub use proto::{stage_from_proto, ExecutionStageProto};
6+
pub use proto::{proto_from_stage, stage_from_proto, ExecutionStageProto};
77
pub use stage::ExecutionStage;

src/stage/proto.rs

Lines changed: 25 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use datafusion_proto::{
1111
protobuf::PhysicalPlanNode,
1212
};
1313

14-
use crate::{plan::DistributedCodec, task::ExecutionTask};
14+
use crate::task::ExecutionTask;
1515

1616
use super::ExecutionStage;
1717

@@ -35,54 +35,42 @@ pub struct ExecutionStageProto {
3535
pub tasks: Vec<ExecutionTask>,
3636
}
3737

38-
impl TryFrom<&ExecutionStage> for ExecutionStageProto {
39-
type Error = DataFusionError;
40-
41-
fn try_from(stage: &ExecutionStage) -> Result<Self, Self::Error> {
42-
let codec = stage.codec.clone().unwrap_or(Arc::new(DistributedCodec {}));
43-
44-
let proto_plan =
45-
PhysicalPlanNode::try_from_physical_plan(stage.plan.clone(), codec.as_ref())?;
46-
let inputs = stage
47-
.child_stages_iter()
48-
.map(|s| Box::new(ExecutionStageProto::try_from(s).unwrap()))
49-
.collect();
50-
51-
Ok(ExecutionStageProto {
52-
num: stage.num as u64,
53-
name: stage.name(),
54-
plan: Some(Box::new(proto_plan)),
55-
inputs,
56-
tasks: stage.tasks.clone(),
57-
})
58-
}
59-
}
60-
61-
impl TryFrom<ExecutionStage> for ExecutionStageProto {
62-
type Error = DataFusionError;
38+
pub fn proto_from_stage(
39+
stage: &ExecutionStage,
40+
codec: &dyn PhysicalExtensionCodec,
41+
) -> Result<ExecutionStageProto, DataFusionError> {
42+
let proto_plan = PhysicalPlanNode::try_from_physical_plan(stage.plan.clone(), codec)?;
43+
let inputs = stage
44+
.child_stages_iter()
45+
.map(|s| Ok(Box::new(proto_from_stage(s, codec)?)))
46+
.collect::<Result<Vec<_>>>()?;
6347

64-
fn try_from(stage: ExecutionStage) -> Result<Self, Self::Error> {
65-
ExecutionStageProto::try_from(&stage)
66-
}
48+
Ok(ExecutionStageProto {
49+
num: stage.num as u64,
50+
name: stage.name(),
51+
plan: Some(Box::new(proto_plan)),
52+
inputs,
53+
tasks: stage.tasks.clone(),
54+
})
6755
}
6856

6957
pub fn stage_from_proto(
7058
msg: ExecutionStageProto,
7159
registry: &dyn FunctionRegistry,
7260
runtime: &RuntimeEnv,
73-
codec: Arc<dyn PhysicalExtensionCodec>,
61+
codec: &dyn PhysicalExtensionCodec,
7462
) -> Result<ExecutionStage> {
7563
let plan_node = msg.plan.ok_or(internal_datafusion_err!(
7664
"ExecutionStageMsg is missing the plan"
7765
))?;
7866

79-
let plan = plan_node.try_into_physical_plan(registry, runtime, codec.as_ref())?;
67+
let plan = plan_node.try_into_physical_plan(registry, runtime, codec)?;
8068

8169
let inputs = msg
8270
.inputs
8371
.into_iter()
8472
.map(|s| {
85-
stage_from_proto(*s, registry, runtime, codec.clone())
73+
stage_from_proto(*s, registry, runtime, codec)
8674
.map(|s| Arc::new(s) as Arc<dyn ExecutionPlan>)
8775
})
8876
.collect::<Result<Vec<_>>>()?;
@@ -93,7 +81,7 @@ pub fn stage_from_proto(
9381
plan,
9482
inputs,
9583
tasks: msg.tasks,
96-
codec: Some(codec),
84+
user_codec: None,
9785
depth: 0,
9886
})
9987
}
@@ -116,6 +104,7 @@ mod tests {
116104
use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec;
117105
use prost::Message;
118106

107+
use crate::stage::proto::proto_from_stage;
119108
use crate::stage::{proto::stage_from_proto, ExecutionStage, ExecutionStageProto};
120109

121110
// create a simple mem table
@@ -158,12 +147,12 @@ mod tests {
158147
plan: physical_plan,
159148
inputs: vec![],
160149
tasks: vec![],
161-
codec: Some(Arc::new(DefaultPhysicalExtensionCodec {})),
150+
user_codec: Some(Arc::new(DefaultPhysicalExtensionCodec {})),
162151
depth: 0,
163152
};
164153

165154
// Convert to proto message
166-
let stage_msg = ExecutionStageProto::try_from(&stage)?;
155+
let stage_msg = proto_from_stage(&stage, &DefaultPhysicalExtensionCodec {})?;
167156

168157
// Serialize to bytes
169158
let mut buf = Vec::new();
@@ -180,7 +169,7 @@ mod tests {
180169
decoded_msg,
181170
&ctx,
182171
ctx.runtime_env().as_ref(),
183-
Arc::new(DefaultPhysicalExtensionCodec {}),
172+
&DefaultPhysicalExtensionCodec {},
184173
)?;
185174

186175
// Compare original and round-tripped stages

0 commit comments

Comments
 (0)