@@ -45,7 +45,10 @@ fn get_min_max_by_result_type(
4545 match & input_types[ 0 ] {
4646 arrow:: datatypes:: DataType :: Dictionary ( _, dict_value_type) => {
4747 // x add checker, if the value type is complex data type
48- Ok ( vec ! [ dict_value_type. deref( ) . clone( ) ] )
48+ let mut result = vec ! [ dict_value_type. deref( ) . clone( ) ] ;
49+ // Preserve all other argument types
50+ result. extend_from_slice ( & input_types[ 1 ..] ) ;
51+ Ok ( result)
4952 }
5053 _ => Ok ( input_types. to_vec ( ) ) ,
5154 }
@@ -207,3 +210,261 @@ impl logical_expr::AggregateUDFImpl for MinByFunction {
207210 Some ( Box :: new ( simplify) )
208211 }
209212}
213+
214+ #[ cfg( test) ]
215+ mod tests {
216+ use super :: * ;
217+
218+ use datafusion:: arrow:: array:: ArrayAccessor ;
219+ use datafusion:: { arrow, datasource, error, prelude} ;
220+ use std:: sync;
221+
222+ const TEST_TABLE_NAME : & str = "types" ;
223+ const STRING_COLUMN_NAME : & str = "string" ;
224+ const DICTIONARY_COLUMN_NAME : & str = "dict_string" ;
225+ const INT64_COLUMN_NAME : & str = "int64" ;
226+ const FLOAT64_COLUMN_NAME : & str = "float64" ;
227+
228+ const MIN_STRING_VALUE : & str = "a" ;
229+ const MID_STRING_VALUE : & str = "b" ;
230+ const MAX_STRING_VALUE : & str = "c" ;
231+ const MIN_FLOAT_VALUE : f64 = 0.25 ;
232+ const MID_FLOAT_VALUE : f64 = 0.5 ;
233+ const MAX_FLOAT_VALUE : f64 = 0.75 ;
234+ const MIN_INT_VALUE : i64 = -1 ;
235+ const MID_INT_VALUE : i64 = 0 ;
236+ const MAX_INT_VALUE : i64 = 1 ;
237+ const MIN_DICTIONARY_VALUE : & str = "a" ;
238+ const MID_DICTIONARY_VALUE : & str = "b" ;
239+ const MAX_DICTIONARY_VALUE : & str = "c" ;
240+
241+ fn test_schema ( ) -> sync:: Arc < arrow:: datatypes:: Schema > {
242+ sync:: Arc :: new ( arrow:: datatypes:: Schema :: new ( vec ! [
243+ arrow:: datatypes:: Field :: new(
244+ STRING_COLUMN_NAME ,
245+ arrow:: datatypes:: DataType :: Utf8 ,
246+ false ,
247+ ) ,
248+ arrow:: datatypes:: Field :: new_dictionary(
249+ DICTIONARY_COLUMN_NAME ,
250+ arrow:: datatypes:: DataType :: Int32 ,
251+ arrow:: datatypes:: DataType :: Utf8 ,
252+ false ,
253+ ) ,
254+ arrow:: datatypes:: Field :: new(
255+ INT64_COLUMN_NAME ,
256+ arrow:: datatypes:: DataType :: Int64 ,
257+ false ,
258+ ) ,
259+ arrow:: datatypes:: Field :: new(
260+ FLOAT64_COLUMN_NAME ,
261+ arrow:: datatypes:: DataType :: Float64 ,
262+ false ,
263+ ) ,
264+ ] ) )
265+ }
266+
267+ fn test_data (
268+ schema : sync:: Arc < arrow:: datatypes:: Schema > ,
269+ ) -> Vec < arrow:: record_batch:: RecordBatch > {
270+ vec ! [
271+ arrow:: record_batch:: RecordBatch :: try_new(
272+ schema,
273+ vec![
274+ sync:: Arc :: new( arrow:: array:: StringArray :: from( vec![
275+ MID_STRING_VALUE ,
276+ MIN_STRING_VALUE ,
277+ MAX_STRING_VALUE ,
278+ ] ) ) ,
279+ sync:: Arc :: new(
280+ vec![
281+ Some ( MID_DICTIONARY_VALUE ) ,
282+ Some ( MIN_DICTIONARY_VALUE ) ,
283+ Some ( MAX_DICTIONARY_VALUE ) ,
284+ ]
285+ . into_iter( )
286+ . collect:: <arrow:: array:: DictionaryArray <arrow:: datatypes:: Int32Type >>( ) ,
287+ ) ,
288+ sync:: Arc :: new( arrow:: array:: Int64Array :: from( vec![
289+ MID_INT_VALUE ,
290+ MIN_INT_VALUE ,
291+ MAX_INT_VALUE ,
292+ ] ) ) ,
293+ sync:: Arc :: new( arrow:: array:: Float64Array :: from( vec![
294+ MID_FLOAT_VALUE ,
295+ MIN_FLOAT_VALUE ,
296+ MAX_FLOAT_VALUE ,
297+ ] ) ) ,
298+ ] ,
299+ )
300+ . unwrap( ) ,
301+ ]
302+ }
303+
304+ fn test_ctx ( ) -> datafusion:: common:: Result < prelude:: SessionContext > {
305+ let schema = test_schema ( ) ;
306+ let data = test_data ( schema. clone ( ) ) ;
307+ let table = datasource:: MemTable :: try_new ( schema, vec ! [ data] ) ?;
308+ let ctx = prelude:: SessionContext :: new ( ) ;
309+ ctx. register_table ( TEST_TABLE_NAME , sync:: Arc :: new ( table) ) ?;
310+ Ok ( ctx)
311+ }
312+
313+ async fn extract_single_value < T , A > ( df : prelude:: DataFrame ) -> error:: Result < T >
314+ where
315+ A : arrow:: array:: Array + ' static ,
316+ for < ' a > & ' a A : arrow:: array:: ArrayAccessor ,
317+ for < ' a > <& ' a A as arrow:: array:: ArrayAccessor >:: Item : Into < T > ,
318+ {
319+ let results = df. collect ( ) . await ?;
320+ let col = results[ 0 ] . column ( 0 ) ;
321+ let v1 = col. as_any ( ) . downcast_ref :: < A > ( ) . unwrap ( ) ;
322+ let value = v1. value ( 0 ) . into ( ) ;
323+ Ok ( value)
324+ }
325+
326+ #[ cfg( test) ]
327+ mod max_by {
328+ use super :: * ;
329+
330+ #[ tokio:: test]
331+ async fn test_max_by_string_int ( ) -> error:: Result < ( ) > {
332+ let query = format ! (
333+ "SELECT max_by({}, {}) FROM {}" ,
334+ STRING_COLUMN_NAME , INT64_COLUMN_NAME , TEST_TABLE_NAME
335+ ) ;
336+ let df = ctx ( ) ?. sql ( & query) . await ?;
337+ let result = extract_single_value :: < String , arrow:: array:: StringArray > ( df) . await ?;
338+ assert_eq ! ( result, MAX_STRING_VALUE ) ;
339+ Ok ( ( ) )
340+ }
341+
342+ #[ tokio:: test]
343+ async fn test_max_by_string_float ( ) -> error:: Result < ( ) > {
344+ let query = format ! (
345+ "SELECT max_by({}, {}) FROM {}" ,
346+ STRING_COLUMN_NAME , FLOAT64_COLUMN_NAME , TEST_TABLE_NAME
347+ ) ;
348+ let df = ctx ( ) ?. sql ( & query) . await ?;
349+ let result = extract_single_value :: < String , arrow:: array:: StringArray > ( df) . await ?;
350+ assert_eq ! ( result, MAX_STRING_VALUE ) ;
351+ Ok ( ( ) )
352+ }
353+
354+ #[ tokio:: test]
355+ async fn test_max_by_float_string ( ) -> error:: Result < ( ) > {
356+ let query = format ! (
357+ "SELECT max_by({}, {}) FROM {}" ,
358+ FLOAT64_COLUMN_NAME , STRING_COLUMN_NAME , TEST_TABLE_NAME
359+ ) ;
360+ let df = ctx ( ) ?. sql ( & query) . await ?;
361+ let result = extract_single_value :: < f64 , arrow:: array:: Float64Array > ( df) . await ?;
362+ assert_eq ! ( result, MAX_FLOAT_VALUE ) ;
363+ Ok ( ( ) )
364+ }
365+
366+ #[ tokio:: test]
367+ async fn test_max_by_int_string ( ) -> error:: Result < ( ) > {
368+ let query = format ! (
369+ "SELECT max_by({}, {}) FROM {}" ,
370+ INT64_COLUMN_NAME , STRING_COLUMN_NAME , TEST_TABLE_NAME
371+ ) ;
372+ let df = ctx ( ) ?. sql ( & query) . await ?;
373+ let result = extract_single_value :: < i64 , arrow:: array:: Int64Array > ( df) . await ?;
374+ assert_eq ! ( result, MAX_INT_VALUE ) ;
375+ Ok ( ( ) )
376+ }
377+
378+ #[ tokio:: test]
379+ async fn test_max_by_dictionary_int ( ) -> error:: Result < ( ) > {
380+ let query = format ! (
381+ "SELECT max_by({}, {}) FROM {}" ,
382+ DICTIONARY_COLUMN_NAME , INT64_COLUMN_NAME , TEST_TABLE_NAME
383+ ) ;
384+ let df = ctx ( ) ?. sql ( & query) . await ?;
385+ let result = extract_single_value :: < String , arrow:: array:: StringArray > ( df) . await ?;
386+ assert_eq ! ( result, MAX_DICTIONARY_VALUE ) ;
387+ Ok ( ( ) )
388+ }
389+
390+ fn ctx ( ) -> error:: Result < prelude:: SessionContext > {
391+ let ctx = test_ctx ( ) ?;
392+ let max_by_udaf = MaxByFunction :: new ( ) ;
393+ ctx. register_udaf ( max_by_udaf. into ( ) ) ;
394+ Ok ( ctx)
395+ }
396+ }
397+
398+ #[ cfg( test) ]
399+ mod min_by {
400+
401+ use super :: * ;
402+
403+ #[ tokio:: test]
404+ async fn test_min_by_string_int ( ) -> error:: Result < ( ) > {
405+ let query = format ! (
406+ "SELECT min_by({}, {}) FROM {}" ,
407+ STRING_COLUMN_NAME , INT64_COLUMN_NAME , TEST_TABLE_NAME
408+ ) ;
409+ let df = ctx ( ) ?. sql ( & query) . await ?;
410+ let result = extract_single_value :: < String , arrow:: array:: StringArray > ( df) . await ?;
411+ assert_eq ! ( result, MIN_STRING_VALUE ) ;
412+ Ok ( ( ) )
413+ }
414+
415+ #[ tokio:: test]
416+ async fn test_min_by_string_float ( ) -> error:: Result < ( ) > {
417+ let query = format ! (
418+ "SELECT min_by({}, {}) FROM {}" ,
419+ STRING_COLUMN_NAME , FLOAT64_COLUMN_NAME , TEST_TABLE_NAME
420+ ) ;
421+ let df = ctx ( ) ?. sql ( & query) . await ?;
422+ let result = extract_single_value :: < String , arrow:: array:: StringArray > ( df) . await ?;
423+ assert_eq ! ( result, MIN_STRING_VALUE ) ;
424+ Ok ( ( ) )
425+ }
426+
427+ #[ tokio:: test]
428+ async fn test_min_by_float_string ( ) -> error:: Result < ( ) > {
429+ let query = format ! (
430+ "SELECT min_by({}, {}) FROM {}" ,
431+ FLOAT64_COLUMN_NAME , STRING_COLUMN_NAME , TEST_TABLE_NAME
432+ ) ;
433+ let df = ctx ( ) ?. sql ( & query) . await ?;
434+ let result = extract_single_value :: < f64 , arrow:: array:: Float64Array > ( df) . await ?;
435+ assert_eq ! ( result, MIN_FLOAT_VALUE ) ;
436+ Ok ( ( ) )
437+ }
438+
439+ #[ tokio:: test]
440+ async fn test_min_by_int_string ( ) -> error:: Result < ( ) > {
441+ let query = format ! (
442+ "SELECT min_by({}, {}) FROM {}" ,
443+ INT64_COLUMN_NAME , STRING_COLUMN_NAME , TEST_TABLE_NAME
444+ ) ;
445+ let df = ctx ( ) ?. sql ( & query) . await ?;
446+ let result = extract_single_value :: < i64 , arrow:: array:: Int64Array > ( df) . await ?;
447+ assert_eq ! ( result, MIN_INT_VALUE ) ;
448+ Ok ( ( ) )
449+ }
450+
451+ #[ tokio:: test]
452+ async fn test_min_by_dictionary_int ( ) -> error:: Result < ( ) > {
453+ let query = format ! (
454+ "SELECT min_by({}, {}) FROM {}" ,
455+ DICTIONARY_COLUMN_NAME , INT64_COLUMN_NAME , TEST_TABLE_NAME
456+ ) ;
457+ let df = ctx ( ) ?. sql ( & query) . await ?;
458+ let result = extract_single_value :: < String , arrow:: array:: StringArray > ( df) . await ?;
459+ assert_eq ! ( result, MIN_DICTIONARY_VALUE ) ;
460+ Ok ( ( ) )
461+ }
462+
463+ fn ctx ( ) -> error:: Result < prelude:: SessionContext > {
464+ let ctx = test_ctx ( ) ?;
465+ let min_by_udaf = MinByFunction :: new ( ) ;
466+ ctx. register_udaf ( min_by_udaf. into ( ) ) ;
467+ Ok ( ctx)
468+ }
469+ }
470+ }
0 commit comments