@@ -1200,8 +1200,10 @@ mod tests {
12001200
12011201 use arrow:: array:: { Float64Array , UInt32Array } ;
12021202 use arrow:: compute:: { concat_batches, SortOptions } ;
1203- use arrow:: datatypes:: DataType ;
1204- use arrow_array:: { Float32Array , Int32Array } ;
1203+ use arrow:: datatypes:: { DataType , Int32Type } ;
1204+ use arrow_array:: {
1205+ DictionaryArray , Float32Array , Int32Array , StructArray , UInt64Array ,
1206+ } ;
12051207 use datafusion_common:: {
12061208 assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError ,
12071209 ScalarValue ,
@@ -1214,6 +1216,7 @@ mod tests {
12141216 use datafusion_functions_aggregate:: count:: count_udaf;
12151217 use datafusion_functions_aggregate:: first_last:: { first_value_udaf, last_value_udaf} ;
12161218 use datafusion_functions_aggregate:: median:: median_udaf;
1219+ use datafusion_functions_aggregate:: sum:: sum_udaf;
12171220 use datafusion_physical_expr:: expressions:: lit;
12181221 use datafusion_physical_expr:: PhysicalSortExpr ;
12191222
@@ -2316,6 +2319,127 @@ mod tests {
23162319 Ok ( ( ) )
23172320 }
23182321
2322+ #[ tokio:: test]
2323+ async fn test_agg_exec_struct_of_dicts ( ) -> Result < ( ) > {
2324+ let batch = RecordBatch :: try_new (
2325+ Arc :: new ( Schema :: new ( vec ! [
2326+ Field :: new(
2327+ "labels" . to_string( ) ,
2328+ DataType :: Struct (
2329+ vec![
2330+ Field :: new_dict(
2331+ "a" . to_string( ) ,
2332+ DataType :: Dictionary (
2333+ Box :: new( DataType :: Int32 ) ,
2334+ Box :: new( DataType :: Utf8 ) ,
2335+ ) ,
2336+ true ,
2337+ 0 ,
2338+ false ,
2339+ ) ,
2340+ Field :: new_dict(
2341+ "b" . to_string( ) ,
2342+ DataType :: Dictionary (
2343+ Box :: new( DataType :: Int32 ) ,
2344+ Box :: new( DataType :: Utf8 ) ,
2345+ ) ,
2346+ true ,
2347+ 0 ,
2348+ false ,
2349+ ) ,
2350+ ]
2351+ . into( ) ,
2352+ ) ,
2353+ false ,
2354+ ) ,
2355+ Field :: new( "value" , DataType :: UInt64 , false ) ,
2356+ ] ) ) ,
2357+ vec ! [
2358+ Arc :: new( StructArray :: from( vec![
2359+ (
2360+ Arc :: new( Field :: new_dict(
2361+ "a" . to_string( ) ,
2362+ DataType :: Dictionary (
2363+ Box :: new( DataType :: Int32 ) ,
2364+ Box :: new( DataType :: Utf8 ) ,
2365+ ) ,
2366+ true ,
2367+ 0 ,
2368+ false ,
2369+ ) ) ,
2370+ Arc :: new(
2371+ vec![ Some ( "a" ) , None , Some ( "a" ) ]
2372+ . into_iter( )
2373+ . collect:: <DictionaryArray <Int32Type >>( ) ,
2374+ ) as ArrayRef ,
2375+ ) ,
2376+ (
2377+ Arc :: new( Field :: new_dict(
2378+ "b" . to_string( ) ,
2379+ DataType :: Dictionary (
2380+ Box :: new( DataType :: Int32 ) ,
2381+ Box :: new( DataType :: Utf8 ) ,
2382+ ) ,
2383+ true ,
2384+ 0 ,
2385+ false ,
2386+ ) ) ,
2387+ Arc :: new(
2388+ vec![ Some ( "b" ) , Some ( "c" ) , Some ( "b" ) ]
2389+ . into_iter( )
2390+ . collect:: <DictionaryArray <Int32Type >>( ) ,
2391+ ) as ArrayRef ,
2392+ ) ,
2393+ ] ) ) ,
2394+ Arc :: new( UInt64Array :: from( vec![ 1 , 1 , 1 ] ) ) ,
2395+ ] ,
2396+ )
2397+ . expect ( "Failed to create RecordBatch" ) ;
2398+
2399+ let group_by = PhysicalGroupBy :: new_single ( vec ! [ (
2400+ col( "labels" , & batch. schema( ) ) ?,
2401+ "labels" . to_string( ) ,
2402+ ) ] ) ;
2403+
2404+ let aggr_expr = vec ! [ AggregateExprBuilder :: new(
2405+ sum_udaf( ) ,
2406+ vec![ col( "value" , & batch. schema( ) ) ?] ,
2407+ )
2408+ . schema( Arc :: clone( & batch. schema( ) ) )
2409+ . alias( String :: from( "SUM(value)" ) )
2410+ . build( ) ?] ;
2411+
2412+ let input = Arc :: new ( MemoryExec :: try_new (
2413+ & [ vec ! [ batch. clone( ) ] ] ,
2414+ Arc :: < arrow_schema:: Schema > :: clone ( & batch. schema ( ) ) ,
2415+ None ,
2416+ ) ?) ;
2417+ let aggregate_exec = Arc :: new ( AggregateExec :: try_new (
2418+ AggregateMode :: FinalPartitioned ,
2419+ group_by,
2420+ aggr_expr,
2421+ vec ! [ None ] ,
2422+ Arc :: clone ( & input) as Arc < dyn ExecutionPlan > ,
2423+ batch. schema ( ) ,
2424+ ) ?) ;
2425+
2426+ let session_config = SessionConfig :: default ( ) ;
2427+ let ctx = TaskContext :: default ( ) . with_session_config ( session_config) ;
2428+ let output = collect ( aggregate_exec. execute ( 0 , Arc :: new ( ctx) ) ?) . await ?;
2429+
2430+ let expected = [
2431+ "+--------------+------------+" ,
2432+ "| labels | SUM(value) |" ,
2433+ "+--------------+------------+" ,
2434+ "| {a: a, b: b} | 2 |" ,
2435+ "| {a: , b: c} | 1 |" ,
2436+ "+--------------+------------+" ,
2437+ ] ;
2438+ assert_batches_eq ! ( expected, & output) ;
2439+
2440+ Ok ( ( ) )
2441+ }
2442+
23192443 #[ tokio:: test]
23202444 async fn test_skip_aggregation_after_first_batch ( ) -> Result < ( ) > {
23212445 let schema = Arc :: new ( Schema :: new ( vec ! [
0 commit comments