Skip to content

Commit fb6c9a8

Browse files
committed
feat: Implement SchemaBasedAggregateUdf and enhance accumulator argument handling
1 parent 1d31dfe commit fb6c9a8

File tree

4 files changed

+122
-7
lines changed

4 files changed

+122
-7
lines changed

datafusion/core/tests/user_defined/user_defined_aggregates.rs

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,6 +1014,51 @@ impl Accumulator for MetadataBasedAccumulator {
10141014
}
10151015
}
10161016

1017+
#[derive(Debug)]
1018+
struct SchemaBasedAggregateUdf {
1019+
signature: Signature,
1020+
}
1021+
1022+
impl SchemaBasedAggregateUdf {
1023+
fn new() -> Self {
1024+
Self {
1025+
signature: Signature::any(1, Volatility::Immutable),
1026+
}
1027+
}
1028+
}
1029+
1030+
impl AggregateUDFImpl for SchemaBasedAggregateUdf {
1031+
fn as_any(&self) -> &dyn Any {
1032+
self
1033+
}
1034+
1035+
fn name(&self) -> &str {
1036+
"schema_based_udf"
1037+
}
1038+
1039+
fn signature(&self) -> &Signature {
1040+
&self.signature
1041+
}
1042+
1043+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
1044+
Ok(DataType::UInt64)
1045+
}
1046+
1047+
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
1048+
let field = acc_args.schema.field(0).clone();
1049+
let double_output = field
1050+
.metadata()
1051+
.get("modify_values")
1052+
.map(|v| v == "double_output")
1053+
.unwrap_or(false);
1054+
1055+
Ok(Box::new(MetadataBasedAccumulator {
1056+
double_output,
1057+
curr_sum: 0,
1058+
}))
1059+
}
1060+
}
1061+
10171062
#[tokio::test]
10181063
async fn test_metadata_based_aggregate() -> Result<()> {
10191064
let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef;
@@ -1170,3 +1215,34 @@ async fn test_metadata_based_aggregate_as_window() -> Result<()> {
11701215

11711216
Ok(())
11721217
}
1218+
1219+
#[tokio::test]
1220+
async fn test_scalar_udaf_schema_metadata() -> Result<()> {
1221+
use datafusion_expr::{expr::FieldMetadata, lit_with_metadata};
1222+
use std::collections::BTreeMap;
1223+
1224+
let ctx = SessionContext::new();
1225+
let udf = AggregateUDF::from(SchemaBasedAggregateUdf::new());
1226+
1227+
let metadata = FieldMetadata::new(BTreeMap::from([(
1228+
"modify_values".to_string(),
1229+
"double_output".to_string(),
1230+
)]));
1231+
1232+
let expr = udf
1233+
.call(vec![lit_with_metadata(
1234+
ScalarValue::UInt64(Some(1)),
1235+
Some(metadata),
1236+
)])
1237+
.alias("res");
1238+
1239+
let plan = LogicalPlanBuilder::empty(true)
1240+
.aggregate(Vec::<Expr>::new(), vec![expr])?
1241+
.build()?;
1242+
1243+
let df = DataFrame::new(ctx.state(), plan);
1244+
let batches = df.collect().await?;
1245+
let array = batches[0].column(0).as_primitive::<UInt64Type>();
1246+
assert_eq!(array.value(0), 2);
1247+
Ok(())
1248+
}

datafusion/expr/src/udaf.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,12 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
754754
/// group during query execution.
755755
///
756756
/// acc_args: [`AccumulatorArgs`] contains information about how the
757-
/// aggregate function was called.
757+
/// aggregate function was called. Use `acc_args.exprs` together with
758+
/// `acc_args.schema` to inspect the [`FieldRef`] of each input. When an
759+
/// aggregate is invoked with literal values only, `acc_args.schema` is
760+
/// synthesized from those literals so that any field metadata (for
761+
/// example Arrow extension types) is available to the accumulator
762+
/// implementation.
758763
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>>;
759764

760765
/// Return the fields used to store the intermediate state of this accumulator.

datafusion/functions-aggregate-common/src/accumulator.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,13 @@ pub struct AccumulatorArgs<'a> {
3030
/// The return field of the aggregate function.
3131
pub return_field: FieldRef,
3232

33-
/// The schema of the input arguments
33+
/// The schema of the input arguments.
34+
///
35+
/// This schema contains the fields corresponding to [`Self::exprs`]. When
36+
/// an aggregate is invoked with only literal values, this schema is
37+
/// synthesized from those literals so that field-level metadata (such as
38+
/// Arrow extension types) remains available to accumulator
39+
/// implementations.
3440
pub schema: &'a Schema,
3541

3642
/// Whether to ignore nulls.
@@ -65,7 +71,9 @@ pub struct AccumulatorArgs<'a> {
6571
/// ```
6672
pub is_distinct: bool,
6773

68-
/// The physical expression of arguments the aggregate function takes.
74+
/// The physical expressions for the aggregate function's arguments.
75+
/// Use these expressions together with [`Self::schema`] to obtain the
76+
/// [`FieldRef`] of each input via `expr.return_field(schema)`.
6977
pub exprs: &'a [Arc<dyn PhysicalExpr>],
7078
}
7179

