@@ -24,10 +24,11 @@ use std::sync::Arc;
2424use crate :: error:: { DataFusionError , Result } ;
2525use crate :: physical_plan:: groups_accumulator:: GroupsAccumulator ;
2626use crate :: physical_plan:: groups_accumulator_flat_adapter:: GroupsAccumulatorFlatAdapter ;
27+ use crate :: physical_plan:: groups_accumulator_prim_op:: PrimitiveGroupsAccumulator ;
2728use crate :: physical_plan:: { Accumulator , AggregateExpr , PhysicalExpr } ;
2829use crate :: scalar:: ScalarValue ;
2930use arrow:: compute;
30- use arrow:: datatypes:: { DataType , TimeUnit } ;
31+ use arrow:: datatypes:: { ArrowPrimitiveType , DataType , TimeUnit } ;
3132use arrow:: {
3233 array:: {
3334 ArrayRef , Float32Array , Float64Array , Int16Array , Int32Array , Int64Array ,
@@ -108,12 +109,90 @@ impl AggregateExpr for Max {
108109 fn create_groups_accumulator (
109110 & self ,
110111 ) -> arrow:: error:: Result < Option < Box < dyn GroupsAccumulator > > > {
111- let data_type = self . data_type . clone ( ) ;
112- Ok ( Some ( Box :: new (
113- GroupsAccumulatorFlatAdapter :: < MaxAccumulator > :: new ( move || {
114- MaxAccumulator :: try_new ( & data_type)
115- } ) ,
116- ) ) )
112+ macro_rules! make_max_accumulator {
113+ ( $T: ty) => {
114+ Box :: new(
115+ PrimitiveGroupsAccumulator :: <$T, $T, _, _>:: new(
116+ & <$T as ArrowPrimitiveType >:: DATA_TYPE ,
117+ |x: & mut <$T as ArrowPrimitiveType >:: Native ,
118+ y: <$T as ArrowPrimitiveType >:: Native | {
119+ * x = ( * x) . max( y) ;
120+ } ,
121+ |x: & mut <$T as ArrowPrimitiveType >:: Native ,
122+ y: <$T as ArrowPrimitiveType >:: Native | {
123+ * x = ( * x) . max( y) ;
124+ } ,
125+ )
126+ . with_starting_value( <$T as ArrowPrimitiveType >:: Native :: MIN ) ,
127+ )
128+ } ;
129+ }
130+ let acc: Box < dyn GroupsAccumulator > = match & self . data_type {
131+ DataType :: Float64 => make_max_accumulator ! ( arrow:: datatypes:: Float64Type ) ,
132+ DataType :: Float32 => make_max_accumulator ! ( arrow:: datatypes:: Float32Type ) ,
133+ DataType :: Int64 => make_max_accumulator ! ( arrow:: datatypes:: Int64Type ) ,
134+ DataType :: Int96 => make_max_accumulator ! ( arrow:: datatypes:: Int96Type ) ,
135+ DataType :: Int64Decimal ( 0 ) => {
136+ make_max_accumulator ! ( arrow:: datatypes:: Int64Decimal0Type )
137+ }
138+ DataType :: Int64Decimal ( 1 ) => {
139+ make_max_accumulator ! ( arrow:: datatypes:: Int64Decimal1Type )
140+ }
141+ DataType :: Int64Decimal ( 2 ) => {
142+ make_max_accumulator ! ( arrow:: datatypes:: Int64Decimal2Type )
143+ }
144+ DataType :: Int64Decimal ( 3 ) => {
145+ make_max_accumulator ! ( arrow:: datatypes:: Int64Decimal3Type )
146+ }
147+ DataType :: Int64Decimal ( 4 ) => {
148+ make_max_accumulator ! ( arrow:: datatypes:: Int64Decimal4Type )
149+ }
150+ DataType :: Int64Decimal ( 5 ) => {
151+ make_max_accumulator ! ( arrow:: datatypes:: Int64Decimal5Type )
152+ }
153+ DataType :: Int64Decimal ( 10 ) => {
154+ make_max_accumulator ! ( arrow:: datatypes:: Int64Decimal10Type )
155+ }
156+ DataType :: Int96Decimal ( 0 ) => {
157+ make_max_accumulator ! ( arrow:: datatypes:: Int96Decimal0Type )
158+ }
159+ DataType :: Int96Decimal ( 1 ) => {
160+ make_max_accumulator ! ( arrow:: datatypes:: Int96Decimal1Type )
161+ }
162+ DataType :: Int96Decimal ( 2 ) => {
163+ make_max_accumulator ! ( arrow:: datatypes:: Int96Decimal2Type )
164+ }
165+ DataType :: Int96Decimal ( 3 ) => {
166+ make_max_accumulator ! ( arrow:: datatypes:: Int96Decimal3Type )
167+ }
168+ DataType :: Int96Decimal ( 4 ) => {
169+ make_max_accumulator ! ( arrow:: datatypes:: Int96Decimal4Type )
170+ }
171+ DataType :: Int96Decimal ( 5 ) => {
172+ make_max_accumulator ! ( arrow:: datatypes:: Int96Decimal5Type )
173+ }
174+ DataType :: Int96Decimal ( 10 ) => {
175+ make_max_accumulator ! ( arrow:: datatypes:: Int96Decimal10Type )
176+ }
177+ DataType :: Int32 => make_max_accumulator ! ( arrow:: datatypes:: Int32Type ) ,
178+ DataType :: Int16 => make_max_accumulator ! ( arrow:: datatypes:: Int16Type ) ,
179+ DataType :: Int8 => make_max_accumulator ! ( arrow:: datatypes:: Int8Type ) ,
180+ DataType :: UInt64 => make_max_accumulator ! ( arrow:: datatypes:: UInt64Type ) ,
181+ DataType :: UInt32 => make_max_accumulator ! ( arrow:: datatypes:: UInt32Type ) ,
182+ DataType :: UInt16 => make_max_accumulator ! ( arrow:: datatypes:: UInt16Type ) ,
183+ DataType :: UInt8 => make_max_accumulator ! ( arrow:: datatypes:: UInt8Type ) ,
184+ _ => {
185+ // Not all types (strings) can use primitive accumulators. And strings use
186+ // max_string as the $OP in typed_min_match_batch.
187+
188+ // Timestamps presently take this branch.
189+ let data_type = self . data_type . clone ( ) ;
190+ Box :: new ( GroupsAccumulatorFlatAdapter :: < MaxAccumulator > :: new (
191+ move || MaxAccumulator :: try_new ( & data_type) ,
192+ ) )
193+ }
194+ } ;
195+ Ok ( Some ( acc) )
117196 }
118197
119198 fn name ( & self ) -> & str {
@@ -547,12 +626,91 @@ impl AggregateExpr for Min {
547626 fn create_groups_accumulator (
548627 & self ,
549628 ) -> arrow:: error:: Result < Option < Box < dyn GroupsAccumulator > > > {
550- let data_type = self . data_type . clone ( ) ;
551- Ok ( Some ( Box :: new (
552- GroupsAccumulatorFlatAdapter :: < MinAccumulator > :: new ( move || {
553- MinAccumulator :: try_new ( & data_type)
554- } ) ,
555- ) ) )
629+ macro_rules! make_min_accumulator {
630+ ( $T: ty) => {
631+ Box :: new(
632+ PrimitiveGroupsAccumulator :: <$T, $T, _, _>:: new(
633+ & <$T as ArrowPrimitiveType >:: DATA_TYPE ,
634+ |x: & mut <$T as ArrowPrimitiveType >:: Native ,
635+ y: <$T as ArrowPrimitiveType >:: Native | {
636+ * x = ( * x) . min( y) ;
637+ } ,
638+ |x: & mut <$T as ArrowPrimitiveType >:: Native ,
639+ y: <$T as ArrowPrimitiveType >:: Native | {
640+ * x = ( * x) . min( y) ;
641+ } ,
642+ )
643+ . with_starting_value( <$T as ArrowPrimitiveType >:: Native :: MAX ) ,
644+ )
645+ } ;
646+ }
647+
648+ let acc: Box < dyn GroupsAccumulator > = match & self . data_type {
649+ DataType :: Float64 => make_min_accumulator ! ( arrow:: datatypes:: Float64Type ) ,
650+ DataType :: Float32 => make_min_accumulator ! ( arrow:: datatypes:: Float32Type ) ,
651+ DataType :: Int64 => make_min_accumulator ! ( arrow:: datatypes:: Int64Type ) ,
652+ DataType :: Int96 => make_min_accumulator ! ( arrow:: datatypes:: Int96Type ) ,
653+ DataType :: Int64Decimal ( 0 ) => {
654+ make_min_accumulator ! ( arrow:: datatypes:: Int64Decimal0Type )
655+ }
656+ DataType :: Int64Decimal ( 1 ) => {
657+ make_min_accumulator ! ( arrow:: datatypes:: Int64Decimal1Type )
658+ }
659+ DataType :: Int64Decimal ( 2 ) => {
660+ make_min_accumulator ! ( arrow:: datatypes:: Int64Decimal2Type )
661+ }
662+ DataType :: Int64Decimal ( 3 ) => {
663+ make_min_accumulator ! ( arrow:: datatypes:: Int64Decimal3Type )
664+ }
665+ DataType :: Int64Decimal ( 4 ) => {
666+ make_min_accumulator ! ( arrow:: datatypes:: Int64Decimal4Type )
667+ }
668+ DataType :: Int64Decimal ( 5 ) => {
669+ make_min_accumulator ! ( arrow:: datatypes:: Int64Decimal5Type )
670+ }
671+ DataType :: Int64Decimal ( 10 ) => {
672+ make_min_accumulator ! ( arrow:: datatypes:: Int64Decimal10Type )
673+ }
674+ DataType :: Int96Decimal ( 0 ) => {
675+ make_min_accumulator ! ( arrow:: datatypes:: Int96Decimal0Type )
676+ }
677+ DataType :: Int96Decimal ( 1 ) => {
678+ make_min_accumulator ! ( arrow:: datatypes:: Int96Decimal1Type )
679+ }
680+ DataType :: Int96Decimal ( 2 ) => {
681+ make_min_accumulator ! ( arrow:: datatypes:: Int96Decimal2Type )
682+ }
683+ DataType :: Int96Decimal ( 3 ) => {
684+ make_min_accumulator ! ( arrow:: datatypes:: Int96Decimal3Type )
685+ }
686+ DataType :: Int96Decimal ( 4 ) => {
687+ make_min_accumulator ! ( arrow:: datatypes:: Int96Decimal4Type )
688+ }
689+ DataType :: Int96Decimal ( 5 ) => {
690+ make_min_accumulator ! ( arrow:: datatypes:: Int96Decimal5Type )
691+ }
692+ DataType :: Int96Decimal ( 10 ) => {
693+ make_min_accumulator ! ( arrow:: datatypes:: Int96Decimal10Type )
694+ }
695+ DataType :: Int32 => make_min_accumulator ! ( arrow:: datatypes:: Int32Type ) ,
696+ DataType :: Int16 => make_min_accumulator ! ( arrow:: datatypes:: Int16Type ) ,
697+ DataType :: Int8 => make_min_accumulator ! ( arrow:: datatypes:: Int8Type ) ,
698+ DataType :: UInt64 => make_min_accumulator ! ( arrow:: datatypes:: UInt64Type ) ,
699+ DataType :: UInt32 => make_min_accumulator ! ( arrow:: datatypes:: UInt32Type ) ,
700+ DataType :: UInt16 => make_min_accumulator ! ( arrow:: datatypes:: UInt16Type ) ,
701+ DataType :: UInt8 => make_min_accumulator ! ( arrow:: datatypes:: UInt8Type ) ,
702+ _ => {
703+ // Not all types (strings) can use primitive accumulators. And strings use
704+ // min_string as the $OP in typed_min_match_batch.
705+
706+ // Timestamps presently take this branch.
707+ let data_type = self . data_type . clone ( ) ;
708+ Box :: new ( GroupsAccumulatorFlatAdapter :: < MinAccumulator > :: new (
709+ move || MinAccumulator :: try_new ( & data_type) ,
710+ ) )
711+ }
712+ } ;
713+ Ok ( Some ( acc) )
556714 }
557715
558716 fn name ( & self ) -> & str {
0 commit comments