@@ -19,7 +19,10 @@ use super::*;
1919use arrow:: compute:: add;
2020use datafusion:: {
2121 logical_plan:: { create_udaf, FunctionRegistry , LogicalPlanBuilder } ,
22- physical_plan:: { expressions:: AvgAccumulator , functions:: make_scalar_function} ,
22+ physical_plan:: {
23+ expressions:: { AvgAccumulator , MaxAccumulator } ,
24+ functions:: make_scalar_function,
25+ } ,
2326} ;
2427
2528/// test that casting happens on udfs.
@@ -144,15 +147,15 @@ async fn scalar_udf() -> Result<()> {
144147/// tests the creation, registration and usage of a UDAF
145148#[ tokio:: test]
146149async fn simple_udaf ( ) -> Result < ( ) > {
147- let schema = Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int32 , false ) ] ) ;
150+ let schema = Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Float64 , false ) ] ) ;
148151
149152 let batch1 = RecordBatch :: try_new (
150153 Arc :: new ( schema. clone ( ) ) ,
151- vec ! [ Arc :: new( Int32Array :: from_slice( [ 1 , 2 , 3 ] ) ) ] ,
154+ vec ! [ Arc :: new( Float64Array :: from_slice( [ 5.0 , 10.0 , 15.0 ] ) ) ] ,
152155 ) ?;
153156 let batch2 = RecordBatch :: try_new (
154157 Arc :: new ( schema. clone ( ) ) ,
155- vec ! [ Arc :: new( Int32Array :: from_slice( [ 4 , 5 ] ) ) ] ,
158+ vec ! [ Arc :: new( Float64Array :: from_slice( [ 10.0 , 15.0 ] ) ) ] ,
156159 ) ?;
157160
158161 let mut ctx = SessionContext :: new ( ) ;
@@ -166,8 +169,22 @@ async fn simple_udaf() -> Result<()> {
166169 DataType :: Float64 ,
167170 Arc :: new ( DataType :: Float64 ) ,
168171 Volatility :: Immutable ,
169- Arc :: new ( || Ok ( Box :: new ( AvgAccumulator :: try_new ( & DataType :: Float64 ) ?) ) ) ,
170- Arc :: new ( vec ! [ DataType :: UInt64 , DataType :: Float64 ] ) ,
172+ Arc :: new ( |distinct| {
173+ if distinct {
174+ // Use MAX function when DISTINCT is specified as an example
175+ Ok ( Box :: new ( MaxAccumulator :: try_new ( & DataType :: Float64 ) ?) )
176+ } else {
177+ Ok ( Box :: new ( AvgAccumulator :: try_new ( & DataType :: Float64 ) ?) )
178+ }
179+ } ) ,
180+ Arc :: new ( |data_type, distinct| {
181+ if distinct {
182+ // When DISTINCT is specified, use state type for MAX function
183+ Ok ( Arc :: new ( vec ! [ data_type. clone( ) ] ) )
184+ } else {
185+ Ok ( Arc :: new ( vec ! [ DataType :: UInt64 , data_type. clone( ) ] ) )
186+ }
187+ } ) ,
171188 ) ;
172189
173190 ctx. register_udaf ( my_avg) ;
@@ -178,10 +195,22 @@ async fn simple_udaf() -> Result<()> {
178195 "+-------------+" ,
179196 "| my_avg(t.a) |" ,
180197 "+-------------+" ,
181- "| 3 |" ,
198+ "| 11 |" ,
182199 "+-------------+" ,
183200 ] ;
184201 assert_batches_eq ! ( expected, & result) ;
185202
203+ // also test DISTINCT. in this case it makes MY_AVG act like MAX function
204+ let result = plan_and_collect ( & ctx, "SELECT MY_AVG(DISTINCT a) FROM t" ) . await ?;
205+
206+ let expected = vec ! [
207+ "+----------------------+" ,
208+ "| my_avg(DISTINCT t.a) |" ,
209+ "+----------------------+" ,
210+ "| 15 |" ,
211+ "+----------------------+" ,
212+ ] ;
213+ assert_batches_eq ! ( expected, & result) ;
214+
186215 Ok ( ( ) )
187216}
0 commit comments