Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
1d31dfe
Merge branch 'main' into udaf-schema-16997
kosiew Aug 8, 2025
66ee0c5
Merge branch 'main' into udaf-schema-16997
kosiew Aug 12, 2025
15e89aa
feat: Implement SchemaBasedAggregateUdf and enhance accumulator argum…
kosiew Aug 7, 2025
fea8c1c
refactor: Rename test function for schema-based aggregate UDF metadata
kosiew Aug 12, 2025
c386ba0
docs(aggregate): clarify AccumulatorArgs schema handling and usage
kosiew Aug 12, 2025
c8678e5
refactor: Extract AccumulatorArgs construction into a separate method…
kosiew Aug 12, 2025
25df8c6
test: Add unit tests for AggregateUDF implementation and args_schema …
kosiew Aug 12, 2025
2991acc
refactor: Consolidate use statements for improved readability in aggr…
kosiew Aug 12, 2025
664367b
refactor(tests): Simplify argument passing in AggregateExprBuilder tests
kosiew Aug 12, 2025
20d7a93
docs: Mark code examples as ignored in AggregateUDF and AccumulatorAr…
kosiew Aug 12, 2025
99824bd
Merge branch 'main' into udaf-schema-16997
kosiew Aug 19, 2025
8692423
Enhance DummyUdf struct by deriving PartialEq, Eq, and Hash traits
kosiew Aug 19, 2025
f2a2d51
Enhance SchemaBasedAggregateUdf struct by deriving PartialEq, Eq, and…
kosiew Aug 19, 2025
f4484be
Merge branch 'main' into udaf-schema-16997
kosiew Aug 21, 2025
35014e0
Merge branch 'main' into udaf-schema-16997
kosiew Aug 27, 2025
9e166ac
feat: refactor DummyUdf initialization and enhance args_schema handling
kosiew Aug 27, 2025
0cd5d33
feat(aggregate): refactor accumulator argument building for clarity
kosiew Aug 27, 2025
03579a1
refactor(tests): reorganize and enhance DummyUdf implementation and t…
kosiew Aug 27, 2025
7175121
Merge branch 'main' into udaf-schema-16997
kosiew Aug 27, 2025
4dd8bb2
Revert to last good point
kosiew Aug 27, 2025
fee3fe9
Merge branch 'main' into udaf-schema-16997
kosiew Aug 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,51 @@ impl Accumulator for MetadataBasedAccumulator {
}
}

#[derive(Debug, PartialEq, Eq, Hash)]
struct SchemaBasedAggregateUdf {
signature: Signature,
}

impl SchemaBasedAggregateUdf {
fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}

impl AggregateUDFImpl for SchemaBasedAggregateUdf {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"schema_based_udf"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::UInt64)
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let field = acc_args.schema.field(0).clone();
let double_output = field
.metadata()
.get("modify_values")
.map(|v| v == "double_output")
.unwrap_or(false);

Ok(Box::new(MetadataBasedAccumulator {
double_output,
curr_sum: 0,
}))
}
}

