@@ -24,6 +24,7 @@ use std::sync::Arc;
2424use crate :: error:: { DataFusionError , Result } ;
2525use crate :: physical_plan:: groups_accumulator:: GroupsAccumulator ;
2626use crate :: physical_plan:: groups_accumulator_flat_adapter:: GroupsAccumulatorFlatAdapter ;
27+ use crate :: physical_plan:: groups_accumulator_prim_op:: PrimitiveGroupsAccumulator ;
2728use crate :: physical_plan:: { Accumulator , AggregateExpr , PhysicalExpr } ;
2829use crate :: scalar:: ScalarValue ;
2930use arrow:: compute;
@@ -49,6 +50,7 @@ use smallvec::SmallVec;
4950pub struct Sum {
5051 name : String ,
5152 data_type : DataType ,
53+ input_data_type : DataType ,
5254 expr : Arc < dyn PhysicalExpr > ,
5355 nullable : bool ,
5456}
@@ -80,11 +82,16 @@ impl Sum {
8082 expr : Arc < dyn PhysicalExpr > ,
8183 name : impl Into < String > ,
8284 data_type : DataType ,
85+ input_data_type : & DataType ,
8386 ) -> Self {
87+ // Note: data_type = sum_return_type(input_data_type) in the actual caller, so we don't
88+ // really need two params. But, we keep the four params to break symmetry with other
89+ // accumulators and any code that might use 3 params, such as the generic_test_op macro.
8490 Self {
8591 name : name. into ( ) ,
8692 expr,
8793 data_type,
94+ input_data_type : input_data_type. clone ( ) ,
8895 nullable : true ,
8996 }
9097 }
@@ -127,12 +134,147 @@ impl AggregateExpr for Sum {
127134 fn create_groups_accumulator (
128135 & self ,
129136 ) -> arrow:: error:: Result < Option < Box < dyn GroupsAccumulator > > > {
130- let data_type = self . data_type . clone ( ) ;
131- Ok ( Some ( Box :: new (
132- GroupsAccumulatorFlatAdapter :: < SumAccumulator > :: new ( move || {
133- SumAccumulator :: try_new ( & data_type)
134- } ) ,
135- ) ) )
137+ use arrow:: datatypes:: ArrowPrimitiveType ;
138+
139+ macro_rules! make_accumulator {
140+ ( $T: ty, $U: ty) => {
141+ Box :: new( PrimitiveGroupsAccumulator :: <$T, $U, _, _>:: new(
142+ & <$T as ArrowPrimitiveType >:: DATA_TYPE ,
143+ |x: & mut <$T as ArrowPrimitiveType >:: Native ,
144+ y: <$U as ArrowPrimitiveType >:: Native | {
145+ * x = * x + ( y as <$T as ArrowPrimitiveType >:: Native ) ;
146+ } ,
147+ |x: & mut <$T as ArrowPrimitiveType >:: Native ,
148+ y: <$T as ArrowPrimitiveType >:: Native | {
149+ * x = * x + y;
150+ } ,
151+ ) )
152+ } ;
153+ }
154+
155+ // Note that upstream uses x.add_wrapping(y) for the sum functions -- but here we just mimic
156+ // the current datafusion Sum accumulator implementation using native +. (That native +
157+ // specifically is the one in the expressions *x = *x + ... above.)
158+ Ok ( Some ( match ( & self . data_type , & self . input_data_type ) {
159+ ( DataType :: Int64 , DataType :: Int64 ) => make_accumulator ! (
160+ arrow:: datatypes:: Int64Type ,
161+ arrow:: datatypes:: Int64Type
162+ ) ,
163+ ( DataType :: Int64 , DataType :: Int32 ) => make_accumulator ! (
164+ arrow:: datatypes:: Int64Type ,
165+ arrow:: datatypes:: Int32Type
166+ ) ,
167+ ( DataType :: Int64 , DataType :: Int16 ) => make_accumulator ! (
168+ arrow:: datatypes:: Int64Type ,
169+ arrow:: datatypes:: Int16Type
170+ ) ,
171+ ( DataType :: Int64 , DataType :: Int8 ) => {
172+ make_accumulator ! ( arrow:: datatypes:: Int64Type , arrow:: datatypes:: Int8Type )
173+ }
174+
175+ ( DataType :: Int96 , DataType :: Int96 ) => make_accumulator ! (
176+ arrow:: datatypes:: Int96Type ,
177+ arrow:: datatypes:: Int96Type
178+ ) ,
179+
180+ ( DataType :: Int64Decimal ( 0 ) , DataType :: Int64Decimal ( 0 ) ) => make_accumulator ! (
181+ arrow:: datatypes:: Int64Decimal0Type ,
182+ arrow:: datatypes:: Int64Decimal0Type
183+ ) ,
184+ ( DataType :: Int64Decimal ( 1 ) , DataType :: Int64Decimal ( 1 ) ) => make_accumulator ! (
185+ arrow:: datatypes:: Int64Decimal1Type ,
186+ arrow:: datatypes:: Int64Decimal1Type
187+ ) ,
188+ ( DataType :: Int64Decimal ( 2 ) , DataType :: Int64Decimal ( 2 ) ) => make_accumulator ! (
189+ arrow:: datatypes:: Int64Decimal2Type ,
190+ arrow:: datatypes:: Int64Decimal2Type
191+ ) ,
192+ ( DataType :: Int64Decimal ( 3 ) , DataType :: Int64Decimal ( 3 ) ) => make_accumulator ! (
193+ arrow:: datatypes:: Int64Decimal3Type ,
194+ arrow:: datatypes:: Int64Decimal3Type
195+ ) ,
196+ ( DataType :: Int64Decimal ( 4 ) , DataType :: Int64Decimal ( 4 ) ) => make_accumulator ! (
197+ arrow:: datatypes:: Int64Decimal4Type ,
198+ arrow:: datatypes:: Int64Decimal4Type
199+ ) ,
200+ ( DataType :: Int64Decimal ( 5 ) , DataType :: Int64Decimal ( 5 ) ) => make_accumulator ! (
201+ arrow:: datatypes:: Int64Decimal5Type ,
202+ arrow:: datatypes:: Int64Decimal5Type
203+ ) ,
204+ ( DataType :: Int64Decimal ( 10 ) , DataType :: Int64Decimal ( 10 ) ) => {
205+ make_accumulator ! (
206+ arrow:: datatypes:: Int64Decimal10Type ,
207+ arrow:: datatypes:: Int64Decimal10Type
208+ )
209+ }
210+
211+ ( DataType :: Int96Decimal ( 0 ) , DataType :: Int96Decimal ( 0 ) ) => make_accumulator ! (
212+ arrow:: datatypes:: Int96Decimal0Type ,
213+ arrow:: datatypes:: Int96Decimal0Type
214+ ) ,
215+ ( DataType :: Int96Decimal ( 1 ) , DataType :: Int96Decimal ( 1 ) ) => make_accumulator ! (
216+ arrow:: datatypes:: Int96Decimal1Type ,
217+ arrow:: datatypes:: Int96Decimal1Type
218+ ) ,
219+ ( DataType :: Int96Decimal ( 2 ) , DataType :: Int96Decimal ( 2 ) ) => make_accumulator ! (
220+ arrow:: datatypes:: Int96Decimal2Type ,
221+ arrow:: datatypes:: Int96Decimal2Type
222+ ) ,
223+ ( DataType :: Int96Decimal ( 3 ) , DataType :: Int96Decimal ( 3 ) ) => make_accumulator ! (
224+ arrow:: datatypes:: Int96Decimal3Type ,
225+ arrow:: datatypes:: Int96Decimal3Type
226+ ) ,
227+ ( DataType :: Int96Decimal ( 4 ) , DataType :: Int96Decimal ( 4 ) ) => make_accumulator ! (
228+ arrow:: datatypes:: Int96Decimal4Type ,
229+ arrow:: datatypes:: Int96Decimal4Type
230+ ) ,
231+ ( DataType :: Int96Decimal ( 5 ) , DataType :: Int96Decimal ( 5 ) ) => make_accumulator ! (
232+ arrow:: datatypes:: Int96Decimal5Type ,
233+ arrow:: datatypes:: Int96Decimal5Type
234+ ) ,
235+ ( DataType :: Int96Decimal ( 10 ) , DataType :: Int96Decimal ( 10 ) ) => {
236+ make_accumulator ! (
237+ arrow:: datatypes:: Int96Decimal10Type ,
238+ arrow:: datatypes:: Int96Decimal10Type
239+ )
240+ }
241+
242+ ( DataType :: UInt64 , DataType :: UInt64 ) => make_accumulator ! (
243+ arrow:: datatypes:: UInt64Type ,
244+ arrow:: datatypes:: UInt64Type
245+ ) ,
246+ ( DataType :: UInt64 , DataType :: UInt32 ) => make_accumulator ! (
247+ arrow:: datatypes:: UInt64Type ,
248+ arrow:: datatypes:: UInt32Type
249+ ) ,
250+ ( DataType :: UInt64 , DataType :: UInt16 ) => make_accumulator ! (
251+ arrow:: datatypes:: UInt64Type ,
252+ arrow:: datatypes:: UInt16Type
253+ ) ,
254+ ( DataType :: UInt64 , DataType :: UInt8 ) => make_accumulator ! (
255+ arrow:: datatypes:: UInt64Type ,
256+ arrow:: datatypes:: UInt8Type
257+ ) ,
258+
259+ ( DataType :: Float32 , DataType :: Float32 ) => make_accumulator ! (
260+ arrow:: datatypes:: Float32Type ,
261+ arrow:: datatypes:: Float32Type
262+ ) ,
263+ ( DataType :: Float64 , DataType :: Float64 ) => make_accumulator ! (
264+ arrow:: datatypes:: Float64Type ,
265+ arrow:: datatypes:: Float64Type
266+ ) ,
267+
268+ _ => {
269+ // This case should never be reached because we've handled all sum_return_type
270+ // arg_type values. Nonetheless:
271+ let data_type = self . data_type . clone ( ) ;
272+
273+ Box :: new ( GroupsAccumulatorFlatAdapter :: < SumAccumulator > :: new (
274+ move || SumAccumulator :: try_new ( & data_type) ,
275+ ) )
276+ }
277+ } ) )
136278 }
137279
138280 fn name ( & self ) -> & str {
@@ -416,13 +558,27 @@ mod tests {
416558 use arrow:: datatypes:: * ;
417559 use arrow:: record_batch:: RecordBatch ;
418560
561+ // A wrapper to make Sum::new, which now has an input_type argument, work with
562+ // generic_test_op!.
563+ struct SumTestStandin ;
564+ impl SumTestStandin {
565+ fn new (
566+ expr : Arc < dyn PhysicalExpr > ,
567+ name : impl Into < String > ,
568+ data_type : DataType ,
569+ ) -> Sum {
570+ Sum :: new ( expr, name, data_type. clone ( ) , & data_type)
571+ }
572+ }
573+
419574 #[ test]
420575 fn sum_i32 ( ) -> Result < ( ) > {
421576 let a: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [ 1 , 2 , 3 , 4 , 5 ] ) ) ;
577+
422578 generic_test_op ! (
423579 a,
424580 DataType :: Int32 ,
425- Sum ,
581+ SumTestStandin ,
426582 ScalarValue :: from( 15i64 ) ,
427583 DataType :: Int64
428584 )
@@ -440,7 +596,7 @@ mod tests {
440596 generic_test_op ! (
441597 a,
442598 DataType :: Int32 ,
443- Sum ,
599+ SumTestStandin ,
444600 ScalarValue :: from( 13i64 ) ,
445601 DataType :: Int64
446602 )
@@ -452,7 +608,7 @@ mod tests {
452608 generic_test_op ! (
453609 a,
454610 DataType :: Int32 ,
455- Sum ,
611+ SumTestStandin ,
456612 ScalarValue :: Int64 ( None ) ,
457613 DataType :: Int64
458614 )
@@ -465,7 +621,7 @@ mod tests {
465621 generic_test_op ! (
466622 a,
467623 DataType :: UInt32 ,
468- Sum ,
624+ SumTestStandin ,
469625 ScalarValue :: from( 15u64 ) ,
470626 DataType :: UInt64
471627 )
@@ -478,7 +634,7 @@ mod tests {
478634 generic_test_op ! (
479635 a,
480636 DataType :: Float32 ,
481- Sum ,
637+ SumTestStandin ,
482638 ScalarValue :: from( 15_f32 ) ,
483639 DataType :: Float32
484640 )
@@ -491,7 +647,7 @@ mod tests {
491647 generic_test_op ! (
492648 a,
493649 DataType :: Float64 ,
494- Sum ,
650+ SumTestStandin ,
495651 ScalarValue :: from( 15_f64 ) ,
496652 DataType :: Float64
497653 )
0 commit comments