Skip to content

Commit a94a214

Browse files
authored
Fix issue in distributed_codec (#200)
1 parent f177452 commit a94a214

File tree

2 files changed

+143
-3
lines changed

2 files changed

+143
-3
lines changed

src/protobuf/distributed_codec.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use datafusion::arrow::datatypes::Schema;
88
use datafusion::arrow::datatypes::SchemaRef;
99
use datafusion::common::internal_datafusion_err;
1010
use datafusion::error::DataFusionError;
11-
use datafusion::execution::FunctionRegistry;
11+
use datafusion::execution::{FunctionRegistry, SessionStateBuilder};
1212
use datafusion::physical_expr::EquivalenceProperties;
1313
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
1414
use datafusion::physical_plan::{ExecutionPlan, Partitioning, PlanProperties};
@@ -40,7 +40,7 @@ impl PhysicalExtensionCodec for DistributedCodec {
4040
&self,
4141
buf: &[u8],
4242
inputs: &[Arc<dyn ExecutionPlan>],
43-
_registry: &dyn FunctionRegistry,
43+
registry: &dyn FunctionRegistry,
4444
) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
4545
let DistributedExecProto {
4646
node: Some(distributed_exec_node),
@@ -54,7 +54,16 @@ impl PhysicalExtensionCodec for DistributedCodec {
5454
// TODO: The PhysicalExtensionCodec trait doesn't provide access to session state,
5555
// so we create a new SessionContext which loses any custom UDFs, UDAFs, and other
5656
// user configurations. This is a limitation of the current trait design.
57-
let ctx = SessionContext::new();
57+
let state = SessionStateBuilder::new()
58+
.with_scalar_functions(
59+
registry
60+
.udfs()
61+
.iter()
62+
.map(|f| registry.udf(f))
63+
.collect::<Result<Vec<_>, _>>()?,
64+
)
65+
.build();
66+
let ctx = SessionContext::from(state);
5867

5968
fn parse_stage_proto(
6069
proto: Option<StageProto>,

tests/udfs.rs

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
#[cfg(all(feature = "integration", test))]
2+
mod tests {
3+
use arrow::datatypes::{Field, Schema};
4+
use arrow::util::pretty::pretty_format_batches;
5+
use datafusion::arrow::datatypes::DataType;
6+
use datafusion::error::DataFusionError;
7+
use datafusion::execution::{SessionState, SessionStateBuilder};
8+
use datafusion::logical_expr::{
9+
ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
10+
};
11+
use datafusion::physical_expr::expressions::lit;
12+
use datafusion::physical_expr::{Partitioning, ScalarFunctionExpr};
13+
use datafusion::physical_optimizer::PhysicalOptimizerRule;
14+
use datafusion::physical_plan::empty::EmptyExec;
15+
use datafusion::physical_plan::repartition::RepartitionExec;
16+
use datafusion::physical_plan::{ExecutionPlan, execute_stream};
17+
use datafusion_distributed::test_utils::localhost::start_localhost_context;
18+
use datafusion_distributed::{
19+
DistributedPhysicalOptimizerRule, DistributedSessionBuilderContext, assert_snapshot,
20+
display_plan_ascii,
21+
};
22+
use futures::TryStreamExt;
23+
use std::any::Any;
24+
use std::error::Error;
25+
use std::sync::Arc;
26+
27+
#[tokio::test]
28+
async fn test_udf_in_partitioning_field() -> Result<(), Box<dyn Error>> {
29+
async fn build_state(
30+
ctx: DistributedSessionBuilderContext,
31+
) -> Result<SessionState, DataFusionError> {
32+
Ok(SessionStateBuilder::new()
33+
.with_runtime_env(ctx.runtime_env)
34+
.with_default_features()
35+
.with_scalar_functions(vec![udf()])
36+
.build())
37+
}
38+
39+
let (ctx, _guard) = start_localhost_context(3, build_state).await;
40+
41+
let wrap = |input: Arc<dyn ExecutionPlan>| -> Arc<dyn ExecutionPlan> {
42+
Arc::new(
43+
RepartitionExec::try_new(
44+
input,
45+
Partitioning::Hash(
46+
vec![Arc::new(ScalarFunctionExpr::new(
47+
"test_udf",
48+
udf(),
49+
vec![lit(1)],
50+
Arc::new(Field::new("return", DataType::Int32, false)),
51+
Default::default(),
52+
))],
53+
1,
54+
),
55+
)
56+
.unwrap(),
57+
)
58+
};
59+
60+
let node = wrap(wrap(Arc::new(EmptyExec::new(Arc::new(Schema::empty())))));
61+
62+
let physical_distributed = DistributedPhysicalOptimizerRule::default()
63+
.with_network_shuffle_tasks(2)
64+
.with_network_coalesce_tasks(2)
65+
.optimize(node, &Default::default())?;
66+
67+
let physical_distributed_str = display_plan_ascii(physical_distributed.as_ref(), false);
68+
69+
assert_snapshot!(physical_distributed_str,
70+
@r"
71+
┌───── DistributedExec ── Tasks: t0:[p0]
72+
│ [Stage 2] => NetworkShuffleExec: output_partitions=1, input_tasks=2
73+
└──────────────────────────────────────────────────
74+
┌───── Stage 1 ── Tasks: t0:[p0..p1] t1:[p0..p1]
75+
│ RepartitionExec: partitioning=Hash([test_udf(1)], 2), input_partitions=1
76+
│ EmptyExec
77+
└──────────────────────────────────────────────────
78+
",
79+
);
80+
81+
let batches = pretty_format_batches(
82+
&execute_stream(physical_distributed, ctx.task_ctx())?
83+
.try_collect::<Vec<_>>()
84+
.await?,
85+
)?;
86+
87+
assert_snapshot!(batches, @r"
88+
++
89+
++
90+
");
91+
Ok(())
92+
}
93+
94+
fn udf() -> Arc<ScalarUDF> {
95+
Arc::new(ScalarUDF::new_from_impl(Udf::new()))
96+
}
97+
98+
#[derive(Debug, PartialEq, Eq, Hash)]
99+
struct Udf(Signature);
100+
101+
impl Udf {
102+
fn new() -> Self {
103+
Self(Signature::any(1, Volatility::Immutable))
104+
}
105+
}
106+
107+
impl ScalarUDFImpl for Udf {
108+
fn as_any(&self) -> &dyn Any {
109+
self
110+
}
111+
112+
fn name(&self) -> &str {
113+
"test_udf"
114+
}
115+
116+
fn signature(&self) -> &Signature {
117+
&self.0
118+
}
119+
120+
fn return_type(&self, arg_types: &[DataType]) -> datafusion::common::Result<DataType> {
121+
Ok(arg_types[0].clone())
122+
}
123+
124+
fn invoke_with_args(
125+
&self,
126+
mut args: ScalarFunctionArgs,
127+
) -> datafusion::common::Result<ColumnarValue> {
128+
Ok(args.args.remove(0))
129+
}
130+
}
131+
}

0 commit comments

Comments
 (0)