From c5d6114c8c61e7bf7b9a2c1d25d1a1bd8b93b010 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Wed, 22 Oct 2025 10:46:54 +0200 Subject: [PATCH] Fix issue in distributed_codec --- src/protobuf/distributed_codec.rs | 15 +++- tests/udfs.rs | 131 ++++++++++++++++++++++++++++++ 2 files changed, 143 insertions(+), 3 deletions(-) create mode 100644 tests/udfs.rs diff --git a/src/protobuf/distributed_codec.rs b/src/protobuf/distributed_codec.rs index ec51d8e..8b2f9d1 100644 --- a/src/protobuf/distributed_codec.rs +++ b/src/protobuf/distributed_codec.rs @@ -8,7 +8,7 @@ use datafusion::arrow::datatypes::Schema; use datafusion::arrow::datatypes::SchemaRef; use datafusion::common::internal_datafusion_err; use datafusion::error::DataFusionError; -use datafusion::execution::FunctionRegistry; +use datafusion::execution::{FunctionRegistry, SessionStateBuilder}; use datafusion::physical_expr::EquivalenceProperties; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::{ExecutionPlan, Partitioning, PlanProperties}; @@ -40,7 +40,7 @@ impl PhysicalExtensionCodec for DistributedCodec { &self, buf: &[u8], inputs: &[Arc], - _registry: &dyn FunctionRegistry, + registry: &dyn FunctionRegistry, ) -> datafusion::common::Result> { let DistributedExecProto { node: Some(distributed_exec_node), @@ -54,7 +54,16 @@ impl PhysicalExtensionCodec for DistributedCodec { // TODO: The PhysicalExtensionCodec trait doesn't provide access to session state, // so we create a new SessionContext which loses any custom UDFs, UDAFs, and other // user configurations. This is a limitation of the current trait design. - let ctx = SessionContext::new(); + let state = SessionStateBuilder::new() + .with_scalar_functions( + registry + .udfs() + .iter() + .map(|f| registry.udf(f)) + .collect::, _>>()?, + ) + .build(); + let ctx = SessionContext::from(state); fn parse_stage_proto( proto: Option, diff --git a/tests/udfs.rs b/tests/udfs.rs new file mode 100644 index 0000000..1e67d05 --- /dev/null +++ b/tests/udfs.rs @@ -0,0 +1,131 @@ +#[cfg(all(feature = "integration", test))] +mod tests { + use arrow::datatypes::{Field, Schema}; + use arrow::util::pretty::pretty_format_batches; + use datafusion::arrow::datatypes::DataType; + use datafusion::error::DataFusionError; + use datafusion::execution::{SessionState, SessionStateBuilder}; + use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, + }; + use datafusion::physical_expr::expressions::lit; + use datafusion::physical_expr::{Partitioning, ScalarFunctionExpr}; + use datafusion::physical_optimizer::PhysicalOptimizerRule; + use datafusion::physical_plan::empty::EmptyExec; + use datafusion::physical_plan::repartition::RepartitionExec; + use datafusion::physical_plan::{ExecutionPlan, execute_stream}; + use datafusion_distributed::test_utils::localhost::start_localhost_context; + use datafusion_distributed::{ + DistributedPhysicalOptimizerRule, DistributedSessionBuilderContext, assert_snapshot, + display_plan_ascii, + }; + use futures::TryStreamExt; + use std::any::Any; + use std::error::Error; + use std::sync::Arc; + + #[tokio::test] + async fn test_udf_in_partitioning_field() -> Result<(), Box> { + async fn build_state( + ctx: DistributedSessionBuilderContext, + ) -> Result { + Ok(SessionStateBuilder::new() + .with_runtime_env(ctx.runtime_env) + .with_default_features() + .with_scalar_functions(vec![udf()]) + .build()) + } + + let (ctx, _guard) = start_localhost_context(3, build_state).await; + + let wrap = |input: Arc| -> Arc { + Arc::new( + RepartitionExec::try_new( + input, + Partitioning::Hash( + vec![Arc::new(ScalarFunctionExpr::new( + "test_udf", + udf(), + vec![lit(1)], + Arc::new(Field::new("return", DataType::Int32, false)), + Default::default(), + ))], + 1, + ), + ) + .unwrap(), + ) + }; + + let node = wrap(wrap(Arc::new(EmptyExec::new(Arc::new(Schema::empty()))))); + + let physical_distributed = DistributedPhysicalOptimizerRule::default() + .with_network_shuffle_tasks(2) + .with_network_coalesce_tasks(2) + .optimize(node, &Default::default())?; + + let physical_distributed_str = display_plan_ascii(physical_distributed.as_ref(), false); + + assert_snapshot!(physical_distributed_str, + @r" + ┌───── DistributedExec ── Tasks: t0:[p0] + │ [Stage 2] => NetworkShuffleExec: output_partitions=1, input_tasks=2 + └────────────────────────────────────────────────── + ┌───── Stage 1 ── Tasks: t0:[p0..p1] t1:[p0..p1] + │ RepartitionExec: partitioning=Hash([test_udf(1)], 2), input_partitions=1 + │ EmptyExec + └────────────────────────────────────────────────── + ", + ); + + let batches = pretty_format_batches( + &execute_stream(physical_distributed, ctx.task_ctx())? + .try_collect::>() + .await?, + )?; + + assert_snapshot!(batches, @r" + ++ + ++ + "); + Ok(()) + } + + fn udf() -> Arc { + Arc::new(ScalarUDF::new_from_impl(Udf::new())) + } + + #[derive(Debug, PartialEq, Eq, Hash)] + struct Udf(Signature); + + impl Udf { + fn new() -> Self { + Self(Signature::any(1, Volatility::Immutable)) + } + } + + impl ScalarUDFImpl for Udf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "test_udf" + } + + fn signature(&self) -> &Signature { + &self.0 + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion::common::Result { + Ok(arg_types[0].clone()) + } + + fn invoke_with_args( + &self, + mut args: ScalarFunctionArgs, + ) -> datafusion::common::Result { + Ok(args.args.remove(0)) + } + } +}