diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index a5b073b147ea..c339716248cc 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -1010,6 +1010,51 @@ impl Accumulator for MetadataBasedAccumulator { } } +#[derive(Debug, PartialEq, Eq, Hash)] +struct SchemaBasedAggregateUdf { + signature: Signature, +} + +impl SchemaBasedAggregateUdf { + fn new() -> Self { + Self { + signature: Signature::any(1, Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for SchemaBasedAggregateUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "schema_based_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::UInt64) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let field = acc_args.schema.field(0).clone(); + let double_output = field + .metadata() + .get("modify_values") + .map(|v| v == "double_output") + .unwrap_or(false); + + Ok(Box::new(MetadataBasedAccumulator { + double_output, + curr_sum: 0, + })) + } +} + #[tokio::test] async fn test_metadata_based_aggregate() -> Result<()> { let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef; @@ -1166,3 +1211,34 @@ async fn test_metadata_based_aggregate_as_window() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_schema_based_aggregate_udf_metadata() -> Result<()> { + use datafusion_expr::{expr::FieldMetadata, lit_with_metadata}; + use std::collections::BTreeMap; + + let ctx = SessionContext::new(); + let udf = AggregateUDF::from(SchemaBasedAggregateUdf::new()); + + let metadata = FieldMetadata::new(BTreeMap::from([( + "modify_values".to_string(), + "double_output".to_string(), + )])); + + let expr = udf + .call(vec![lit_with_metadata( + ScalarValue::UInt64(Some(1)), + Some(metadata), + )]) + .alias("res"); + + let plan = LogicalPlanBuilder::empty(true) + .aggregate(Vec::::new(), vec![expr])? + .build()?; + + let df = DataFrame::new(ctx.state(), plan); + let batches = df.collect().await?; + let array = batches[0].column(0).as_primitive::(); + assert_eq!(array.value(0), 2); + Ok(()) +} diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index e1d36da9c4db..189e82859adf 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -752,7 +752,24 @@ pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send + Sync { /// group during query execution. /// /// acc_args: [`AccumulatorArgs`] contains information about how the - /// aggregate function was called. + /// aggregate function was called. Use `acc_args.exprs` together with + /// `acc_args.schema` to inspect the [`FieldRef`] of each input. + /// + /// Example: retrieving metadata and return field for input `i`: + /// ```ignore + /// let metadata = acc_args.schema.field(i).metadata(); + /// let field = acc_args.exprs[i].return_field(&acc_args.schema)?; + /// ``` + /// Multi-argument functions: `exprs[i]` corresponds to `schema.field(i)`. + /// Mixed inputs (columns and literals): the physical input schema is used + /// when not empty; `acc_args.schema` is only synthesized from literals when + /// the physical schema is empty. + /// + /// When an + /// aggregate is invoked with literal values only, `acc_args.schema` is + /// synthesized from those literals so that any field metadata (for + /// example Arrow extension types) is available to the accumulator + /// implementation. fn accumulator(&self, acc_args: AccumulatorArgs) -> Result>; /// Return the fields used to store the intermediate state of this accumulator. diff --git a/datafusion/functions-aggregate-common/src/accumulator.rs b/datafusion/functions-aggregate-common/src/accumulator.rs index e0f7af1fb38e..c11b705037ec 100644 --- a/datafusion/functions-aggregate-common/src/accumulator.rs +++ b/datafusion/functions-aggregate-common/src/accumulator.rs @@ -30,7 +30,14 @@ pub struct AccumulatorArgs<'a> { /// The return field of the aggregate function. pub return_field: FieldRef, - /// The schema of the input arguments + /// The schema of the input arguments. + /// + /// This schema contains the fields corresponding to the function’s input + /// expressions (`exprs`). When an aggregate is invoked with only literal + /// values, this schema is synthesized from those literals to preserve + /// field-level metadata (such as Arrow extension types). In mixed column + /// and literal inputs, metadata from the physical schema takes precedence; + /// synthesized metadata is only used when the physical schema is empty. pub schema: &'a Schema, /// Whether to ignore nulls. @@ -65,7 +72,15 @@ pub struct AccumulatorArgs<'a> { /// ``` pub is_distinct: bool, - /// The physical expression of arguments the aggregate function takes. + /// The physical expressions for the aggregate function's arguments. + /// Use these expressions together with [`Self::schema`] to obtain the + /// [`FieldRef`] of each input via `expr.return_field(schema)`. + /// + /// Example: + /// ```ignore + /// let input_field = exprs[i].return_field(&schema)?; + /// ``` + /// Note: physical schema metadata takes precedence in mixed inputs. pub exprs: &'a [Arc], } diff --git a/datafusion/physical-expr/src/aggregate.rs b/datafusion/physical-expr/src/aggregate.rs index 19d2ecc924dd..cef7ed0cb19a 100644 --- a/datafusion/physical-expr/src/aggregate.rs +++ b/datafusion/physical-expr/src/aggregate.rs @@ -24,6 +24,92 @@ pub(crate) mod groups_accumulator { accumulate::NullState, GroupsAccumulatorAdapter, }; } + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::Literal; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::ScalarValue; + use datafusion_expr::{AggregateUDF, AggregateUDFImpl, Signature, Volatility}; + use std::{any::Any, sync::Arc}; + #[derive(Debug, PartialEq, Eq, Hash)] + struct DummyUdf { + signature: Signature, + } + + impl DummyUdf { + fn new() -> Self { + Self { + signature: Signature::any(1, Volatility::Immutable), + } + } + } + + impl AggregateUDFImpl for DummyUdf { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "dummy" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _args: &[DataType]) -> Result { + Ok(DataType::UInt64) + } + fn accumulator(&self, _args: AccumulatorArgs) -> Result> { + unimplemented!() + } + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + unimplemented!() + } + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + unimplemented!() + } + } + + #[test] + fn test_args_schema_and_groups_path() { + // literal-only: empty physical schema synthesizes schema from literal expr + let udf = Arc::new(AggregateUDF::from(DummyUdf::new())); + let lit_expr = + Arc::new(Literal::new(ScalarValue::UInt32(Some(1)))) as Arc; + let agg = + AggregateExprBuilder::new(Arc::clone(&udf), vec![Arc::clone(&lit_expr)]) + .alias("x") + .schema(Arc::new(Schema::empty())) + .build() + .unwrap(); + match agg.args_schema() { + Cow::Owned(s) => assert_eq!(s.field(0).name(), "lit"), + _ => panic!("expected owned schema"), + } + assert!(agg.groups_accumulator_supported()); + + // non-empty physical schema should be borrowed + let f = Field::new("b", DataType::Int32, false); + let phys_schema = Schema::new(vec![f.clone()]); + let col_expr = Arc::new(Column::new("b", 0)) as Arc; + let agg2 = AggregateExprBuilder::new(udf, vec![col_expr]) + .alias("x") + .schema(Arc::new(phys_schema)) + .build() + .unwrap(); + match agg2.args_schema() { + Cow::Borrowed(s) => assert_eq!(s.field(0).name(), "b"), + _ => panic!("expected borrowed schema"), + } + assert!(agg2.groups_accumulator_supported()); + } +} pub(crate) mod stats { pub use datafusion_functions_aggregate_common::stats::StatsType; } @@ -33,25 +119,25 @@ pub mod utils { DecimalAverager, Hashable, }; } - -use std::fmt::Debug; -use std::sync::Arc; - use crate::expressions::Column; - -use arrow::compute::SortOptions; -use arrow::datatypes::{DataType, FieldRef, Schema, SchemaRef}; +use arrow::{ + compute::SortOptions, + datatypes::{DataType, FieldRef, Schema, SchemaRef}, +}; use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue}; use datafusion_expr::{AggregateUDF, ReversedUDAF, SetMonotonicity}; -use datafusion_expr_common::accumulator::Accumulator; -use datafusion_expr_common::groups_accumulator::GroupsAccumulator; -use datafusion_expr_common::type_coercion::aggregates::check_arg_count; -use datafusion_functions_aggregate_common::accumulator::{ - AccumulatorArgs, StateFieldsArgs, +use datafusion_expr_common::{ + accumulator::Accumulator, groups_accumulator::GroupsAccumulator, + type_coercion::aggregates::check_arg_count, +}; +use datafusion_functions_aggregate_common::{ + accumulator::{AccumulatorArgs, StateFieldsArgs}, + order::AggregateOrderSensitivity, +}; +use datafusion_physical_expr_common::{ + physical_expr::PhysicalExpr, sort_expr::PhysicalSortExpr, }; -use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; +use std::{borrow::Cow, fmt::Debug, sync::Arc}; /// Builder for physical [`AggregateFunctionExpr`] /// @@ -376,21 +462,52 @@ impl AggregateFunctionExpr { .into() } - /// the accumulator used to accumulate values from the expressions. - /// the accumulator expects the same number of arguments as `expressions` and must - /// return states with the same description as `state_fields` - pub fn create_accumulator(&self) -> Result> { - let acc_args = AccumulatorArgs { + /// Returns a schema containing the fields corresponding to this + /// aggregate's input expressions in the same order as `input_fields`/`exprs`. + /// + /// If the physical input schema is empty (literal-only inputs), + /// synthesizes a new schema from the literal expressions to preserve + /// field-level metadata (such as Arrow extension types). + /// Field order is guaranteed to match the order of input expressions. + /// In mixed column and literal inputs, existing physical schema fields + /// win; synthesized metadata is only applied when the physical schema + /// has no fields. + /// + /// Uses [`std::borrow::Cow`] to avoid allocation when the existing + /// schema is non-empty. For micro-optimizations, implementers may + /// cache the owned schema if multiple calls are made per instance. + fn args_schema(&self) -> Cow<'_, Schema> { + if self.schema.fields().is_empty() { + Cow::Owned(Schema::new( + self.input_fields + .iter() + .map(|f| f.as_ref().clone()) + .collect::>(), + )) + } else { + Cow::Borrowed(&self.schema) + } + } + /// Construct AccumulatorArgs for this aggregate using a given schema slice. + fn make_acc_args<'a>(&'a self, schema: &'a Schema) -> AccumulatorArgs<'a> { + AccumulatorArgs { return_field: Arc::clone(&self.return_field), - schema: &self.schema, + schema, ignore_nulls: self.ignore_nulls, order_bys: self.order_bys.as_ref(), is_distinct: self.is_distinct, name: &self.name, is_reversed: self.is_reversed, exprs: &self.args, - }; + } + } + /// the accumulator used to accumulate values from the expressions. + /// the accumulator expects the same number of arguments as `expressions` and must + /// return states with the same description as `state_fields` + pub fn create_accumulator(&self) -> Result> { + let schema = self.args_schema(); + let acc_args = self.make_acc_args(schema.as_ref()); self.fun.accumulator(acc_args) } @@ -464,17 +581,8 @@ impl AggregateFunctionExpr { /// Creates accumulator implementation that supports retract pub fn create_sliding_accumulator(&self) -> Result> { - let args = AccumulatorArgs { - return_field: Arc::clone(&self.return_field), - schema: &self.schema, - ignore_nulls: self.ignore_nulls, - order_bys: self.order_bys.as_ref(), - is_distinct: self.is_distinct, - name: &self.name, - is_reversed: self.is_reversed, - exprs: &self.args, - }; - + let schema = self.args_schema(); + let args = self.make_acc_args(schema.as_ref()); let accumulator = self.fun.create_sliding_accumulator(args)?; // Accumulators that have window frame startings different @@ -533,16 +641,8 @@ impl AggregateFunctionExpr { /// [`GroupsAccumulator`] implementation. If this returns true, /// `[Self::create_groups_accumulator`] will be called. pub fn groups_accumulator_supported(&self) -> bool { - let args = AccumulatorArgs { - return_field: Arc::clone(&self.return_field), - schema: &self.schema, - ignore_nulls: self.ignore_nulls, - order_bys: self.order_bys.as_ref(), - is_distinct: self.is_distinct, - name: &self.name, - is_reversed: self.is_reversed, - exprs: &self.args, - }; + let schema = self.args_schema(); + let args = self.make_acc_args(schema.as_ref()); self.fun.groups_accumulator_supported(args) } @@ -552,16 +652,8 @@ impl AggregateFunctionExpr { /// For maximum performance, a [`GroupsAccumulator`] should be /// implemented in addition to [`Accumulator`]. pub fn create_groups_accumulator(&self) -> Result> { - let args = AccumulatorArgs { - return_field: Arc::clone(&self.return_field), - schema: &self.schema, - ignore_nulls: self.ignore_nulls, - order_bys: self.order_bys.as_ref(), - is_distinct: self.is_distinct, - name: &self.name, - is_reversed: self.is_reversed, - exprs: &self.args, - }; + let schema = self.args_schema(); + let args = self.make_acc_args(schema.as_ref()); self.fun.create_groups_accumulator(args) }