#[tokio::test]
async fn test_metadata_based_aggregate() -> Result<()> {
let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef;
Expand Down Expand Up @@ -1166,3 +1211,34 @@ async fn test_metadata_based_aggregate_as_window() -> Result<()> {

Ok(())
}

#[tokio::test]
async fn test_schema_based_aggregate_udf_metadata() -> Result<()> {
use datafusion_expr::{expr::FieldMetadata, lit_with_metadata};
use std::collections::BTreeMap;

let ctx = SessionContext::new();
let udf = AggregateUDF::from(SchemaBasedAggregateUdf::new());

let metadata = FieldMetadata::new(BTreeMap::from([(
"modify_values".to_string(),
"double_output".to_string(),
)]));

let expr = udf
.call(vec![lit_with_metadata(
ScalarValue::UInt64(Some(1)),
Some(metadata),
)])
.alias("res");

let plan = LogicalPlanBuilder::empty(true)
.aggregate(Vec::<Expr>::new(), vec![expr])?
.build()?;

let df = DataFrame::new(ctx.state(), plan);
let batches = df.collect().await?;
let array = batches[0].column(0).as_primitive::<UInt64Type>();
assert_eq!(array.value(0), 2);
Ok(())
}
19 changes: 18 additions & 1 deletion datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,24 @@ pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send + Sync {
/// group during query execution.
///
/// acc_args: [`AccumulatorArgs`] contains information about how the
/// aggregate function was called.
/// aggregate function was called. Use `acc_args.exprs` together with
/// `acc_args.schema` to inspect the [`FieldRef`] of each input.
///
/// Example: retrieving metadata and return field for input `i`:
/// ```ignore
/// let metadata = acc_args.schema.field(i).metadata();
/// let field = acc_args.exprs[i].return_field(&acc_args.schema)?;
/// ```
/// Multi-argument functions: `exprs[i]` corresponds to `schema.field(i)`.
/// Mixed inputs (columns and literals): the physical input schema is used
/// when not empty; `acc_args.schema` is only synthesized from literals when
/// the physical schema is empty.
///
/// When an
/// aggregate is invoked with literal values only, `acc_args.schema` is
/// synthesized from those literals so that any field metadata (for
/// example Arrow extension types) is available to the accumulator
/// implementation.
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>>;

/// Return the fields used to store the intermediate state of this accumulator.
Expand Down
19 changes: 17 additions & 2 deletions datafusion/functions-aggregate-common/src/accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,14 @@ pub struct AccumulatorArgs<'a> {
/// The return field of the aggregate function.
pub return_field: FieldRef,

/// The schema of the input arguments
/// The schema of the input arguments.
///
/// This schema contains the fields corresponding to the function’s input
/// expressions (`exprs`). When an aggregate is invoked with only literal
/// values, this schema is synthesized from those literals to preserve
/// field-level metadata (such as Arrow extension types). In mixed column
/// and literal inputs, metadata from the physical schema takes precedence;
/// synthesized metadata is only used when the physical schema is empty.
pub schema: &'a Schema,

/// Whether to ignore nulls.
Expand Down Expand Up @@ -65,7 +72,15 @@ pub struct AccumulatorArgs<'a> {
/// ```
pub is_distinct: bool,

/// The physical expression of arguments the aggregate function takes.
/// The physical expressions for the aggregate function's arguments.
/// Use these expressions together with [`Self::schema`] to obtain the
/// [`FieldRef`] of each input via `expr.return_field(schema)`.
///
/// Example:
/// ```ignore
/// let input_field = exprs[i].return_field(&schema)?;
/// ```
/// Note: physical schema metadata takes precedence in mixed inputs.
pub exprs: &'a [Arc<dyn PhysicalExpr>],
}

Expand Down
198 changes: 145 additions & 53 deletions datafusion/physical-expr/src/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,92 @@ pub(crate) mod groups_accumulator {
accumulate::NullState, GroupsAccumulatorAdapter,
};
}

#[cfg(test)]
mod tests {
use super::*;
use crate::expressions::Literal;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::ScalarValue;
use datafusion_expr::{AggregateUDF, AggregateUDFImpl, Signature, Volatility};
use std::{any::Any, sync::Arc};
#[derive(Debug, PartialEq, Eq, Hash)]
struct DummyUdf {
signature: Signature,
}

impl DummyUdf {
fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}

impl AggregateUDFImpl for DummyUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"dummy"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
Ok(DataType::UInt64)
}
fn accumulator(&self, _args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
unimplemented!()
}
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
unimplemented!()
}
fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
true
}
fn create_groups_accumulator(
&self,
_args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
unimplemented!()
}
}

#[test]
fn test_args_schema_and_groups_path() {
// literal-only: empty physical schema synthesizes schema from literal expr
let udf = Arc::new(AggregateUDF::from(DummyUdf::new()));
let lit_expr =
Arc::new(Literal::new(ScalarValue::UInt32(Some(1)))) as Arc<dyn PhysicalExpr>;
let agg =
AggregateExprBuilder::new(Arc::clone(&udf), vec![Arc::clone(&lit_expr)])
.alias("x")
.schema(Arc::new(Schema::empty()))
.build()
.unwrap();
match agg.args_schema() {
Cow::Owned(s) => assert_eq!(s.field(0).name(), "lit"),
_ => panic!("expected owned schema"),
}
assert!(agg.groups_accumulator_supported());

// non-empty physical schema should be borrowed
let f = Field::new("b", DataType::Int32, false);
let phys_schema = Schema::new(vec![f.clone()]);
let col_expr = Arc::new(Column::new("b", 0)) as Arc<dyn PhysicalExpr>;
let agg2 = AggregateExprBuilder::new(udf, vec![col_expr])
.alias("x")
.schema(Arc::new(phys_schema))
.build()
.unwrap();
match agg2.args_schema() {
Cow::Borrowed(s) => assert_eq!(s.field(0).name(), "b"),
_ => panic!("expected borrowed schema"),
}
assert!(agg2.groups_accumulator_supported());
}
}
pub(crate) mod stats {
pub use datafusion_functions_aggregate_common::stats::StatsType;
}
Expand All @@ -33,25 +119,25 @@ pub mod utils {
DecimalAverager, Hashable,
};
}

use std::fmt::Debug;
use std::sync::Arc;

use crate::expressions::Column;

