1515// specific language governing permissions and limitations
1616// under the License.
1717
18- use crate :: EvalMode ;
18+ use crate :: { arithmetic_overflow_error , EvalMode } ;
1919use arrow:: array:: {
2020 cast:: AsArray , Array , ArrayBuilder , ArrayRef , ArrowNativeTypeOp , ArrowPrimitiveType ,
2121 BooleanArray , Int64Array , PrimitiveArray ,
@@ -39,8 +39,8 @@ pub struct SumInteger {
3939
4040impl SumInteger {
4141 pub fn try_new ( data_type : DataType , eval_mode : EvalMode ) -> DFResult < Self > {
42- // The `data_type` is the SUM result type passed from Spark side
43- println ! ( "data type: {:?}" , data_type) ;
42+ // The `data_type` is the SUM result type passed from Spark side which should i64
43+ println ! ( "data type: {:?} eval_mode {:?} " , data_type, eval_mode ) ;
4444 match data_type {
4545 DataType :: Int8 | DataType :: Int16 | DataType :: Int32 | DataType :: Int64 => Ok ( Self {
4646 signature : Signature :: user_defined ( Immutable ) ,
@@ -75,14 +75,14 @@ impl AggregateUDFImpl for SumInteger {
7575 }
7676
7777 fn accumulator ( & self , acc_args : AccumulatorArgs ) -> DFResult < Box < dyn Accumulator > > {
78- Ok ( Box :: new ( SumIntegerAccumulator :: new ( ) ) )
78+ Ok ( Box :: new ( SumIntegerAccumulator :: new ( self . eval_mode ) ) )
7979 }
8080
8181 fn create_groups_accumulator (
8282 & self ,
8383 _args : AccumulatorArgs ,
8484 ) -> DFResult < Box < dyn GroupsAccumulator > > {
85- Ok ( Box :: new ( SumDecimalGroupsAccumulator :: new ( self . eval_mode ) ) )
85+ Ok ( Box :: new ( SumIntGroupsAccumulator :: new ( self . eval_mode ) ) )
8686 }
8787}
8888
@@ -94,10 +94,10 @@ struct SumIntegerAccumulator {
9494}
9595
9696impl SumIntegerAccumulator {
97- fn new ( ) -> Self {
97+ fn new ( eval_mode : EvalMode ) -> Self {
9898 Self {
9999 sum : 0 ,
100- eval_mode : EvalMode :: Legacy ,
100+ eval_mode,
101101 input_data_type : DataType :: Int64 ,
102102 }
103103 }
@@ -113,13 +113,13 @@ impl Accumulator for SumIntegerAccumulator {
113113 where
114114 T : ArrowPrimitiveType ,
115115 {
116- println ! ( "match internal function data type: {:?}" , sum) ;
117116 let len = int_array. len ( ) ;
118117 for i in 0 ..int_array. len ( ) {
119118 if !int_array. is_null ( i) {
120119 let v = int_array. value ( i) . to_i64 ( ) . ok_or_else ( || {
121120 DataFusionError :: Internal ( "Failed to convert value to i64" . to_string ( ) )
122121 } ) ?;
122+ println ! ( "sum : {:?}, v : {:?}" , sum, v) ;
123123 match eval_mode {
124124 EvalMode :: Legacy | EvalMode :: Try => {
125125 sum = v. add_wrapping ( sum) ;
@@ -128,7 +128,7 @@ impl Accumulator for SumIntegerAccumulator {
128128 match v. add_checked ( sum) {
129129 Ok ( v) => sum = v,
130130 Err ( e) => {
131- return Err ( DataFusionError :: Internal ( "error" . to_string ( ) ) )
131+ return Err ( DataFusionError :: from ( arithmetic_overflow_error ( "integer" ) ) )
132132 }
133133 } ;
134134 }
@@ -157,53 +157,40 @@ impl Accumulator for SumIntegerAccumulator {
157157 ) ;
158158 Ok ( ( ) )
159159 } else {
160- match values. data_type ( ) {
161- DataType :: Int64 => {
162- println ! ( "match data type: {:?}" , self . input_data_type) ;
163- update_sum_internal (
164- values
165- . as_any ( )
166- . downcast_ref :: < PrimitiveArray < Int64Type > > ( )
167- . unwrap ( ) ,
168- self . eval_mode ,
169- self . sum ,
170- ) ?;
171- }
172- DataType :: Int32 => {
173- println ! ( "match data type: {:?}" , self . input_data_type) ;
174- update_sum_internal (
175- values
176- . as_any ( )
177- . downcast_ref :: < PrimitiveArray < Int32Type > > ( )
178- . unwrap ( ) ,
179- self . eval_mode ,
180- self . sum ,
181- ) ?;
182- }
183- DataType :: Int16 => {
184- println ! ( "match data type: {:?}" , self . input_data_type) ;
185- update_sum_internal (
186- values
187- . as_any ( )
188- . downcast_ref :: < PrimitiveArray < Int16Type > > ( )
189- . unwrap ( ) ,
190- self . eval_mode ,
191- self . sum ,
192- ) ?;
193- }
194- DataType :: Int8 => {
195- println ! ( "match data type: {:?}" , self . input_data_type) ;
196- update_sum_internal (
197- values
198- . as_any ( )
199- . downcast_ref :: < PrimitiveArray < Int8Type > > ( )
200- . unwrap ( ) ,
201- self . eval_mode ,
202- self . sum ,
203- ) ?;
204- }
160+ self . sum = match values. data_type ( ) {
161+ DataType :: Int64 => update_sum_internal (
162+ values
163+ . as_any ( )
164+ . downcast_ref :: < PrimitiveArray < Int64Type > > ( )
165+ . unwrap ( ) ,
166+ self . eval_mode ,
167+ self . sum ,
168+ ) ?,
169+ DataType :: Int32 => update_sum_internal (
170+ values
171+ . as_any ( )
172+ . downcast_ref :: < PrimitiveArray < Int32Type > > ( )
173+ . unwrap ( ) ,
174+ self . eval_mode ,
175+ self . sum ,
176+ ) ?,
177+ DataType :: Int16 => update_sum_internal (
178+ values
179+ . as_any ( )
180+ . downcast_ref :: < PrimitiveArray < Int16Type > > ( )
181+ . unwrap ( ) ,
182+ self . eval_mode ,
183+ self . sum ,
184+ ) ?,
185+ DataType :: Int8 => update_sum_internal (
186+ values
187+ . as_any ( )
188+ . downcast_ref :: < PrimitiveArray < Int8Type > > ( )
189+ . unwrap ( ) ,
190+ self . eval_mode ,
191+ self . sum ,
192+ ) ?,
205193 _ => {
206- println ! ( "unsupported input data type: {:?}" , self . input_data_type) ;
207194 panic ! ( "Unsupported data type" )
208195 }
209196 } ;
@@ -246,19 +233,19 @@ impl Accumulator for SumIntegerAccumulator {
246233 }
247234 EvalMode :: Ansi => match self . sum . add_checked ( that_sum. value ( 0 ) ) {
248235 Ok ( v) => self . sum = v,
249- Err ( e) => return Err ( DataFusionError :: Internal ( "error" . to_string ( ) ) ) ,
236+ Err ( e) => return Err ( DataFusionError :: from ( arithmetic_overflow_error ( "integer" ) ) ) ,
250237 } ,
251238 }
252239 Ok ( ( ) )
253240 }
254241}
255242
256- struct SumDecimalGroupsAccumulator {
243+ struct SumIntGroupsAccumulator {
257244 sums : Vec < i64 > ,
258245 eval_mode : EvalMode ,
259246}
260247
261- impl SumDecimalGroupsAccumulator {
248+ impl SumIntGroupsAccumulator {
262249 fn new ( eval_mode : EvalMode ) -> Self {
263250 Self {
264251 sums : Vec :: new ( ) ,
@@ -267,7 +254,7 @@ impl SumDecimalGroupsAccumulator {
267254 }
268255}
269256
270- impl GroupsAccumulator for SumDecimalGroupsAccumulator {
257+ impl GroupsAccumulator for SumIntGroupsAccumulator {
271258 fn update_batch (
272259 & mut self ,
273260 values : & [ ArrayRef ] ,
@@ -285,13 +272,13 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator {
285272 for ( & group_index, & value) in iter {
286273 match self . eval_mode {
287274 EvalMode :: Legacy | EvalMode :: Try => {
288- self . sums [ group_index] . add_wrapping ( value) ;
275+ self . sums [ group_index] = self . sums [ group_index ] . add_wrapping ( value) ;
289276 }
290277 EvalMode :: Ansi => {
291278 match self . sums [ group_index] . add_checked ( value) {
292- Ok ( v) => v,
279+ Ok ( v) => self . sums [ group_index ] = v,
293280 Err ( e) => {
294- return Err ( DataFusionError :: Internal ( "integer overflow" . to_string ( ) ) )
281+ return Err ( DataFusionError :: from ( arithmetic_overflow_error ( "integer" ) ) )
295282 }
296283 } ;
297284 }
@@ -344,11 +331,11 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator {
344331 for ( & group_index, & value) in iter {
345332 match self . eval_mode {
346333 EvalMode :: Legacy | EvalMode :: Try => {
347- self . sums [ group_index] . add_wrapping ( value) ;
334+ self . sums [ group_index] = self . sums [ group_index ] . add_wrapping ( value) ;
348335 }
349336 EvalMode :: Ansi => {
350337 match self . sums [ group_index] . add_checked ( value) {
351- Ok ( v) => v,
338+ Ok ( v) => self . sums [ group_index ] = v,
352339 Err ( e) => {
353340 return Err ( DataFusionError :: Internal ( "integer overflow" . to_string ( ) ) )
354341 }
0 commit comments