@@ -24,6 +24,93 @@ pub(crate) mod groups_accumulator {
24
24
accumulate:: NullState , GroupsAccumulatorAdapter ,
25
25
} ;
26
26
}
27
+
28
+ #[ cfg( test) ]
29
+ mod tests {
30
+ use super :: * ;
31
+ use crate :: expressions:: Literal ;
32
+ use arrow:: datatypes:: { DataType , Field , Schema } ;
33
+ use datafusion_common:: ScalarValue ;
34
+ use datafusion_expr:: { AggregateUDF , AggregateUDFImpl , Signature , Volatility } ;
35
+ use std:: any:: Any ;
36
+ use std:: sync:: Arc ;
37
+
38
+ #[ derive( Debug ) ]
39
+ struct DummyUdf {
40
+ signature : Signature ,
41
+ }
42
+
43
+ impl DummyUdf {
44
+ fn new ( ) -> Self {
45
+ Self {
46
+ signature : Signature :: any ( 1 , Volatility :: Immutable ) ,
47
+ }
48
+ }
49
+ }
50
+
51
+ impl AggregateUDFImpl for DummyUdf {
52
+ fn as_any ( & self ) -> & dyn Any {
53
+ self
54
+ }
55
+ fn name ( & self ) -> & str {
56
+ "dummy"
57
+ }
58
+ fn signature ( & self ) -> & Signature {
59
+ & self . signature
60
+ }
61
+ fn return_type ( & self , _args : & [ DataType ] ) -> Result < DataType > {
62
+ Ok ( DataType :: UInt64 )
63
+ }
64
+ fn accumulator ( & self , _args : AccumulatorArgs ) -> Result < Box < dyn Accumulator > > {
65
+ unimplemented ! ( )
66
+ }
67
+ fn state_fields ( & self , _args : StateFieldsArgs ) -> Result < Vec < FieldRef > > {
68
+ unimplemented ! ( )
69
+ }
70
+ fn groups_accumulator_supported ( & self , _args : AccumulatorArgs ) -> bool {
71
+ true
72
+ }
73
+ fn create_groups_accumulator (
74
+ & self ,
75
+ _args : AccumulatorArgs ,
76
+ ) -> Result < Box < dyn GroupsAccumulator > > {
77
+ unimplemented ! ( )
78
+ }
79
+ }
80
+
81
+ #[ test]
82
+ fn test_args_schema_and_groups_path ( ) {
83
+ // literal-only: empty physical schema synthesizes schema from literal expr
84
+ let udf = Arc :: new ( AggregateUDF :: from ( DummyUdf :: new ( ) ) ) ;
85
+ let lit_expr =
86
+ Arc :: new ( Literal :: new ( ScalarValue :: UInt32 ( Some ( 1 ) ) ) ) as Arc < dyn PhysicalExpr > ;
87
+ let agg = AggregateExprBuilder :: new ( udf. clone ( ) , vec ! [ lit_expr. clone( ) ] )
88
+ . alias ( "x" )
89
+ . schema ( Arc :: new ( Schema :: empty ( ) ) )
90
+ . build ( )
91
+ . unwrap ( ) ;
92
+ match agg. args_schema ( ) {
93
+ Cow :: Owned ( s) => assert_eq ! ( s. field( 0 ) . name( ) , "lit" ) ,
94
+ _ => panic ! ( "expected owned schema" ) ,
95
+ }
96
+ assert ! ( agg. groups_accumulator_supported( ) ) ;
97
+
98
+ // non-empty physical schema should be borrowed
99
+ let f = Field :: new ( "b" , DataType :: Int32 , false ) ;
100
+ let phys_schema = Schema :: new ( vec ! [ f. clone( ) ] ) ;
101
+ let col_expr = Arc :: new ( Column :: new ( "b" , 0 ) ) as Arc < dyn PhysicalExpr > ;
102
+ let agg2 = AggregateExprBuilder :: new ( udf, vec ! [ col_expr] )
103
+ . alias ( "x" )
104
+ . schema ( Arc :: new ( phys_schema. clone ( ) ) )
105
+ . build ( )
106
+ . unwrap ( ) ;
107
+ match agg2. args_schema ( ) {
108
+ Cow :: Borrowed ( s) => assert_eq ! ( s. field( 0 ) . name( ) , "b" ) ,
109
+ _ => panic ! ( "expected borrowed schema" ) ,
110
+ }
111
+ assert ! ( agg2. groups_accumulator_supported( ) ) ;
112
+ }
113
+ }
27
114
pub ( crate ) mod stats {
28
115
pub use datafusion_functions_aggregate_common:: stats:: StatsType ;
29
116
}
@@ -404,7 +491,7 @@ impl AggregateFunctionExpr {
404
491
}
405
492
}
406
493
/// Construct AccumulatorArgs for this aggregate using a given schema slice.
407
- fn make_acc_args ( & self , schema : & Schema ) -> AccumulatorArgs < ' _ > {
494
+ fn make_acc_args < ' a > ( & ' a self , schema : & ' a Schema ) -> AccumulatorArgs < ' a > {
408
495
AccumulatorArgs {
409
496
return_field : Arc :: clone ( & self . return_field ) ,
410
497
schema,
0 commit comments