@@ -24,10 +24,11 @@ use std::sync::Arc;
2424use crate :: error:: { DataFusionError , Result } ;
2525use crate :: physical_plan:: groups_accumulator:: GroupsAccumulator ;
2626use crate :: physical_plan:: groups_accumulator_flat_adapter:: GroupsAccumulatorFlatAdapter ;
27+ use crate :: physical_plan:: groups_accumulator_prim_op:: PrimitiveGroupsAccumulator ;
2728use crate :: physical_plan:: { Accumulator , AggregateExpr , PhysicalExpr } ;
2829use crate :: scalar:: ScalarValue ;
2930use arrow:: compute;
30- use arrow:: datatypes:: { DataType , TimeUnit } ;
31+ use arrow:: datatypes:: { ArrowPrimitiveType , DataType , TimeUnit } ;
3132use arrow:: {
3233 array:: {
3334 ArrayRef , Float32Array , Float64Array , Int16Array , Int32Array , Int64Array ,
@@ -108,12 +109,92 @@ impl AggregateExpr for Max {
108109 fn create_groups_accumulator (
109110 & self ,
110111 ) -> arrow:: error:: Result < Option < Box < dyn GroupsAccumulator > > > {
111- let data_type = self . data_type . clone ( ) ;
112- Ok ( Some ( Box :: new (
113- GroupsAccumulatorFlatAdapter :: < MaxAccumulator > :: new ( move || {
114- MaxAccumulator :: try_new ( & data_type)
115- } ) ,
116- ) ) )
112+ use arrow:: datatypes:: ArrowPrimitiveType ;
113+
114+ macro_rules! make_max_accumulator {
115+ ( $T: ty) => {
116+ Box :: new(
117+ PrimitiveGroupsAccumulator :: <$T, $T, _, _>:: new(
118+ & <$T as ArrowPrimitiveType >:: DATA_TYPE ,
119+ |x: & mut <$T as ArrowPrimitiveType >:: Native ,
120+ y: <$T as ArrowPrimitiveType >:: Native | {
121+ * x = ( * x) . max( y) ;
122+ } ,
123+ |x: & mut <$T as ArrowPrimitiveType >:: Native ,
124+ y: <$T as ArrowPrimitiveType >:: Native | {
125+ * x = ( * x) . max( y) ;
126+ } ,
127+ )
128+ . with_starting_value( <$T as ArrowPrimitiveType >:: Native :: MIN ) ,
129+ )
130+ } ;
131+ }
132+ let acc: Box < dyn GroupsAccumulator > = match & self . data_type {
133+ DataType :: Float64 => make_max_accumulator ! ( arrow:: datatypes:: Float64Type ) ,
134+ DataType :: Float32 => make_max_accumulator ! ( arrow:: datatypes:: Float32Type ) ,
135+ DataType :: Int64 => make_max_accumulator ! ( arrow:: datatypes:: Int64Type ) ,
136+ DataType :: Int96 => make_max_accumulator ! ( arrow:: datatypes:: Int96Type ) ,
137+ DataType :: Int64Decimal ( 0 ) => {
138+ make_max_accumulator ! ( arrow:: datatypes:: Int64Decimal0Type )
139+ }
140+ DataType :: Int64Decimal ( 1 ) => {
141+ make_max_accumulator ! ( arrow:: datatypes:: Int64Decimal1Type )
142+ }
143+ DataType :: Int64Decimal ( 2 ) => {
144+ make_max_accumulator ! ( arrow:: datatypes:: Int64Decimal2Type )
145+ }
146+ DataType :: Int64Decimal ( 3 ) => {
147+ make_max_accumulator ! ( arrow:: datatypes:: Int64Decimal3Type )
148+ }
149+ DataType :: Int64Decimal ( 4 ) => {
150+ make_max_accumulator ! ( arrow:: datatypes:: Int64Decimal4Type )
151+ }
152+ DataType :: Int64Decimal ( 5 ) => {
153+ make_max_accumulator ! ( arrow:: datatypes:: Int64Decimal5Type )
154+ }
155+ DataType :: Int64Decimal ( 10 ) => {
156+ make_max_accumulator ! ( arrow:: datatypes:: Int64Decimal10Type )
157+ }
158+ DataType :: Int96Decimal ( 0 ) => {
159+ make_max_accumulator ! ( arrow:: datatypes:: Int96Decimal0Type )
160+ }
161+ DataType :: Int96Decimal ( 1 ) => {
162+ make_max_accumulator ! ( arrow:: datatypes:: Int96Decimal1Type )
163+ }
164+ DataType :: Int96Decimal ( 2 ) => {
165+ make_max_accumulator ! ( arrow:: datatypes:: Int96Decimal2Type )
166+ }
167+ DataType :: Int96Decimal ( 3 ) => {
168+ make_max_accumulator ! ( arrow:: datatypes:: Int96Decimal3Type )
169+ }
170+ DataType :: Int96Decimal ( 4 ) => {
171+ make_max_accumulator ! ( arrow:: datatypes:: Int96Decimal4Type )
172+ }
173+ DataType :: Int96Decimal ( 5 ) => {
174+ make_max_accumulator ! ( arrow:: datatypes:: Int96Decimal5Type )
175+ }
176+ DataType :: Int96Decimal ( 10 ) => {
177+ make_max_accumulator ! ( arrow:: datatypes:: Int96Decimal10Type )
178+ }
179+ DataType :: Int32 => make_max_accumulator ! ( arrow:: datatypes:: Int32Type ) ,
180+ DataType :: Int16 => make_max_accumulator ! ( arrow:: datatypes:: Int16Type ) ,
181+ DataType :: Int8 => make_max_accumulator ! ( arrow:: datatypes:: Int8Type ) ,
182+ DataType :: UInt64 => make_max_accumulator ! ( arrow:: datatypes:: UInt64Type ) ,
183+ DataType :: UInt32 => make_max_accumulator ! ( arrow:: datatypes:: UInt32Type ) ,
184+ DataType :: UInt16 => make_max_accumulator ! ( arrow:: datatypes:: UInt16Type ) ,
185+ DataType :: UInt8 => make_max_accumulator ! ( arrow:: datatypes:: UInt8Type ) ,
186+ _ => {
187+ // Not all types (strings) can use primitive accumulators. And strings use
188+ // max_string as the $OP in typed_min_match_batch.
189+
190+ // Timestamps presently take this branch.
191+ let data_type = self . data_type . clone ( ) ;
192+ Box :: new ( GroupsAccumulatorFlatAdapter :: < MaxAccumulator > :: new (
193+ move || MaxAccumulator :: try_new ( & data_type) ,
194+ ) )
195+ }
196+ } ;
197+ Ok ( Some ( acc) )
117198 }
118199
119200 fn name ( & self ) -> & str {
@@ -547,12 +628,91 @@ impl AggregateExpr for Min {
547628 fn create_groups_accumulator (
548629 & self ,
549630 ) -> arrow:: error:: Result < Option < Box < dyn GroupsAccumulator > > > {
550- let data_type = self . data_type . clone ( ) ;
551- Ok ( Some ( Box :: new (
552- GroupsAccumulatorFlatAdapter :: < MinAccumulator > :: new ( move || {
553- MinAccumulator :: try_new ( & data_type)
554- } ) ,
555- ) ) )
631+ macro_rules! make_min_accumulator {
632+ ( $T: ty) => {
633+ Box :: new(
634+ PrimitiveGroupsAccumulator :: <$T, $T, _, _>:: new(
635+ & <$T as ArrowPrimitiveType >:: DATA_TYPE ,
636+ |x: & mut <$T as ArrowPrimitiveType >:: Native ,
637+ y: <$T as ArrowPrimitiveType >:: Native | {
638+ * x = ( * x) . min( y) ;
639+ } ,
640+ |x: & mut <$T as ArrowPrimitiveType >:: Native ,
641+ y: <$T as ArrowPrimitiveType >:: Native | {
642+ * x = ( * x) . min( y) ;
643+ } ,
644+ )
645+ . with_starting_value( <$T as ArrowPrimitiveType >:: Native :: MAX ) ,
646+ )
647+ } ;
648+ }
649+
650+ let acc: Box < dyn GroupsAccumulator > = match & self . data_type {
651+ DataType :: Float64 => make_min_accumulator ! ( arrow:: datatypes:: Float64Type ) ,
652+ DataType :: Float32 => make_min_accumulator ! ( arrow:: datatypes:: Float32Type ) ,
653+ DataType :: Int64 => make_min_accumulator ! ( arrow:: datatypes:: Int64Type ) ,
654+ DataType :: Int96 => make_min_accumulator ! ( arrow:: datatypes:: Int96Type ) ,
655+ DataType :: Int64Decimal ( 0 ) => {
656+ make_min_accumulator ! ( arrow:: datatypes:: Int64Decimal0Type )
657+ }
658+ DataType :: Int64Decimal ( 1 ) => {
659+ make_min_accumulator ! ( arrow:: datatypes:: Int64Decimal1Type )
660+ }
661+ DataType :: Int64Decimal ( 2 ) => {
662+ make_min_accumulator ! ( arrow:: datatypes:: Int64Decimal2Type )
663+ }
664+ DataType :: Int64Decimal ( 3 ) => {
665+ make_min_accumulator ! ( arrow:: datatypes:: Int64Decimal3Type )
666+ }
667+ DataType :: Int64Decimal ( 4 ) => {
668+ make_min_accumulator ! ( arrow:: datatypes:: Int64Decimal4Type )
669+ }
670+ DataType :: Int64Decimal ( 5 ) => {
671+ make_min_accumulator ! ( arrow:: datatypes:: Int64Decimal5Type )
672+ }
673+ DataType :: Int64Decimal ( 10 ) => {
674+ make_min_accumulator ! ( arrow:: datatypes:: Int64Decimal10Type )
675+ }
676+ DataType :: Int96Decimal ( 0 ) => {
677+ make_min_accumulator ! ( arrow:: datatypes:: Int96Decimal0Type )
678+ }
679+ DataType :: Int96Decimal ( 1 ) => {
680+ make_min_accumulator ! ( arrow:: datatypes:: Int96Decimal1Type )
681+ }
682+ DataType :: Int96Decimal ( 2 ) => {
683+ make_min_accumulator ! ( arrow:: datatypes:: Int96Decimal2Type )
684+ }
685+ DataType :: Int96Decimal ( 3 ) => {
686+ make_min_accumulator ! ( arrow:: datatypes:: Int96Decimal3Type )
687+ }
688+ DataType :: Int96Decimal ( 4 ) => {
689+ make_min_accumulator ! ( arrow:: datatypes:: Int96Decimal4Type )
690+ }
691+ DataType :: Int96Decimal ( 5 ) => {
692+ make_min_accumulator ! ( arrow:: datatypes:: Int96Decimal5Type )
693+ }
694+ DataType :: Int96Decimal ( 10 ) => {
695+ make_min_accumulator ! ( arrow:: datatypes:: Int96Decimal10Type )
696+ }
697+ DataType :: Int32 => make_min_accumulator ! ( arrow:: datatypes:: Int32Type ) ,
698+ DataType :: Int16 => make_min_accumulator ! ( arrow:: datatypes:: Int16Type ) ,
699+ DataType :: Int8 => make_min_accumulator ! ( arrow:: datatypes:: Int8Type ) ,
700+ DataType :: UInt64 => make_min_accumulator ! ( arrow:: datatypes:: UInt64Type ) ,
701+ DataType :: UInt32 => make_min_accumulator ! ( arrow:: datatypes:: UInt32Type ) ,
702+ DataType :: UInt16 => make_min_accumulator ! ( arrow:: datatypes:: UInt16Type ) ,
703+ DataType :: UInt8 => make_min_accumulator ! ( arrow:: datatypes:: UInt8Type ) ,
704+ _ => {
705+ // Not all types (strings) can use primitive accumulators. And strings use
706+ // min_string as the $OP in typed_min_match_batch.
707+
708+ // Timestamps presently take this branch.
709+ let data_type = self . data_type . clone ( ) ;
710+ Box :: new ( GroupsAccumulatorFlatAdapter :: < MinAccumulator > :: new (
711+ move || MinAccumulator :: try_new ( & data_type) ,
712+ ) )
713+ }
714+ } ;
715+ Ok ( Some ( acc) )
556716 }
557717
558718 fn name ( & self ) -> & str {
0 commit comments