Skip to content

Commit 7a881d3

Browse files
committed
Make DistributedExec handle udfs
1 parent d7d86d9 commit 7a881d3

File tree

4 files changed

+212
-37
lines changed

4 files changed

+212
-37
lines changed

src/execution_plans/distributed.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ impl ExecutionPlan for DistributedExec {
138138
}
139139

140140
let channel_resolver = get_distributed_channel_resolver(context.session_config())?;
141-
let codec = DistributedCodec::new_combined_with_user(context.session_config());
141+
let codec = DistributedCodec::new_combined_with_user(Arc::clone(&context));
142142

143143
let prepared = self.prepare_plan(&channel_resolver.get_urls()?, &codec)?;
144144
{

src/flight_service/do_get.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@ impl ArrowFlightEndpoint {
8181
.await
8282
.map_err(|err| datafusion_error_to_tonic_status(&err))?;
8383

84-
let codec = DistributedCodec::new_combined_with_user(session_state.config());
85-
let ctx = SessionContext::new_with_state(session_state.clone());
84+
let codec = DistributedCodec::new_combined_with_user(session_state.task_ctx());
8685

8786
// There's only 1 `StageExec` responsible for all requests that share the same `stage_key`,
8887
// so here we either retrieve the existing one or create a new one if it does not exist.
@@ -94,6 +93,7 @@ impl ArrowFlightEndpoint {
9493
let stage_data = once
9594
.get_or_try_init(|| async {
9695
let proto_node = PhysicalPlanNode::try_decode(doget.plan_proto.as_ref())?;
96+
let ctx = SessionContext::from(session_state.clone());
9797
let plan = proto_node.try_into_physical_plan(&ctx, &self.runtime, &codec)?;
9898

9999
// Initialize partition count to the number of partitions in the stage

src/protobuf/distributed_codec.rs

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,32 +6,40 @@ use crate::{NetworkShuffleExec, PartitionIsolatorExec};
66
use bytes::Bytes;
77
use datafusion::arrow::datatypes::Schema;
88
use datafusion::arrow::datatypes::SchemaRef;
9-
use datafusion::common::{internal_datafusion_err, not_impl_err};
9+
use datafusion::common::internal_datafusion_err;
1010
use datafusion::error::DataFusionError;
11-
use datafusion::execution::FunctionRegistry;
11+
use datafusion::execution::{FunctionRegistry, TaskContext};
1212
use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
1313
use datafusion::physical_expr::EquivalenceProperties;
1414
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
1515
use datafusion::physical_plan::{ExecutionPlan, Partitioning, PlanProperties};
16-
use datafusion::prelude::{SessionConfig, SessionContext};
16+
use datafusion::prelude::SessionContext;
1717
use datafusion_proto::physical_plan::from_proto::parse_protobuf_partitioning;
1818
use datafusion_proto::physical_plan::to_proto::serialize_partitioning;
1919
use datafusion_proto::physical_plan::{ComposedPhysicalExtensionCodec, PhysicalExtensionCodec};
2020
use datafusion_proto::protobuf;
2121
use datafusion_proto::protobuf::proto_error;
2222
use prost::Message;
23+
use std::fmt::{Debug, Formatter};
2324
use std::sync::Arc;
2425
use url::Url;
2526

2627
/// DataFusion [PhysicalExtensionCodec] implementation that allows serializing and
2728
/// deserializing the custom ExecutionPlans in this project
28-
#[derive(Debug)]
29-
pub struct DistributedCodec;
29+
#[derive(Clone, Default)]
30+
pub struct DistributedCodec(Arc<TaskContext>);
31+
32+
impl Debug for DistributedCodec {
33+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
34+
write!(f, "DistributedCodec")
35+
}
36+
}
3037

3138
impl DistributedCodec {
32-
pub fn new_combined_with_user(cfg: &SessionConfig) -> impl PhysicalExtensionCodec + use<> {
33-
let mut codecs: Vec<Arc<dyn PhysicalExtensionCodec>> = vec![Arc::new(DistributedCodec {})];
34-
codecs.extend(get_distributed_user_codecs(cfg));
39+
pub fn new_combined_with_user(ctx: Arc<TaskContext>) -> impl PhysicalExtensionCodec {
40+
let mut codecs: Vec<Arc<dyn PhysicalExtensionCodec>> =
41+
vec![Arc::new(DistributedCodec(Arc::clone(&ctx)))];
42+
codecs.extend(get_distributed_user_codecs(ctx.session_config()));
3543
ComposedPhysicalExtensionCodec::new(codecs)
3644
}
3745
}
@@ -104,13 +112,9 @@ impl PhysicalExtensionCodec for DistributedCodec {
104112
.map(|s| s.try_into())
105113
.ok_or(proto_error("NetworkShuffleExec is missing schema"))??;
106114

107-
let partitioning = parse_protobuf_partitioning(
108-
partitioning.as_ref(),
109-
&ctx,
110-
&schema,
111-
&DistributedCodec {},
112-
)?
113-
.ok_or(proto_error("NetworkShuffleExec is missing partitioning"))?;
115+
let partitioning =
116+
parse_protobuf_partitioning(partitioning.as_ref(), &ctx, &schema, self)?
117+
.ok_or(proto_error("NetworkShuffleExec is missing partitioning"))?;
114118

115119
Ok(Arc::new(new_network_hash_shuffle_exec(
116120
partitioning,
@@ -128,13 +132,9 @@ impl PhysicalExtensionCodec for DistributedCodec {
128132
.map(|s| s.try_into())
129133
.ok_or(proto_error("NetworkCoalesceExec is missing schema"))??;
130134

131-
let partitioning = parse_protobuf_partitioning(
132-
partitioning.as_ref(),
133-
&ctx,
134-
&schema,
135-
&DistributedCodec {},
136-
)?
137-
.ok_or(proto_error("NetworkCoalesceExec is missing partitioning"))?;
135+
let partitioning =
136+
parse_protobuf_partitioning(partitioning.as_ref(), &ctx, &schema, self)?
137+
.ok_or(proto_error("NetworkCoalesceExec is missing partitioning"))?;
138138

139139
Ok(Arc::new(new_network_coalesce_tasks_exec(
140140
partitioning,
@@ -185,7 +185,7 @@ impl PhysicalExtensionCodec for DistributedCodec {
185185
schema: Some(node.schema().try_into()?),
186186
partitioning: Some(serialize_partitioning(
187187
node.properties().output_partitioning(),
188-
&DistributedCodec {},
188+
self,
189189
)?),
190190
input_stage: Some(encode_stage_proto(node.input_stage())?),
191191
};
@@ -200,7 +200,7 @@ impl PhysicalExtensionCodec for DistributedCodec {
200200
schema: Some(node.schema().try_into()?),
201201
partitioning: Some(serialize_partitioning(
202202
node.properties().output_partitioning(),
203-
&DistributedCodec {},
203+
self,
204204
)?),
205205
input_stage: Some(encode_stage_proto(node.input_stage())?),
206206
};
@@ -230,16 +230,28 @@ impl PhysicalExtensionCodec for DistributedCodec {
230230
}
231231
}
232232

233-
fn try_encode_udf(&self, _: &ScalarUDF, _: &mut Vec<u8>) -> datafusion::common::Result<()> {
234-
not_impl_err!("DistributedCodec does not encode UDFs")
233+
fn try_decode_udf(
234+
&self,
235+
name: &str,
236+
_buf: &[u8],
237+
) -> datafusion::common::Result<Arc<ScalarUDF>> {
238+
self.0.udf(name)
235239
}
236240

237-
fn try_encode_udaf(&self, _: &AggregateUDF, _: &mut Vec<u8>) -> datafusion::common::Result<()> {
238-
not_impl_err!("DistributedCodec does not encode UDAFs")
241+
fn try_decode_udaf(
242+
&self,
243+
name: &str,
244+
_buf: &[u8],
245+
) -> datafusion::common::Result<Arc<AggregateUDF>> {
246+
self.0.udaf(name)
239247
}
240248

241-
fn try_encode_udwf(&self, _: &WindowUDF, _: &mut Vec<u8>) -> datafusion::common::Result<()> {
242-
not_impl_err!("DistributedCodec does not encode UDWFs")
249+
fn try_decode_udwf(
250+
&self,
251+
name: &str,
252+
_buf: &[u8],
253+
) -> datafusion::common::Result<Arc<WindowUDF>> {
254+
self.0.udwf(name)
243255
}
244256
}
245257

@@ -424,7 +436,7 @@ mod tests {
424436

425437
#[test]
426438
fn test_roundtrip_single_flight() -> datafusion::common::Result<()> {
427-
let codec = DistributedCodec;
439+
let codec = DistributedCodec::default();
428440
let registry = MemoryFunctionRegistry::new();
429441

430442
let schema = schema_i32("a");
@@ -443,7 +455,7 @@ mod tests {
443455

444456
#[test]
445457
fn test_roundtrip_isolator_flight() -> datafusion::common::Result<()> {
446-
let codec = DistributedCodec;
458+
let codec = DistributedCodec::default();
447459
let registry = MemoryFunctionRegistry::new();
448460

449461
let schema = schema_i32("b");
@@ -467,7 +479,7 @@ mod tests {
467479

468480
#[test]
469481
fn test_roundtrip_isolator_union() -> datafusion::common::Result<()> {
470-
let codec = DistributedCodec;
482+
let codec = DistributedCodec::default();
471483
let registry = MemoryFunctionRegistry::new();
472484

473485
let schema = schema_i32("c");
@@ -497,7 +509,7 @@ mod tests {
497509

498510
#[test]
499511
fn test_roundtrip_isolator_sort_flight() -> datafusion::common::Result<()> {
500-
let codec = DistributedCodec;
512+
let codec = DistributedCodec::default();
501513
let registry = MemoryFunctionRegistry::new();
502514

503515
let schema = schema_i32("d");

tests/udfs.rs

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
#[cfg(all(feature = "integration", test))]
2+
mod tests {
3+
use arrow::util::pretty::pretty_format_batches;
4+
use datafusion::arrow::datatypes::DataType;
5+
use datafusion::common::not_impl_err;
6+
use datafusion::error::DataFusionError;
7+
use datafusion::execution::{FunctionRegistry, SessionState, SessionStateBuilder};
8+
use datafusion::logical_expr::{
9+
AggregateUDF, ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
10+
Volatility, WindowUDF,
11+
};
12+
use datafusion::logical_expr_common::dyn_eq::DynEq;
13+
use datafusion::physical_optimizer::PhysicalOptimizerRule;
14+
use datafusion::physical_plan::{ExecutionPlan, execute_stream};
15+
use datafusion_distributed::test_utils::localhost::start_localhost_context;
16+
use datafusion_distributed::test_utils::parquet::register_parquet_tables;
17+
use datafusion_distributed::{
18+
DistributedExt, DistributedPhysicalOptimizerRule, DistributedSessionBuilderContext,
19+
assert_snapshot,
20+
};
21+
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
22+
use futures::TryStreamExt;
23+
use std::any::Any;
24+
use std::error::Error;
25+
use std::sync::Arc;
26+
27+
#[tokio::test]
28+
async fn test_udfs() -> Result<(), Box<dyn Error>> {
29+
async fn build_state(
30+
ctx: DistributedSessionBuilderContext,
31+
) -> Result<SessionState, DataFusionError> {
32+
Ok(SessionStateBuilder::new()
33+
.with_runtime_env(ctx.runtime_env)
34+
.with_default_features()
35+
.with_distributed_user_codec(UdfCodec)
36+
.build())
37+
}
38+
39+
let (ctx, _guard) = start_localhost_context(3, build_state).await;
40+
register_parquet_tables(&ctx).await?;
41+
42+
let df = ctx
43+
.sql(r#"SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)"#)
44+
.await?;
45+
let physical = df.create_physical_plan().await?;
46+
47+
let physical_distributed = DistributedPhysicalOptimizerRule::default()
48+
.with_network_shuffle_tasks(2)
49+
.with_network_coalesce_tasks(2)
50+
.optimize(physical.clone(), &Default::default())?;
51+
52+
let batches = pretty_format_batches(
53+
&execute_stream(physical_distributed, ctx.task_ctx())?
54+
.try_collect::<Vec<_>>()
55+
.await?,
56+
)?;
57+
58+
assert_snapshot!(batches, @r"
59+
+----------+-----------+
60+
| count(*) | RainToday |
61+
+----------+-----------+
62+
| 66 | Yes |
63+
| 300 | No |
64+
+----------+-----------+
65+
");
66+
Ok(())
67+
}
68+
69+
#[derive(Debug, PartialEq, Eq, Hash)]
70+
pub struct Udf(Signature);
71+
72+
impl Udf {
73+
fn new() -> Self {
74+
Self(Signature::any(1, Volatility::Immutable))
75+
}
76+
}
77+
78+
impl ScalarUDFImpl for Udf {
79+
fn as_any(&self) -> &dyn Any {
80+
self
81+
}
82+
83+
fn name(&self) -> &str {
84+
"test_udf"
85+
}
86+
87+
fn signature(&self) -> &Signature {
88+
&self.0
89+
}
90+
91+
fn return_type(&self, arg_types: &[DataType]) -> datafusion::common::Result<DataType> {
92+
Ok(arg_types[0].clone())
93+
}
94+
95+
fn invoke_with_args(
96+
&self,
97+
mut args: ScalarFunctionArgs,
98+
) -> datafusion::common::Result<ColumnarValue> {
99+
Ok(args.args.remove(0))
100+
}
101+
}
102+
103+
#[derive(Debug)]
104+
struct UdfCodec;
105+
106+
impl PhysicalExtensionCodec for UdfCodec {
107+
fn try_decode(
108+
&self,
109+
_: &[u8],
110+
_: &[Arc<dyn ExecutionPlan>],
111+
_registry: &dyn FunctionRegistry,
112+
) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
113+
not_impl_err!("not implemented")
114+
}
115+
116+
fn try_encode(
117+
&self,
118+
_: Arc<dyn ExecutionPlan>,
119+
_: &mut Vec<u8>,
120+
) -> datafusion::common::Result<()> {
121+
not_impl_err!("not implemented")
122+
}
123+
124+
fn try_encode_udaf(
125+
&self,
126+
_: &AggregateUDF,
127+
_: &mut Vec<u8>,
128+
) -> datafusion::common::Result<()> {
129+
not_impl_err!("not implemented")
130+
}
131+
132+
fn try_encode_udwf(
133+
&self,
134+
_: &WindowUDF,
135+
_: &mut Vec<u8>,
136+
) -> datafusion::common::Result<()> {
137+
not_impl_err!("not implemented")
138+
}
139+
140+
fn try_encode_udf(
141+
&self,
142+
node: &ScalarUDF,
143+
_: &mut Vec<u8>,
144+
) -> datafusion::common::Result<()> {
145+
if node.dyn_eq(node) {
146+
return Ok(());
147+
};
148+
not_impl_err!("not implemented")
149+
}
150+
151+
fn try_decode_udf(
152+
&self,
153+
name: &str,
154+
_: &[u8],
155+
) -> datafusion::common::Result<Arc<ScalarUDF>> {
156+
let udf = Udf::new();
157+
if name == udf.name() {
158+
return Ok(Arc::new(ScalarUDF::new_from_impl(udf)));
159+
}
160+
not_impl_err!("not implemented")
161+
}
162+
}
163+
}

0 commit comments

Comments
 (0)