@@ -45,7 +45,9 @@ fn get_min_max_by_result_type(
4545 match & input_types[ 0 ] {
4646 arrow:: datatypes:: DataType :: Dictionary ( _, dict_value_type) => {
4747 // x add checker, if the value type is complex data type
48- Ok ( vec ! [ dict_value_type. deref( ) . clone( ) ] )
48+ let mut result = vec ! [ dict_value_type. deref( ) . clone( ) ] ;
49+ result. extend_from_slice ( & input_types[ 1 ..] ) ; // Preserve all other argument types
50+ Ok ( result)
4951 }
5052 _ => Ok ( input_types. to_vec ( ) ) ,
5153 }
@@ -207,3 +209,212 @@ impl logical_expr::AggregateUDFImpl for MinByFunction {
207209 Some ( Box :: new ( simplify) )
208210 }
209211}
212+
213+ #[ cfg( test) ]
214+ mod tests {
215+ use datafusion:: arrow:: array:: {
216+ ArrayRef , Float64Array , Int64Array , RecordBatch , StringArray , UInt64Array ,
217+ } ;
218+ use datafusion:: arrow:: datatypes:: { DataType , Field , Schema } ;
219+ use datafusion:: datasource:: MemTable ;
220+ use datafusion:: prelude:: SessionContext ;
221+ use std:: any:: Any ;
222+ use std:: sync:: Arc ;
223+
224+ #[ cfg( test) ]
225+ mod tests_max_by {
226+ use crate :: max_min_by:: max_by_udaf;
227+ use crate :: max_min_by:: tests:: {
228+ extract_single_float64, extract_single_int64, extract_single_string, test_ctx,
229+ } ;
230+ use datafusion:: error:: Result ;
231+ use datafusion:: prelude:: SessionContext ;
232+
233+ #[ tokio:: test]
234+ async fn test_max_by_string_int ( ) -> Result < ( ) > {
235+ let df = ctx ( ) ?
236+ . sql ( "SELECT max_by(string, int64) FROM types" )
237+ . await ?;
238+ assert_eq ! ( extract_single_string( df. collect( ) . await ?) , "h" ) ;
239+ Ok ( ( ) )
240+ }
241+
242+ #[ tokio:: test]
243+ async fn test_max_by_string_float ( ) -> Result < ( ) > {
244+ let df = ctx ( ) ?
245+ . sql ( "SELECT max_by(string, float64) FROM types" )
246+ . await ?;
247+ assert_eq ! ( extract_single_string( df. collect( ) . await ?) , "h" ) ;
248+ Ok ( ( ) )
249+ }
250+
251+ #[ tokio:: test]
252+ async fn test_max_by_float_string ( ) -> Result < ( ) > {
253+ let df = ctx ( ) ?
254+ . sql ( "SELECT max_by(float64, string) FROM types" )
255+ . await ?;
256+ assert_eq ! ( extract_single_float64( df. collect( ) . await ?) , 8.0 ) ;
257+ Ok ( ( ) )
258+ }
259+
260+ #[ tokio:: test]
261+ async fn test_max_by_int_string ( ) -> Result < ( ) > {
262+ let df = ctx ( ) ?
263+ . sql ( "SELECT max_by(int64, string) FROM types" )
264+ . await ?;
265+ assert_eq ! ( extract_single_int64( df. collect( ) . await ?) , 8 ) ;
266+ Ok ( ( ) )
267+ }
268+
269+ #[ tokio:: test]
270+ async fn test_max_by_dictionary_int ( ) -> Result < ( ) > {
271+ let df = ctx ( ) ?
272+ . sql ( "SELECT max_by(dict_string, int64) FROM types" )
273+ . await ?;
274+ assert_eq ! ( extract_single_string( df. collect( ) . await ?) , "h" ) ;
275+ Ok ( ( ) )
276+ }
277+
278+ fn ctx ( ) -> Result < SessionContext > {
279+ let ctx = test_ctx ( ) ?;
280+ ctx. register_udaf ( max_by_udaf ( ) . as_ref ( ) . clone ( ) ) ;
281+ Ok ( ctx)
282+ }
283+ }
284+
285+ #[ cfg( test) ]
286+ mod test_min_by {
287+ use crate :: max_min_by:: min_by_udaf;
288+ use crate :: max_min_by:: tests:: {
289+ extract_single_float64, extract_single_int64, extract_single_string, test_ctx,
290+ } ;
291+ use datafusion:: error:: Result ;
292+ use datafusion:: prelude:: SessionContext ;
293+
294+ #[ tokio:: test]
295+ async fn test_min_by_string_int ( ) -> Result < ( ) > {
296+ let df = ctx ( ) ?
297+ . sql ( "SELECT min_by(string, int64) FROM types" )
298+ . await ?;
299+ assert_eq ! ( extract_single_string( df. collect( ) . await ?) , "a" ) ;
300+ Ok ( ( ) )
301+ }
302+
303+ #[ tokio:: test]
304+ async fn test_min_by_string_float ( ) -> Result < ( ) > {
305+ let df = ctx ( ) ?
306+ . sql ( "SELECT min_by(string, float64) FROM types" )
307+ . await ?;
308+ assert_eq ! ( extract_single_string( df. collect( ) . await ?) , "a" ) ;
309+ Ok ( ( ) )
310+ }
311+
312+ #[ tokio:: test]
313+ async fn test_min_by_float_string ( ) -> Result < ( ) > {
314+ let df = ctx ( ) ?
315+ . sql ( "SELECT min_by(float64, string) FROM types" )
316+ . await ?;
317+ assert_eq ! ( extract_single_float64( df. collect( ) . await ?) , 0.5 ) ;
318+ Ok ( ( ) )
319+ }
320+
321+ #[ tokio:: test]
322+ async fn test_min_by_int_string ( ) -> Result < ( ) > {
323+ let df = ctx ( ) ?
324+ . sql ( "SELECT min_by(int64, string) FROM types" )
325+ . await ?;
326+ assert_eq ! ( extract_single_int64( df. collect( ) . await ?) , 1 ) ;
327+ Ok ( ( ) )
328+ }
329+
330+ #[ tokio:: test]
331+ async fn test_min_by_dictionary_int ( ) -> Result < ( ) > {
332+ let df = ctx ( ) ?
333+ . sql ( "SELECT min_by(dict_string, int64) FROM types" )
334+ . await ?;
335+ assert_eq ! ( extract_single_string( df. collect( ) . await ?) , "a" ) ;
336+ Ok ( ( ) )
337+ }
338+
339+ fn ctx ( ) -> Result < SessionContext > {
340+ let ctx = test_ctx ( ) ?;
341+ ctx. register_udaf ( min_by_udaf ( ) . as_ref ( ) . clone ( ) ) ;
342+ Ok ( ctx)
343+ }
344+ }
345+
346+ pub ( super ) fn test_schema ( ) -> Arc < Schema > {
347+ Arc :: new ( Schema :: new ( vec ! [
348+ Field :: new( "string" , DataType :: Utf8 , false ) ,
349+ Field :: new_dictionary( "dict_string" , DataType :: Int32 , DataType :: Utf8 , false ) ,
350+ Field :: new( "int64" , DataType :: Int64 , false ) ,
351+ Field :: new( "uint64" , DataType :: UInt64 , false ) ,
352+ Field :: new( "float64" , DataType :: Float64 , false ) ,
353+ ] ) )
354+ }
355+
356+ pub ( super ) fn test_data ( schema : Arc < Schema > ) -> Vec < RecordBatch > {
357+ use datafusion:: arrow:: array:: DictionaryArray ;
358+ use datafusion:: arrow:: datatypes:: Int32Type ;
359+
360+ vec ! [
361+ RecordBatch :: try_new(
362+ schema. clone( ) ,
363+ vec![
364+ Arc :: new( StringArray :: from( vec![ "a" , "b" , "c" , "d" ] ) ) ,
365+ Arc :: new(
366+ vec![ Some ( "a" ) , Some ( "b" ) , Some ( "c" ) , Some ( "d" ) ]
367+ . into_iter( )
368+ . collect:: <DictionaryArray <Int32Type >>( ) ,
369+ ) ,
370+ Arc :: new( Int64Array :: from( vec![ 1 , 2 , 3 , 4 ] ) ) ,
371+ Arc :: new( UInt64Array :: from( vec![ 1 , 2 , 3 , 4 ] ) ) ,
372+ Arc :: new( Float64Array :: from( vec![ 0.5 , 2.0 , 3.0 , 4.0 ] ) ) ,
373+ ] ,
374+ )
375+ . unwrap( ) ,
376+ RecordBatch :: try_new(
377+ schema. clone( ) ,
378+ vec![
379+ Arc :: new( StringArray :: from( vec![ "e" , "f" , "g" , "h" ] ) ) ,
380+ Arc :: new(
381+ vec![ Some ( "e" ) , Some ( "f" ) , Some ( "g" ) , Some ( "h" ) ]
382+ . into_iter( )
383+ . collect:: <DictionaryArray <Int32Type >>( ) ,
384+ ) ,
385+ Arc :: new( Int64Array :: from( vec![ 5 , 6 , 7 , 8 ] ) ) ,
386+ Arc :: new( UInt64Array :: from( vec![ 5 , 6 , 7 , 8 ] ) ) ,
387+ Arc :: new( Float64Array :: from( vec![ 5.0 , 6.0 , 7.0 , 8.0 ] ) ) ,
388+ ] ,
389+ )
390+ . unwrap( ) ,
391+ ]
392+ }
393+
394+ pub ( crate ) fn test_ctx ( ) -> datafusion:: common:: Result < SessionContext > {
395+ let schema = test_schema ( ) ;
396+ let table = MemTable :: try_new ( schema. clone ( ) , vec ! [ test_data( schema) ] ) ?;
397+ let ctx = SessionContext :: new ( ) ;
398+ ctx. register_table ( "types" , Arc :: new ( table) ) ?;
399+ Ok ( ctx)
400+ }
401+
402+ fn downcast < T : Any > ( col : & ArrayRef ) -> & T {
403+ col. as_any ( ) . downcast_ref :: < T > ( ) . unwrap ( )
404+ }
405+
406+ pub ( crate ) fn extract_single_string ( results : Vec < RecordBatch > ) -> String {
407+ let v1 = downcast :: < StringArray > ( results[ 0 ] . column ( 0 ) ) ;
408+ v1. value ( 0 ) . to_string ( )
409+ }
410+
411+ pub ( crate ) fn extract_single_int64 ( results : Vec < RecordBatch > ) -> i64 {
412+ let v1 = downcast :: < Int64Array > ( results[ 0 ] . column ( 0 ) ) ;
413+ v1. value ( 0 )
414+ }
415+
416+ pub ( crate ) fn extract_single_float64 ( results : Vec < RecordBatch > ) -> f64 {
417+ let v1 = downcast :: < Float64Array > ( results[ 0 ] . column ( 0 ) ) ;
418+ v1. value ( 0 )
419+ }
420+ }
0 commit comments