Skip to content

Commit 25df8c6

Browse files
committed
test: Add unit tests for AggregateUDF implementation and args_schema behavior
1 parent c8678e5 commit 25df8c6

File tree

1 file changed

+88
-1
lines changed

1 file changed

+88
-1
lines changed

datafusion/physical-expr/src/aggregate.rs

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,93 @@ pub(crate) mod groups_accumulator {
2424
accumulate::NullState, GroupsAccumulatorAdapter,
2525
};
2626
}
27+
28+
#[cfg(test)]
29+
mod tests {
30+
use super::*;
31+
use crate::expressions::Literal;
32+
use arrow::datatypes::{DataType, Field, Schema};
33+
use datafusion_common::ScalarValue;
34+
use datafusion_expr::{AggregateUDF, AggregateUDFImpl, Signature, Volatility};
35+
use std::any::Any;
36+
use std::sync::Arc;
37+
38+
#[derive(Debug)]
39+
struct DummyUdf {
40+
signature: Signature,
41+
}
42+
43+
impl DummyUdf {
44+
fn new() -> Self {
45+
Self {
46+
signature: Signature::any(1, Volatility::Immutable),
47+
}
48+
}
49+
}
50+
51+
impl AggregateUDFImpl for DummyUdf {
52+
fn as_any(&self) -> &dyn Any {
53+
self
54+
}
55+
fn name(&self) -> &str {
56+
"dummy"
57+
}
58+
fn signature(&self) -> &Signature {
59+
&self.signature
60+
}
61+
fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
62+
Ok(DataType::UInt64)
63+
}
64+
fn accumulator(&self, _args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
65+
unimplemented!()
66+
}
67+
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
68+
unimplemented!()
69+
}
70+
fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
71+
true
72+
}
73+
fn create_groups_accumulator(
74+
&self,
75+
_args: AccumulatorArgs,
76+
) -> Result<Box<dyn GroupsAccumulator>> {
77+
unimplemented!()
78+
}
79+
}
80+
81+
#[test]
82+
fn test_args_schema_and_groups_path() {
83+
// literal-only: empty physical schema synthesizes schema from literal expr
84+
let udf = Arc::new(AggregateUDF::from(DummyUdf::new()));
85+
let lit_expr =
86+
Arc::new(Literal::new(ScalarValue::UInt32(Some(1)))) as Arc<dyn PhysicalExpr>;
87+
let agg = AggregateExprBuilder::new(udf.clone(), vec![lit_expr.clone()])
88+
.alias("x")
89+
.schema(Arc::new(Schema::empty()))
90+
.build()
91+
.unwrap();
92+
match agg.args_schema() {
93+
Cow::Owned(s) => assert_eq!(s.field(0).name(), "lit"),
94+
_ => panic!("expected owned schema"),
95+
}
96+
assert!(agg.groups_accumulator_supported());
97+
98+
// non-empty physical schema should be borrowed
99+
let f = Field::new("b", DataType::Int32, false);
100+
let phys_schema = Schema::new(vec![f.clone()]);
101+
let col_expr = Arc::new(Column::new("b", 0)) as Arc<dyn PhysicalExpr>;
102+
let agg2 = AggregateExprBuilder::new(udf, vec![col_expr])
103+
.alias("x")
104+
.schema(Arc::new(phys_schema.clone()))
105+
.build()
106+
.unwrap();
107+
match agg2.args_schema() {
108+
Cow::Borrowed(s) => assert_eq!(s.field(0).name(), "b"),
109+
_ => panic!("expected borrowed schema"),
110+
}
111+
assert!(agg2.groups_accumulator_supported());
112+
}
113+
}
27114
pub(crate) mod stats {
28115
pub use datafusion_functions_aggregate_common::stats::StatsType;
29116
}
@@ -404,7 +491,7 @@ impl AggregateFunctionExpr {
404491
}
405492
}
406493
/// Construct AccumulatorArgs for this aggregate using a given schema slice.
407-
fn make_acc_args(&self, schema: &Schema) -> AccumulatorArgs<'_> {
494+
fn make_acc_args<'a>(&'a self, schema: &'a Schema) -> AccumulatorArgs<'a> {
408495
AccumulatorArgs {
409496
return_field: Arc::clone(&self.return_field),
410497
schema,

0 commit comments

Comments
 (0)