datafusion/physical-expr/src/aggregate.rs

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ pub mod utils {
3434
};
3535
}
3636

37+
use std::borrow::Cow;
3738
use std::fmt::Debug;
3839
use std::sync::Arc;
3940

@@ -376,13 +377,35 @@ impl AggregateFunctionExpr {
376377
.into()
377378
}
378379

380+
/// Returns a schema containing the fields corresponding to this
381+
/// aggregate's input expressions.
382+
///
383+
/// When an aggregate is invoked with only literal values, the
384+
/// physical input schema is empty. In this case a schema is
385+
/// synthesized from the literal expressions so that field metadata
386+
/// (such as Arrow extension types) remains available to the
387+
/// [`Accumulator`].
388+
fn args_schema(&self) -> Cow<'_, Schema> {
389+
if self.schema.fields().is_empty() {
390+
Cow::Owned(Schema::new(
391+
self.input_fields
392+
.iter()
393+
.map(|f| f.as_ref().clone())
394+
.collect::<Vec<_>>(),
395+
))
396+
} else {
397+
Cow::Borrowed(&self.schema)
398+
}
399+
}
400+
379401
/// the accumulator used to accumulate values from the expressions.
380402
/// the accumulator expects the same number of arguments as `expressions` and must
381403
/// return states with the same description as `state_fields`
382404
pub fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
405+
let schema = self.args_schema();
383406
let acc_args = AccumulatorArgs {
384407
return_field: Arc::clone(&self.return_field),
385-
schema: &self.schema,
408+
schema: schema.as_ref(),
386409
ignore_nulls: self.ignore_nulls,
387410
order_bys: self.order_bys.as_ref(),
388411
is_distinct: self.is_distinct,
@@ -464,9 +487,10 @@ impl AggregateFunctionExpr {
464487

465488
/// Creates accumulator implementation that supports retract
466489
pub fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
490+
let schema = self.args_schema();
467491
let args = AccumulatorArgs {
468492
return_field: Arc::clone(&self.return_field),
469-
schema: &self.schema,
493+
schema: schema.as_ref(),
470494
ignore_nulls: self.ignore_nulls,
471495
order_bys: self.order_bys.as_ref(),
472496
is_distinct: self.is_distinct,
@@ -533,9 +557,10 @@ impl AggregateFunctionExpr {
533557
/// [`GroupsAccumulator`] implementation. If this returns true,
534558
/// `[Self::create_groups_accumulator`] will be called.
535559
pub fn groups_accumulator_supported(&self) -> bool {
560+
let schema = self.args_schema();
536561
let args = AccumulatorArgs {
537562
return_field: Arc::clone(&self.return_field),
538-
schema: &self.schema,
563+
schema: schema.as_ref(),
539564
ignore_nulls: self.ignore_nulls,
540565
order_bys: self.order_bys.as_ref(),
541566
is_distinct: self.is_distinct,
@@ -552,9 +577,10 @@ impl AggregateFunctionExpr {
552577
/// For maximum performance, a [`GroupsAccumulator`] should be
553578
/// implemented in addition to [`Accumulator`].
554579
pub fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
580+
let schema = self.args_schema();
555581
let args = AccumulatorArgs {
556582
return_field: Arc::clone(&self.return_field),
557-
schema: &self.schema,
583+
schema: schema.as_ref(),
558584
ignore_nulls: self.ignore_nulls,
559585
order_bys: self.order_bys.as_ref(),
560586
is_distinct: self.is_distinct,

0 commit comments

Comments
 (0)