1515// specific language governing permissions and limitations
1616// under the License.
1717
18- use arrow:: array:: ArrayAccessor ;
19- use arrow:: array:: ArrayIter ;
20- use arrow:: array:: ArrayRef ;
21- use arrow:: array:: AsArray ;
22- use arrow:: datatypes:: DataType ;
23- use datafusion:: arrow;
24- use datafusion:: common:: cast:: as_primitive_array;
25- use datafusion:: common:: cast:: as_string_array;
26- use datafusion:: error:: Result ;
27- use datafusion:: logical_expr:: Accumulator ;
28- use datafusion:: scalar:: ScalarValue ;
29- use std:: collections:: HashMap ;
18+ use datafusion:: arrow:: array:: AsArray ;
19+ use datafusion:: { arrow, common, error, logical_expr, scalar} ;
20+ use std:: collections;
3021
3122#[ derive( Debug ) ]
3223pub struct BytesModeAccumulator {
33- value_counts : HashMap < String , i64 > ,
34- data_type : DataType ,
24+ value_counts : collections :: HashMap < String , i64 > ,
25+ data_type : arrow :: datatypes :: DataType ,
3526}
3627
3728impl BytesModeAccumulator {
38- pub fn new ( data_type : & DataType ) -> Self {
29+ pub fn new ( data_type : & arrow :: datatypes :: DataType ) -> Self {
3930 Self {
40- value_counts : HashMap :: new ( ) ,
31+ value_counts : collections :: HashMap :: new ( ) ,
4132 data_type : data_type. clone ( ) ,
4233 }
4334 }
4435
4536 fn update_counts < ' a , V > ( & mut self , array : V )
4637 where
47- V : ArrayAccessor < Item = & ' a str > ,
38+ V : arrow :: array :: ArrayAccessor < Item = & ' a str > ,
4839 {
49- for value in ArrayIter :: new ( array) . flatten ( ) {
40+ for value in arrow :: array :: ArrayIter :: new ( array) . flatten ( ) {
5041 let key = value;
5142 if let Some ( count) = self . value_counts . get_mut ( key) {
5243 * count += 1 ;
@@ -57,14 +48,14 @@ impl BytesModeAccumulator {
5748 }
5849}
5950
60- impl Accumulator for BytesModeAccumulator {
61- fn update_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
51+ impl logical_expr :: Accumulator for BytesModeAccumulator {
52+ fn update_batch ( & mut self , values : & [ arrow :: array :: ArrayRef ] ) -> error :: Result < ( ) > {
6253 if values. is_empty ( ) {
6354 return Ok ( ( ) ) ;
6455 }
6556
6657 match & self . data_type {
67- DataType :: Utf8View => {
58+ arrow :: datatypes :: DataType :: Utf8View => {
6859 let array = values[ 0 ] . as_string_view ( ) ;
6960 self . update_counts ( array) ;
7061 }
@@ -77,35 +68,36 @@ impl Accumulator for BytesModeAccumulator {
7768 Ok ( ( ) )
7869 }
7970
80- fn state ( & mut self ) -> Result < Vec < ScalarValue > > {
81- let values: Vec < ScalarValue > = self
71+ fn state ( & mut self ) -> error :: Result < Vec < scalar :: ScalarValue > > {
72+ let values: Vec < scalar :: ScalarValue > = self
8273 . value_counts
8374 . keys ( )
84- . map ( |key| ScalarValue :: Utf8 ( Some ( key. to_string ( ) ) ) )
75+ . map ( |key| scalar :: ScalarValue :: Utf8 ( Some ( key. to_string ( ) ) ) )
8576 . collect ( ) ;
8677
87- let frequencies: Vec < ScalarValue > = self
78+ let frequencies: Vec < scalar :: ScalarValue > = self
8879 . value_counts
8980 . values ( )
90- . map ( |& count| ScalarValue :: Int64 ( Some ( count) ) )
81+ . map ( |& count| scalar :: ScalarValue :: Int64 ( Some ( count) ) )
9182 . collect ( ) ;
9283
93- let values_scalar = ScalarValue :: new_list_nullable ( & values, & DataType :: Utf8 ) ;
94- let frequencies_scalar = ScalarValue :: new_list_nullable ( & frequencies, & DataType :: Int64 ) ;
84+ let values_scalar = scalar:: ScalarValue :: new_list_nullable ( & values, & arrow:: datatypes:: DataType :: Utf8 ) ;
85+ let frequencies_scalar =
86+ scalar:: ScalarValue :: new_list_nullable ( & frequencies, & arrow:: datatypes:: DataType :: Int64 ) ;
9587
9688 Ok ( vec ! [
97- ScalarValue :: List ( values_scalar) ,
98- ScalarValue :: List ( frequencies_scalar) ,
89+ scalar :: ScalarValue :: List ( values_scalar) ,
90+ scalar :: ScalarValue :: List ( frequencies_scalar) ,
9991 ] )
10092 }
10193
102- fn merge_batch ( & mut self , states : & [ ArrayRef ] ) -> Result < ( ) > {
94+ fn merge_batch ( & mut self , states : & [ arrow :: array :: ArrayRef ] ) -> error :: Result < ( ) > {
10395 if states. is_empty ( ) {
10496 return Ok ( ( ) ) ;
10597 }
10698
107- let values_array = as_string_array ( & states[ 0 ] ) ? ;
108- let counts_array = as_primitive_array :: < arrow:: datatypes:: Int64Type > ( & states[ 1 ] ) ?;
99+ let values_array = arrow :: array :: as_string_array ( & states[ 0 ] ) ;
100+ let counts_array = common :: cast :: as_primitive_array :: < arrow:: datatypes:: Int64Type > ( & states[ 1 ] ) ?;
109101
110102 for ( i, value_option) in values_array. iter ( ) . enumerate ( ) {
111103 if let Some ( value) = value_option {
@@ -118,11 +110,11 @@ impl Accumulator for BytesModeAccumulator {
118110 Ok ( ( ) )
119111 }
120112
121- fn evaluate ( & mut self ) -> Result < ScalarValue > {
113+ fn evaluate ( & mut self ) -> error :: Result < scalar :: ScalarValue > {
122114 if self . value_counts . is_empty ( ) {
123115 return match & self . data_type {
124- DataType :: Utf8View => Ok ( ScalarValue :: Utf8View ( None ) ) ,
125- _ => Ok ( ScalarValue :: Utf8 ( None ) ) ,
116+ arrow :: datatypes :: DataType :: Utf8View => Ok ( scalar :: ScalarValue :: Utf8View ( None ) ) ,
117+ _ => Ok ( scalar :: ScalarValue :: Utf8 ( None ) ) ,
126118 } ;
127119 }
128120
@@ -139,12 +131,12 @@ impl Accumulator for BytesModeAccumulator {
139131
140132 match mode {
141133 Some ( result) => match & self . data_type {
142- DataType :: Utf8View => Ok ( ScalarValue :: Utf8View ( Some ( result) ) ) ,
143- _ => Ok ( ScalarValue :: Utf8 ( Some ( result) ) ) ,
134+ arrow :: datatypes :: DataType :: Utf8View => Ok ( scalar :: ScalarValue :: Utf8View ( Some ( result) ) ) ,
135+ _ => Ok ( scalar :: ScalarValue :: Utf8 ( Some ( result) ) ) ,
144136 } ,
145137 None => match & self . data_type {
146- DataType :: Utf8View => Ok ( ScalarValue :: Utf8View ( None ) ) ,
147- _ => Ok ( ScalarValue :: Utf8 ( None ) ) ,
138+ arrow :: datatypes :: DataType :: Utf8View => Ok ( scalar :: ScalarValue :: Utf8View ( None ) ) ,
139+ _ => Ok ( scalar :: ScalarValue :: Utf8 ( None ) ) ,
148140 } ,
149141 }
150142 }
@@ -156,14 +148,16 @@ impl Accumulator for BytesModeAccumulator {
156148
157149#[ cfg( test) ]
158150mod tests {
151+
159152 use super :: * ;
160- use arrow:: array:: { ArrayRef , GenericByteViewArray , StringArray } ;
161- use std:: sync:: Arc ;
153+
154+ use datafusion:: logical_expr:: Accumulator ;
155+ use std:: sync;
162156
163157 #[ test]
164- fn test_mode_accumulator_single_mode_utf8 ( ) -> Result < ( ) > {
165- let mut acc = BytesModeAccumulator :: new ( & DataType :: Utf8 ) ;
166- let values: ArrayRef = Arc :: new ( StringArray :: from ( vec ! [
158+ fn test_mode_accumulator_single_mode_utf8 ( ) -> error :: Result < ( ) > {
159+ let mut acc = BytesModeAccumulator :: new ( & arrow :: datatypes :: DataType :: Utf8 ) ;
160+ let values: arrow :: array :: ArrayRef = sync :: Arc :: new ( arrow :: array :: StringArray :: from ( vec ! [
167161 Some ( "apple" ) ,
168162 Some ( "banana" ) ,
169163 Some ( "apple" ) ,
@@ -175,14 +169,14 @@ mod tests {
175169 acc. update_batch ( & [ values] ) ?;
176170 let result = acc. evaluate ( ) ?;
177171
178- assert_eq ! ( result, ScalarValue :: Utf8 ( Some ( "apple" . to_string( ) ) ) ) ;
172+ assert_eq ! ( result, scalar :: ScalarValue :: Utf8 ( Some ( "apple" . to_string( ) ) ) ) ;
179173 Ok ( ( ) )
180174 }
181175
182176 #[ test]
183- fn test_mode_accumulator_tie_utf8 ( ) -> Result < ( ) > {
184- let mut acc = BytesModeAccumulator :: new ( & DataType :: Utf8 ) ;
185- let values: ArrayRef = Arc :: new ( StringArray :: from ( vec ! [
177+ fn test_mode_accumulator_tie_utf8 ( ) -> error :: Result < ( ) > {
178+ let mut acc = BytesModeAccumulator :: new ( & arrow :: datatypes :: DataType :: Utf8 ) ;
179+ let values: arrow :: array :: ArrayRef = sync :: Arc :: new ( arrow :: array :: StringArray :: from ( vec ! [
186180 Some ( "apple" ) ,
187181 Some ( "banana" ) ,
188182 Some ( "apple" ) ,
@@ -193,26 +187,27 @@ mod tests {
193187 acc. update_batch ( & [ values] ) ?;
194188 let result = acc. evaluate ( ) ?;
195189
196- assert_eq ! ( result, ScalarValue :: Utf8 ( Some ( "apple" . to_string( ) ) ) ) ;
190+ assert_eq ! ( result, scalar :: ScalarValue :: Utf8 ( Some ( "apple" . to_string( ) ) ) ) ;
197191 Ok ( ( ) )
198192 }
199193
200194 #[ test]
201- fn test_mode_accumulator_all_nulls_utf8 ( ) -> Result < ( ) > {
202- let mut acc = BytesModeAccumulator :: new ( & DataType :: Utf8 ) ;
203- let values: ArrayRef = Arc :: new ( StringArray :: from ( vec ! [ None as Option <& str >, None , None ] ) ) ;
195+ fn test_mode_accumulator_all_nulls_utf8 ( ) -> error:: Result < ( ) > {
196+ let mut acc = BytesModeAccumulator :: new ( & arrow:: datatypes:: DataType :: Utf8 ) ;
197+ let values: arrow:: array:: ArrayRef =
198+ sync:: Arc :: new ( arrow:: array:: StringArray :: from ( vec ! [ None as Option <& str >, None , None ] ) ) ;
204199
205200 acc. update_batch ( & [ values] ) ?;
206201 let result = acc. evaluate ( ) ?;
207202
208- assert_eq ! ( result, ScalarValue :: Utf8 ( None ) ) ;
203+ assert_eq ! ( result, scalar :: ScalarValue :: Utf8 ( None ) ) ;
209204 Ok ( ( ) )
210205 }
211206
212207 #[ test]
213- fn test_mode_accumulator_with_nulls_utf8 ( ) -> Result < ( ) > {
214- let mut acc = BytesModeAccumulator :: new ( & DataType :: Utf8 ) ;
215- let values: ArrayRef = Arc :: new ( StringArray :: from ( vec ! [
208+ fn test_mode_accumulator_with_nulls_utf8 ( ) -> error :: Result < ( ) > {
209+ let mut acc = BytesModeAccumulator :: new ( & arrow :: datatypes :: DataType :: Utf8 ) ;
210+ let values: arrow :: array :: ArrayRef = sync :: Arc :: new ( arrow :: array :: StringArray :: from ( vec ! [
216211 Some ( "apple" ) ,
217212 None ,
218213 Some ( "banana" ) ,
@@ -226,14 +221,14 @@ mod tests {
226221 acc. update_batch ( & [ values] ) ?;
227222 let result = acc. evaluate ( ) ?;
228223
229- assert_eq ! ( result, ScalarValue :: Utf8 ( Some ( "apple" . to_string( ) ) ) ) ;
224+ assert_eq ! ( result, scalar :: ScalarValue :: Utf8 ( Some ( "apple" . to_string( ) ) ) ) ;
230225 Ok ( ( ) )
231226 }
232227
233228 #[ test]
234- fn test_mode_accumulator_single_mode_utf8view ( ) -> Result < ( ) > {
235- let mut acc = BytesModeAccumulator :: new ( & DataType :: Utf8View ) ;
236- let values: ArrayRef = Arc :: new ( GenericByteViewArray :: from ( vec ! [
229+ fn test_mode_accumulator_single_mode_utf8view ( ) -> error :: Result < ( ) > {
230+ let mut acc = BytesModeAccumulator :: new ( & arrow :: datatypes :: DataType :: Utf8View ) ;
231+ let values: arrow :: array :: ArrayRef = sync :: Arc :: new ( arrow :: array :: GenericByteViewArray :: from ( vec ! [
237232 Some ( "apple" ) ,
238233 Some ( "banana" ) ,
239234 Some ( "apple" ) ,
@@ -245,14 +240,14 @@ mod tests {
245240 acc. update_batch ( & [ values] ) ?;
246241 let result = acc. evaluate ( ) ?;
247242
248- assert_eq ! ( result, ScalarValue :: Utf8View ( Some ( "apple" . to_string( ) ) ) ) ;
243+ assert_eq ! ( result, scalar :: ScalarValue :: Utf8View ( Some ( "apple" . to_string( ) ) ) ) ;
249244 Ok ( ( ) )
250245 }
251246
252247 #[ test]
253- fn test_mode_accumulator_tie_utf8view ( ) -> Result < ( ) > {
254- let mut acc = BytesModeAccumulator :: new ( & DataType :: Utf8View ) ;
255- let values: ArrayRef = Arc :: new ( GenericByteViewArray :: from ( vec ! [
248+ fn test_mode_accumulator_tie_utf8view ( ) -> error :: Result < ( ) > {
249+ let mut acc = BytesModeAccumulator :: new ( & arrow :: datatypes :: DataType :: Utf8View ) ;
250+ let values: arrow :: array :: ArrayRef = sync :: Arc :: new ( arrow :: array :: GenericByteViewArray :: from ( vec ! [
256251 Some ( "apple" ) ,
257252 Some ( "banana" ) ,
258253 Some ( "apple" ) ,
@@ -263,26 +258,30 @@ mod tests {
263258 acc. update_batch ( & [ values] ) ?;
264259 let result = acc. evaluate ( ) ?;
265260
266- assert_eq ! ( result, ScalarValue :: Utf8View ( Some ( "apple" . to_string( ) ) ) ) ;
261+ assert_eq ! ( result, scalar :: ScalarValue :: Utf8View ( Some ( "apple" . to_string( ) ) ) ) ;
267262 Ok ( ( ) )
268263 }
269264
270265 #[ test]
271- fn test_mode_accumulator_all_nulls_utf8view ( ) -> Result < ( ) > {
272- let mut acc = BytesModeAccumulator :: new ( & DataType :: Utf8View ) ;
273- let values: ArrayRef = Arc :: new ( GenericByteViewArray :: from ( vec ! [ None as Option <& str >, None , None ] ) ) ;
266+ fn test_mode_accumulator_all_nulls_utf8view ( ) -> error:: Result < ( ) > {
267+ let mut acc = BytesModeAccumulator :: new ( & arrow:: datatypes:: DataType :: Utf8View ) ;
268+ let values: arrow:: array:: ArrayRef = sync:: Arc :: new ( arrow:: array:: GenericByteViewArray :: from ( vec ! [
269+ None as Option <& str >,
270+ None ,
271+ None ,
272+ ] ) ) ;
274273
275274 acc. update_batch ( & [ values] ) ?;
276275 let result = acc. evaluate ( ) ?;
277276
278- assert_eq ! ( result, ScalarValue :: Utf8View ( None ) ) ;
277+ assert_eq ! ( result, scalar :: ScalarValue :: Utf8View ( None ) ) ;
279278 Ok ( ( ) )
280279 }
281280
282281 #[ test]
283- fn test_mode_accumulator_with_nulls_utf8view ( ) -> Result < ( ) > {
284- let mut acc = BytesModeAccumulator :: new ( & DataType :: Utf8View ) ;
285- let values: ArrayRef = Arc :: new ( GenericByteViewArray :: from ( vec ! [
282+ fn test_mode_accumulator_with_nulls_utf8view ( ) -> error :: Result < ( ) > {
283+ let mut acc = BytesModeAccumulator :: new ( & arrow :: datatypes :: DataType :: Utf8View ) ;
284+ let values: arrow :: array :: ArrayRef = sync :: Arc :: new ( arrow :: array :: GenericByteViewArray :: from ( vec ! [
286285 Some ( "apple" ) ,
287286 None ,
288287 Some ( "banana" ) ,
@@ -296,7 +295,7 @@ mod tests {
296295 acc. update_batch ( & [ values] ) ?;
297296 let result = acc. evaluate ( ) ?;
298297
299- assert_eq ! ( result, ScalarValue :: Utf8View ( Some ( "apple" . to_string( ) ) ) ) ;
298+ assert_eq ! ( result, scalar :: ScalarValue :: Utf8View ( Some ( "apple" . to_string( ) ) ) ) ;
300299 Ok ( ( ) )
301300 }
302301}
0 commit comments