use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, FieldRef, Schema, SchemaRef};
use arrow::{
compute::SortOptions,
datatypes::{DataType, FieldRef, Schema, SchemaRef},
};
use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue};
use datafusion_expr::{AggregateUDF, ReversedUDAF, SetMonotonicity};
use datafusion_expr_common::accumulator::Accumulator;
use datafusion_expr_common::groups_accumulator::GroupsAccumulator;
use datafusion_expr_common::type_coercion::aggregates::check_arg_count;
use datafusion_functions_aggregate_common::accumulator::{
AccumulatorArgs, StateFieldsArgs,
use datafusion_expr_common::{
accumulator::Accumulator, groups_accumulator::GroupsAccumulator,
type_coercion::aggregates::check_arg_count,
};
use datafusion_functions_aggregate_common::{
accumulator::{AccumulatorArgs, StateFieldsArgs},
order::AggregateOrderSensitivity,
};
use datafusion_physical_expr_common::{
physical_expr::PhysicalExpr, sort_expr::PhysicalSortExpr,
};
use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
use std::{borrow::Cow, fmt::Debug, sync::Arc};

/// Builder for physical [`AggregateFunctionExpr`]
///
Expand Down Expand Up @@ -376,21 +462,52 @@ impl AggregateFunctionExpr {
.into()
}

/// the accumulator used to accumulate values from the expressions.
/// the accumulator expects the same number of arguments as `expressions` and must
/// return states with the same description as `state_fields`
pub fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
let acc_args = AccumulatorArgs {
/// Returns a schema containing the fields corresponding to this
/// aggregate's input expressions in the same order as `input_fields`/`exprs`.
///
/// If the physical input schema is empty (literal-only inputs),
/// synthesizes a new schema from the literal expressions to preserve
/// field-level metadata (such as Arrow extension types).
/// Field order is guaranteed to match the order of input expressions.
/// In mixed column and literal inputs, existing physical schema fields
/// win; synthesized metadata is only applied when the physical schema
/// has no fields.
///
/// Uses [`std::borrow::Cow`] to avoid allocation when the existing
/// schema is non-empty. For micro-optimizations, implementers may
/// cache the owned schema if multiple calls are made per instance.
fn args_schema(&self) -> Cow<'_, Schema> {
if self.schema.fields().is_empty() {
Cow::Owned(Schema::new(
self.input_fields
.iter()
.map(|f| f.as_ref().clone())
.collect::<Vec<_>>(),
))
} else {
Cow::Borrowed(&self.schema)
}
}
/// Construct AccumulatorArgs for this aggregate using a given schema slice.
fn make_acc_args<'a>(&'a self, schema: &'a Schema) -> AccumulatorArgs<'a> {
AccumulatorArgs {
return_field: Arc::clone(&self.return_field),
schema: &self.schema,
schema,
ignore_nulls: self.ignore_nulls,
order_bys: self.order_bys.as_ref(),
is_distinct: self.is_distinct,
name: &self.name,
is_reversed: self.is_reversed,
exprs: &self.args,
};
}
}

/// the accumulator used to accumulate values from the expressions.
/// the accumulator expects the same number of arguments as `expressions` and must
/// return states with the same description as `state_fields`
pub fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
let schema = self.args_schema();
let acc_args = self.make_acc_args(schema.as_ref());
self.fun.accumulator(acc_args)
}

Expand Down Expand Up @@ -464,17 +581,8 @@ impl AggregateFunctionExpr {

/// Creates accumulator implementation that supports retract
pub fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
let args = AccumulatorArgs {
return_field: Arc::clone(&self.return_field),
schema: &self.schema,
ignore_nulls: self.ignore_nulls,
order_bys: self.order_bys.as_ref(),
is_distinct: self.is_distinct,
name: &self.name,
is_reversed: self.is_reversed,
exprs: &self.args,
};

let schema = self.args_schema();
let args = self.make_acc_args(schema.as_ref());
let accumulator = self.fun.create_sliding_accumulator(args)?;

// Accumulators that have window frame startings different
Expand Down Expand Up @@ -533,16 +641,8 @@ impl AggregateFunctionExpr {
/// [`GroupsAccumulator`] implementation. If this returns true,
/// `[Self::create_groups_accumulator`] will be called.
pub fn groups_accumulator_supported(&self) -> bool {
let args = AccumulatorArgs {
return_field: Arc::clone(&self.return_field),
schema: &self.schema,
ignore_nulls: self.ignore_nulls,
order_bys: self.order_bys.as_ref(),
is_distinct: self.is_distinct,
name: &self.name,
is_reversed: self.is_reversed,
exprs: &self.args,
};
let schema = self.args_schema();
let args = self.make_acc_args(schema.as_ref());
self.fun.groups_accumulator_supported(args)
}

Expand All @@ -552,16 +652,8 @@ impl AggregateFunctionExpr {
/// For maximum performance, a [`GroupsAccumulator`] should be
/// implemented in addition to [`Accumulator`].
pub fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
let args = AccumulatorArgs {
return_field: Arc::clone(&self.return_field),
schema: &self.schema,
ignore_nulls: self.ignore_nulls,
order_bys: self.order_bys.as_ref(),
is_distinct: self.is_distinct,
name: &self.name,
is_reversed: self.is_reversed,
exprs: &self.args,
};
let schema = self.args_schema();
let args = self.make_acc_args(schema.as_ref());
self.fun.create_groups_accumulator(args)
}

Expand Down