@@ -24,6 +24,7 @@ use std::sync::Arc;
2424use crate :: error:: { DataFusionError , Result } ;
2525use crate :: physical_plan:: groups_accumulator:: GroupsAccumulator ;
2626use crate :: physical_plan:: groups_accumulator_flat_adapter:: GroupsAccumulatorFlatAdapter ;
27+ use crate :: physical_plan:: groups_accumulator_prim_op:: PrimitiveGroupsAccumulator ;
2728use crate :: physical_plan:: { Accumulator , AggregateExpr , PhysicalExpr } ;
2829use crate :: scalar:: ScalarValue ;
2930use arrow:: compute;
@@ -49,6 +50,7 @@ use smallvec::SmallVec;
4950pub struct Sum {
5051 name : String ,
5152 data_type : DataType ,
53+ input_data_type : DataType ,
5254 expr : Arc < dyn PhysicalExpr > ,
5355 nullable : bool ,
5456}
@@ -80,11 +82,16 @@ impl Sum {
8082 expr : Arc < dyn PhysicalExpr > ,
8183 name : impl Into < String > ,
8284 data_type : DataType ,
85+ input_data_type : & DataType ,
8386 ) -> Self {
87+ // Note: data_type = sum_return_type(input_data_type) in the actual caller, so we don't
88+ // really need two params. But, we keep the four params to break symmetry with other
89+ // accumulators and any code that might use 3 params, such as the generic_test_op macro.
8490 Self {
8591 name : name. into ( ) ,
8692 expr,
8793 data_type,
94+ input_data_type : input_data_type. clone ( ) ,
8895 nullable : true ,
8996 }
9097 }
@@ -127,12 +134,64 @@ impl AggregateExpr for Sum {
127134 fn create_groups_accumulator (
128135 & self ,
129136 ) -> arrow:: error:: Result < Option < Box < dyn GroupsAccumulator > > > {
130- let data_type = self . data_type . clone ( ) ;
131- Ok ( Some ( Box :: new (
132- GroupsAccumulatorFlatAdapter :: < SumAccumulator > :: new ( move || {
133- SumAccumulator :: try_new ( & data_type)
134- } ) ,
135- ) ) )
137+ use arrow:: datatypes:: ArrowPrimitiveType ;
138+
139+ macro_rules! make_accumulator {
140+ ( $T: ty, $U: ty) => { Box :: new( PrimitiveGroupsAccumulator :: <
141+ $T,
142+ $U,
143+ _,
144+ _,
145+ >:: new( & <$T as ArrowPrimitiveType >:: DATA_TYPE , |x: & mut <$T as ArrowPrimitiveType >:: Native , y: <$U as ArrowPrimitiveType >:: Native | {
146+ * x = * x + ( y as <$T as ArrowPrimitiveType >:: Native ) ;
147+ } , |x: & mut <$T as ArrowPrimitiveType >:: Native , y: <$T as ArrowPrimitiveType >:: Native | { * x = * x + y; } ) ) } ;
148+ }
149+
150+ // Note that upstream uses x.add_wrapping(y) for the sum functions -- but here we just mimic
151+ // the current datafusion Sum accumulator implementation using native +. (That native +
152+ // specifically is the one in the expressions *x = *x + ... above.)
153+ Ok ( Some ( match ( & self . data_type , & self . input_data_type ) {
154+ ( DataType :: Int64 , DataType :: Int64 ) => make_accumulator ! ( arrow:: datatypes:: Int64Type , arrow:: datatypes:: Int64Type ) ,
155+ ( DataType :: Int64 , DataType :: Int32 ) => make_accumulator ! ( arrow:: datatypes:: Int64Type , arrow:: datatypes:: Int32Type ) ,
156+ ( DataType :: Int64 , DataType :: Int16 ) => make_accumulator ! ( arrow:: datatypes:: Int64Type , arrow:: datatypes:: Int16Type ) ,
157+ ( DataType :: Int64 , DataType :: Int8 ) => make_accumulator ! ( arrow:: datatypes:: Int64Type , arrow:: datatypes:: Int8Type ) ,
158+
159+ ( DataType :: Int96 , DataType :: Int96 ) => make_accumulator ! ( arrow:: datatypes:: Int96Type , arrow:: datatypes:: Int96Type ) ,
160+
161+ ( DataType :: Int64Decimal ( 0 ) , DataType :: Int64Decimal ( 0 ) ) => make_accumulator ! ( arrow:: datatypes:: Int64Decimal0Type , arrow:: datatypes:: Int64Decimal0Type ) ,
162+ ( DataType :: Int64Decimal ( 1 ) , DataType :: Int64Decimal ( 1 ) ) => make_accumulator ! ( arrow:: datatypes:: Int64Decimal1Type , arrow:: datatypes:: Int64Decimal1Type ) ,
163+ ( DataType :: Int64Decimal ( 2 ) , DataType :: Int64Decimal ( 2 ) ) => make_accumulator ! ( arrow:: datatypes:: Int64Decimal2Type , arrow:: datatypes:: Int64Decimal2Type ) ,
164+ ( DataType :: Int64Decimal ( 3 ) , DataType :: Int64Decimal ( 3 ) ) => make_accumulator ! ( arrow:: datatypes:: Int64Decimal3Type , arrow:: datatypes:: Int64Decimal3Type ) ,
165+ ( DataType :: Int64Decimal ( 4 ) , DataType :: Int64Decimal ( 4 ) ) => make_accumulator ! ( arrow:: datatypes:: Int64Decimal4Type , arrow:: datatypes:: Int64Decimal4Type ) ,
166+ ( DataType :: Int64Decimal ( 5 ) , DataType :: Int64Decimal ( 5 ) ) => make_accumulator ! ( arrow:: datatypes:: Int64Decimal5Type , arrow:: datatypes:: Int64Decimal5Type ) ,
167+ ( DataType :: Int64Decimal ( 10 ) , DataType :: Int64Decimal ( 10 ) ) => make_accumulator ! ( arrow:: datatypes:: Int64Decimal10Type , arrow:: datatypes:: Int64Decimal10Type ) ,
168+
169+ ( DataType :: Int96Decimal ( 0 ) , DataType :: Int96Decimal ( 0 ) ) => make_accumulator ! ( arrow:: datatypes:: Int96Decimal0Type , arrow:: datatypes:: Int96Decimal0Type ) ,
170+ ( DataType :: Int96Decimal ( 1 ) , DataType :: Int96Decimal ( 1 ) ) => make_accumulator ! ( arrow:: datatypes:: Int96Decimal1Type , arrow:: datatypes:: Int96Decimal1Type ) ,
171+ ( DataType :: Int96Decimal ( 2 ) , DataType :: Int96Decimal ( 2 ) ) => make_accumulator ! ( arrow:: datatypes:: Int96Decimal2Type , arrow:: datatypes:: Int96Decimal2Type ) ,
172+ ( DataType :: Int96Decimal ( 3 ) , DataType :: Int96Decimal ( 3 ) ) => make_accumulator ! ( arrow:: datatypes:: Int96Decimal3Type , arrow:: datatypes:: Int96Decimal3Type ) ,
173+ ( DataType :: Int96Decimal ( 4 ) , DataType :: Int96Decimal ( 4 ) ) => make_accumulator ! ( arrow:: datatypes:: Int96Decimal4Type , arrow:: datatypes:: Int96Decimal4Type ) ,
174+ ( DataType :: Int96Decimal ( 5 ) , DataType :: Int96Decimal ( 5 ) ) => make_accumulator ! ( arrow:: datatypes:: Int96Decimal5Type , arrow:: datatypes:: Int96Decimal5Type ) ,
175+ ( DataType :: Int96Decimal ( 10 ) , DataType :: Int96Decimal ( 10 ) ) => make_accumulator ! ( arrow:: datatypes:: Int96Decimal10Type , arrow:: datatypes:: Int96Decimal10Type ) ,
176+
177+ ( DataType :: UInt64 , DataType :: UInt64 ) => make_accumulator ! ( arrow:: datatypes:: UInt64Type , arrow:: datatypes:: UInt64Type ) ,
178+ ( DataType :: UInt64 , DataType :: UInt32 ) => make_accumulator ! ( arrow:: datatypes:: UInt64Type , arrow:: datatypes:: UInt32Type ) ,
179+ ( DataType :: UInt64 , DataType :: UInt16 ) => make_accumulator ! ( arrow:: datatypes:: UInt64Type , arrow:: datatypes:: UInt16Type ) ,
180+ ( DataType :: UInt64 , DataType :: UInt8 ) => make_accumulator ! ( arrow:: datatypes:: UInt64Type , arrow:: datatypes:: UInt8Type ) ,
181+
182+ ( DataType :: Float32 , DataType :: Float32 ) => make_accumulator ! ( arrow:: datatypes:: Float32Type , arrow:: datatypes:: Float32Type ) ,
183+ ( DataType :: Float64 , DataType :: Float64 ) => make_accumulator ! ( arrow:: datatypes:: Float64Type , arrow:: datatypes:: Float64Type ) ,
184+
185+ _ => {
186+ // This case should never be reached because we've handled all sum_return_type
187+ // arg_type values. Nonetheless:
188+ let data_type = self . data_type . clone ( ) ;
189+
190+ Box :: new ( GroupsAccumulatorFlatAdapter :: < SumAccumulator > :: new (
191+ move || SumAccumulator :: try_new ( & data_type) ,
192+ ) )
193+ }
194+ } ) )
136195 }
137196
138197 fn name ( & self ) -> & str {
@@ -416,13 +475,25 @@ mod tests {
416475 use arrow:: datatypes:: * ;
417476 use arrow:: record_batch:: RecordBatch ;
418477
478+ // A wrapper to make Sum::new, which now has an input_type argument, work with
479+ // generic_test_op!.
480+ struct SumTestStandin ;
481+ impl SumTestStandin {
482+ fn new ( expr : Arc < dyn PhysicalExpr > ,
483+ name : impl Into < String > ,
484+ data_type : DataType ) -> Sum {
485+ Sum :: new ( expr, name, data_type. clone ( ) , & data_type)
486+ }
487+ }
488+
419489 #[ test]
420490 fn sum_i32 ( ) -> Result < ( ) > {
421491 let a: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [ 1 , 2 , 3 , 4 , 5 ] ) ) ;
492+
422493 generic_test_op ! (
423494 a,
424495 DataType :: Int32 ,
425- Sum ,
496+ SumTestStandin ,
426497 ScalarValue :: from( 15i64 ) ,
427498 DataType :: Int64
428499 )
@@ -440,7 +511,7 @@ mod tests {
440511 generic_test_op ! (
441512 a,
442513 DataType :: Int32 ,
443- Sum ,
514+ SumTestStandin ,
444515 ScalarValue :: from( 13i64 ) ,
445516 DataType :: Int64
446517 )
@@ -452,7 +523,7 @@ mod tests {
452523 generic_test_op ! (
453524 a,
454525 DataType :: Int32 ,
455- Sum ,
526+ SumTestStandin ,
456527 ScalarValue :: Int64 ( None ) ,
457528 DataType :: Int64
458529 )
@@ -465,7 +536,7 @@ mod tests {
465536 generic_test_op ! (
466537 a,
467538 DataType :: UInt32 ,
468- Sum ,
539+ SumTestStandin ,
469540 ScalarValue :: from( 15u64 ) ,
470541 DataType :: UInt64
471542 )
@@ -478,7 +549,7 @@ mod tests {
478549 generic_test_op ! (
479550 a,
480551 DataType :: Float32 ,
481- Sum ,
552+ SumTestStandin ,
482553 ScalarValue :: from( 15_f32 ) ,
483554 DataType :: Float32
484555 )
@@ -491,7 +562,7 @@ mod tests {
491562 generic_test_op ! (
492563 a,
493564 DataType :: Float64 ,
494- Sum ,
565+ SumTestStandin ,
495566 ScalarValue :: from( 15_f64 ) ,
496567 DataType :: Float64
497568 )
0 commit comments