Skip to content

Commit 50131c8

Browse files
committed
Fix issue in distributed_codec
1 parent 7a881d3 commit 50131c8

File tree

4 files changed

+94
-142
lines changed

4 files changed

+94
-142
lines changed

src/execution_plans/distributed.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ impl ExecutionPlan for DistributedExec {
138138
}
139139

140140
let channel_resolver = get_distributed_channel_resolver(context.session_config())?;
141-
let codec = DistributedCodec::new_combined_with_user(Arc::clone(&context));
141+
let codec = DistributedCodec::new_combined_with_user(context.session_config());
142142

143143
let prepared = self.prepare_plan(&channel_resolver.get_urls()?, &codec)?;
144144
{

src/flight_service/do_get.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ impl ArrowFlightEndpoint {
8181
.await
8282
.map_err(|err| datafusion_error_to_tonic_status(&err))?;
8383

84-
let codec = DistributedCodec::new_combined_with_user(session_state.task_ctx());
84+
let codec = DistributedCodec::new_combined_with_user(session_state.config());
85+
let ctx = SessionContext::new_with_state(session_state.clone());
8586

8687
// There's only 1 `StageExec` responsible for all requests that share the same `stage_key`,
8788
// so here we either retrieve the existing one or create a new one if it does not exist.
@@ -93,7 +94,6 @@ impl ArrowFlightEndpoint {
9394
let stage_data = once
9495
.get_or_try_init(|| async {
9596
let proto_node = PhysicalPlanNode::try_decode(doget.plan_proto.as_ref())?;
96-
let ctx = SessionContext::from(session_state.clone());
9797
let plan = proto_node.try_into_physical_plan(&ctx, &self.runtime, &codec)?;
9898

9999
// Initialize partition count to the number of partitions in the stage

src/protobuf/distributed_codec.rs

Lines changed: 38 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -8,38 +8,29 @@ use datafusion::arrow::datatypes::Schema;
88
use datafusion::arrow::datatypes::SchemaRef;
99
use datafusion::common::internal_datafusion_err;
1010
use datafusion::error::DataFusionError;
11-
use datafusion::execution::{FunctionRegistry, TaskContext};
12-
use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
11+
use datafusion::execution::{FunctionRegistry, SessionStateBuilder};
1312
use datafusion::physical_expr::EquivalenceProperties;
1413
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
1514
use datafusion::physical_plan::{ExecutionPlan, Partitioning, PlanProperties};
16-
use datafusion::prelude::SessionContext;
15+
use datafusion::prelude::{SessionConfig, SessionContext};
1716
use datafusion_proto::physical_plan::from_proto::parse_protobuf_partitioning;
1817
use datafusion_proto::physical_plan::to_proto::serialize_partitioning;
1918
use datafusion_proto::physical_plan::{ComposedPhysicalExtensionCodec, PhysicalExtensionCodec};
2019
use datafusion_proto::protobuf;
2120
use datafusion_proto::protobuf::proto_error;
2221
use prost::Message;
23-
use std::fmt::{Debug, Formatter};
2422
use std::sync::Arc;
2523
use url::Url;
2624

2725
/// DataFusion [PhysicalExtensionCodec] implementation that allows serializing and
2826
/// deserializing the custom ExecutionPlans in this project
29-
#[derive(Clone, Default)]
30-
pub struct DistributedCodec(Arc<TaskContext>);
31-
32-
impl Debug for DistributedCodec {
33-
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
34-
write!(f, "DistributedCodec")
35-
}
36-
}
27+
#[derive(Debug)]
28+
pub struct DistributedCodec;
3729

3830
impl DistributedCodec {
39-
pub fn new_combined_with_user(ctx: Arc<TaskContext>) -> impl PhysicalExtensionCodec {
40-
let mut codecs: Vec<Arc<dyn PhysicalExtensionCodec>> =
41-
vec![Arc::new(DistributedCodec(Arc::clone(&ctx)))];
42-
codecs.extend(get_distributed_user_codecs(ctx.session_config()));
31+
pub fn new_combined_with_user(cfg: &SessionConfig) -> impl PhysicalExtensionCodec + use<> {
32+
let mut codecs: Vec<Arc<dyn PhysicalExtensionCodec>> = vec![Arc::new(DistributedCodec {})];
33+
codecs.extend(get_distributed_user_codecs(cfg));
4334
ComposedPhysicalExtensionCodec::new(codecs)
4435
}
4536
}
@@ -49,7 +40,7 @@ impl PhysicalExtensionCodec for DistributedCodec {
4940
&self,
5041
buf: &[u8],
5142
inputs: &[Arc<dyn ExecutionPlan>],
52-
_registry: &dyn FunctionRegistry,
43+
registry: &dyn FunctionRegistry,
5344
) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
5445
let DistributedExecProto {
5546
node: Some(distributed_exec_node),
@@ -63,7 +54,16 @@ impl PhysicalExtensionCodec for DistributedCodec {
6354
// TODO: The PhysicalExtensionCodec trait doesn't provide access to session state,
6455
// so we create a new SessionContext which loses any custom UDFs, UDAFs, and other
6556
// user configurations. This is a limitation of the current trait design.
66-
let ctx = SessionContext::new();
57+
let state = SessionStateBuilder::new()
58+
.with_scalar_functions(
59+
registry
60+
.udfs()
61+
.iter()
62+
.map(|f| registry.udf(f))
63+
.collect::<Result<Vec<_>, _>>()?,
64+
)
65+
.build();
66+
let ctx = SessionContext::from(state);
6767

6868
fn parse_stage_proto(
6969
proto: Option<StageProto>,
@@ -112,9 +112,13 @@ impl PhysicalExtensionCodec for DistributedCodec {
112112
.map(|s| s.try_into())
113113
.ok_or(proto_error("NetworkShuffleExec is missing schema"))??;
114114

115-
let partitioning =
116-
parse_protobuf_partitioning(partitioning.as_ref(), &ctx, &schema, self)?
117-
.ok_or(proto_error("NetworkShuffleExec is missing partitioning"))?;
115+
let partitioning = parse_protobuf_partitioning(
116+
partitioning.as_ref(),
117+
&ctx,
118+
&schema,
119+
&DistributedCodec {},
120+
)?
121+
.ok_or(proto_error("NetworkShuffleExec is missing partitioning"))?;
118122

119123
Ok(Arc::new(new_network_hash_shuffle_exec(
120124
partitioning,
@@ -132,9 +136,13 @@ impl PhysicalExtensionCodec for DistributedCodec {
132136
.map(|s| s.try_into())
133137
.ok_or(proto_error("NetworkCoalesceExec is missing schema"))??;
134138

135-
let partitioning =
136-
parse_protobuf_partitioning(partitioning.as_ref(), &ctx, &schema, self)?
137-
.ok_or(proto_error("NetworkCoalesceExec is missing partitioning"))?;
139+
let partitioning = parse_protobuf_partitioning(
140+
partitioning.as_ref(),
141+
&ctx,
142+
&schema,
143+
&DistributedCodec {},
144+
)?
145+
.ok_or(proto_error("NetworkCoalesceExec is missing partitioning"))?;
138146

139147
Ok(Arc::new(new_network_coalesce_tasks_exec(
140148
partitioning,
@@ -185,7 +193,7 @@ impl PhysicalExtensionCodec for DistributedCodec {
185193
schema: Some(node.schema().try_into()?),
186194
partitioning: Some(serialize_partitioning(
187195
node.properties().output_partitioning(),
188-
self,
196+
&DistributedCodec {},
189197
)?),
190198
input_stage: Some(encode_stage_proto(node.input_stage())?),
191199
};
@@ -200,7 +208,7 @@ impl PhysicalExtensionCodec for DistributedCodec {
200208
schema: Some(node.schema().try_into()?),
201209
partitioning: Some(serialize_partitioning(
202210
node.properties().output_partitioning(),
203-
self,
211+
&DistributedCodec {},
204212
)?),
205213
input_stage: Some(encode_stage_proto(node.input_stage())?),
206214
};
@@ -229,30 +237,6 @@ impl PhysicalExtensionCodec for DistributedCodec {
229237
Err(proto_error(format!("Unexpected plan {}", node.name())))
230238
}
231239
}
232-
233-
fn try_decode_udf(
234-
&self,
235-
name: &str,
236-
_buf: &[u8],
237-
) -> datafusion::common::Result<Arc<ScalarUDF>> {
238-
self.0.udf(name)
239-
}
240-
241-
fn try_decode_udaf(
242-
&self,
243-
name: &str,
244-
_buf: &[u8],
245-
) -> datafusion::common::Result<Arc<AggregateUDF>> {
246-
self.0.udaf(name)
247-
}
248-
249-
fn try_decode_udwf(
250-
&self,
251-
name: &str,
252-
_buf: &[u8],
253-
) -> datafusion::common::Result<Arc<WindowUDF>> {
254-
self.0.udwf(name)
255-
}
256240
}
257241

258242
/// A key that uniquely identifies a stage in a query
@@ -436,7 +420,7 @@ mod tests {
436420

437421
#[test]
438422
fn test_roundtrip_single_flight() -> datafusion::common::Result<()> {
439-
let codec = DistributedCodec::default();
423+
let codec = DistributedCodec;
440424
let registry = MemoryFunctionRegistry::new();
441425

442426
let schema = schema_i32("a");
@@ -455,7 +439,7 @@ mod tests {
455439

456440
#[test]
457441
fn test_roundtrip_isolator_flight() -> datafusion::common::Result<()> {
458-
let codec = DistributedCodec::default();
442+
let codec = DistributedCodec;
459443
let registry = MemoryFunctionRegistry::new();
460444

461445
let schema = schema_i32("b");
@@ -479,7 +463,7 @@ mod tests {
479463

480464
#[test]
481465
fn test_roundtrip_isolator_union() -> datafusion::common::Result<()> {
482-
let codec = DistributedCodec::default();
466+
let codec = DistributedCodec;
483467
let registry = MemoryFunctionRegistry::new();
484468

485469
let schema = schema_i32("c");
@@ -509,7 +493,7 @@ mod tests {
509493

510494
#[test]
511495
fn test_roundtrip_isolator_sort_flight() -> datafusion::common::Result<()> {
512-
let codec = DistributedCodec::default();
496+
let codec = DistributedCodec;
513497
let registry = MemoryFunctionRegistry::new();
514498

515499
let schema = schema_i32("d");

0 commit comments

Comments
 (0)