2020//! [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator
2121
2222use arrow:: array:: { Array , BooleanArray , BooleanBufferBuilder , PrimitiveArray } ;
23- use arrow:: buffer:: { BooleanBuffer , NullBuffer } ;
23+ use arrow:: buffer:: NullBuffer ;
2424use arrow:: datatypes:: ArrowPrimitiveType ;
2525
2626use datafusion_expr_common:: groups_accumulator:: EmitTo ;
27+
28+ /// If the input has nulls, then the accumulator must potentially
29+ /// handle each input null value specially (e.g. for `SUM` to mark the
30+ /// corresponding sum as null)
31+ ///
32+ /// If there are filters present, `NullState` tracks if it has seen
33+ /// *any* value for that group (as some values may be filtered
34+ /// out). Without a filter, the accumulator is only passed groups that
35+ /// had at least one value to accumulate so they do not need to track
36+ /// if they have seen values for a particular group.
37+ #[ derive( Debug ) ]
38+ pub enum SeenValues {
39+ /// All groups seen so far have seen at least one non-null value
40+ All {
41+ num_values : usize ,
42+ } ,
43+ // Some groups have not yet seen a non-null value
44+ Some {
45+ values : BooleanBufferBuilder ,
46+ } ,
47+ }
48+
49+ impl Default for SeenValues {
50+ fn default ( ) -> Self {
51+ SeenValues :: All { num_values : 0 }
52+ }
53+ }
54+
55+ impl SeenValues {
56+ /// Return a mutable reference to the `BooleanBufferBuilder` in `SeenValues::Some`.
57+ ///
58+ /// If `self` is `SeenValues::All`, it is transitioned to `SeenValues::Some`
59+ /// by creating a new `BooleanBufferBuilder` where the first `num_values` are true.
60+ ///
61+ /// The builder is then ensured to have at least `total_num_groups` length,
62+ /// with any new entries initialized to false.
63+ fn get_builder ( & mut self , total_num_groups : usize ) -> & mut BooleanBufferBuilder {
64+ match self {
65+ SeenValues :: All { num_values } => {
66+ let mut builder = BooleanBufferBuilder :: new ( total_num_groups) ;
67+ builder. append_n ( * num_values, true ) ;
68+ if total_num_groups > * num_values {
69+ builder. append_n ( total_num_groups - * num_values, false ) ;
70+ }
71+ * self = SeenValues :: Some { values : builder } ;
72+ match self {
73+ SeenValues :: Some { values } => values,
74+ _ => unreachable ! ( ) ,
75+ }
76+ }
77+ SeenValues :: Some { values } => {
78+ if values. len ( ) < total_num_groups {
79+ values. append_n ( total_num_groups - values. len ( ) , false ) ;
80+ }
81+ values
82+ }
83+ }
84+ }
85+ }
86+
2787/// Track the accumulator null state per row: if any values for that
2888/// group were null and if any values have been seen at all for that group.
2989///
@@ -53,12 +113,14 @@ use datafusion_expr_common::groups_accumulator::EmitTo;
53113pub struct NullState {
54114 /// Have we seen any non-filtered input values for `group_index`?
55115 ///
56- /// If `seen_values[i]` is true, have seen at least one non null
116+ /// If `seen_values` is `SeenValues::Some(buffer)` and buffer\[i\] is true, have seen at least one non null
57117 /// value for group `i`
58118 ///
59- /// If `seen_values[i]` is false, have not seen any values that
119+ /// If `seen_values` is `SeenValues::Some(buffer)` and buffer\[i\] is false, have not seen any values that
60120 /// pass the filter yet for group `i`
61- seen_values : BooleanBufferBuilder ,
121+ ///
122+ /// If `seen_values` is `SeenValues::All`, all groups have seen at least one non null value
123+ seen_values : SeenValues ,
62124}
63125
64126impl Default for NullState {
@@ -70,14 +132,16 @@ impl Default for NullState {
70132impl NullState {
71133 pub fn new ( ) -> Self {
72134 Self {
73- seen_values : BooleanBufferBuilder :: new ( 0 ) ,
135+ seen_values : SeenValues :: All { num_values : 0 } ,
74136 }
75137 }
76138
77139 /// return the size of all buffers allocated by this null state, not including self
78140 pub fn size ( & self ) -> usize {
79- // capacity is in bits, so convert to bytes
80- self . seen_values . capacity ( ) / 8
141+ match & self . seen_values {
142+ SeenValues :: All { .. } => 0 ,
143+ SeenValues :: Some { values } => values. capacity ( ) / 8 ,
144+ }
81145 }
82146
83147 /// Invokes `value_fn(group_index, value)` for each non null, non
@@ -107,10 +171,17 @@ impl NullState {
107171 T : ArrowPrimitiveType + Send ,
108172 F : FnMut ( usize , T :: Native ) + Send ,
109173 {
110- // ensure the seen_values is big enough (start everything at
111- // "not seen" valid)
112- let seen_values =
113- initialize_builder ( & mut self . seen_values , total_num_groups, false ) ;
174+ // skip null handling if no nulls in input or accumulator
175+ if let SeenValues :: All { num_values } = & mut self . seen_values
176+ && opt_filter. is_none ( )
177+ && values. null_count ( ) == 0
178+ {
179+ accumulate ( group_indices, values, None , value_fn) ;
180+ * num_values = total_num_groups;
181+ return ;
182+ }
183+
184+ let seen_values = self . seen_values . get_builder ( total_num_groups) ;
114185 accumulate ( group_indices, values, opt_filter, |group_index, value| {
115186 seen_values. set_bit ( group_index, true ) ;
116187 value_fn ( group_index, value) ;
@@ -140,10 +211,21 @@ impl NullState {
140211 let data = values. values ( ) ;
141212 assert_eq ! ( data. len( ) , group_indices. len( ) ) ;
142213
143- // ensure the seen_values is big enough (start everything at
144- // "not seen" valid)
145- let seen_values =
146- initialize_builder ( & mut self . seen_values , total_num_groups, false ) ;
214+ // skip null handling if no nulls in input or accumulator
215+ if let SeenValues :: All { num_values } = & mut self . seen_values
216+ && opt_filter. is_none ( )
217+ && values. null_count ( ) == 0
218+ {
219+ group_indices
220+ . iter ( )
221+ . zip ( data. iter ( ) )
222+ . for_each ( |( & group_index, new_value) | value_fn ( group_index, new_value) ) ;
223+ * num_values = total_num_groups;
224+
225+ return ;
226+ }
227+
228+ let seen_values = self . seen_values . get_builder ( total_num_groups) ;
147229
148230 // These could be made more performant by iterating in chunks of 64 bits at a time
149231 match ( values. null_count ( ) > 0 , opt_filter) {
@@ -211,21 +293,39 @@ impl NullState {
211293 /// for the `emit_to` rows.
212294 ///
213295 /// resets the internal state appropriately
214- pub fn build ( & mut self , emit_to : EmitTo ) -> NullBuffer {
215- let nulls: BooleanBuffer = self . seen_values . finish ( ) ;
216-
217- let nulls = match emit_to {
218- EmitTo :: All => nulls,
219- EmitTo :: First ( n) => {
220- // split off the first N values in seen_values
221- let first_n_null: BooleanBuffer = nulls. slice ( 0 , n) ;
222- // reset the existing seen buffer
223- self . seen_values
224- . append_buffer ( & nulls. slice ( n, nulls. len ( ) - n) ) ;
225- first_n_null
296+ pub fn build ( & mut self , emit_to : EmitTo ) -> Option < NullBuffer > {
297+ match emit_to {
298+ EmitTo :: All => {
299+ let old_seen = std:: mem:: take ( & mut self . seen_values ) ;
300+ match old_seen {
301+ SeenValues :: All { .. } => None ,
302+ SeenValues :: Some { mut values } => {
303+ Some ( NullBuffer :: new ( values. finish ( ) ) )
304+ }
305+ }
226306 }
227- } ;
228- NullBuffer :: new ( nulls)
307+ EmitTo :: First ( n) => match & mut self . seen_values {
308+ SeenValues :: All { num_values } => {
309+ * num_values = num_values. saturating_sub ( n) ;
310+ None
311+ }
312+ SeenValues :: Some { .. } => {
313+ let mut old_values = match std:: mem:: take ( & mut self . seen_values ) {
314+ SeenValues :: Some { values } => values,
315+ _ => unreachable ! ( ) ,
316+ } ;
317+ let nulls = old_values. finish ( ) ;
318+ let first_n_null = nulls. slice ( 0 , n) ;
319+ let remainder = nulls. slice ( n, nulls. len ( ) - n) ;
320+ let mut new_builder = BooleanBufferBuilder :: new ( remainder. len ( ) ) ;
321+ new_builder. append_buffer ( & remainder) ;
322+ self . seen_values = SeenValues :: Some {
323+ values : new_builder,
324+ } ;
325+ Some ( NullBuffer :: new ( first_n_null) )
326+ }
327+ } ,
328+ }
229329 }
230330}
231331
@@ -573,27 +673,14 @@ pub fn accumulate_indices<F>(
573673 }
574674}
575675
576- /// Ensures that `builder` contains a `BooleanBufferBuilder with at
577- /// least `total_num_groups`.
578- ///
579- /// All new entries are initialized to `default_value`
580- fn initialize_builder (
581- builder : & mut BooleanBufferBuilder ,
582- total_num_groups : usize ,
583- default_value : bool ,
584- ) -> & mut BooleanBufferBuilder {
585- if builder. len ( ) < total_num_groups {
586- let new_groups = total_num_groups - builder. len ( ) ;
587- builder. append_n ( new_groups, default_value) ;
588- }
589- builder
590- }
591-
592676#[ cfg( test) ]
593677mod test {
594678 use super :: * ;
595679
596- use arrow:: array:: { Int32Array , UInt32Array } ;
680+ use arrow:: {
681+ array:: { Int32Array , UInt32Array } ,
682+ buffer:: BooleanBuffer ,
683+ } ;
597684 use rand:: { Rng , rngs:: ThreadRng } ;
598685 use std:: collections:: HashSet ;
599686
@@ -834,15 +921,24 @@ mod test {
834921 accumulated_values, expected_values,
835922 "\n \n accumulated_values:{accumulated_values:#?}\n \n expected_values:{expected_values:#?}"
836923 ) ;
837- let seen_values = null_state. seen_values . finish_cloned ( ) ;
838- mock. validate_seen_values ( & seen_values) ;
924+
925+ match & null_state. seen_values {
926+ SeenValues :: All { num_values } => {
927+ assert_eq ! ( * num_values, total_num_groups) ;
928+ }
929+ SeenValues :: Some { values } => {
930+ let seen_values = values. finish_cloned ( ) ;
931+ mock. validate_seen_values ( & seen_values) ;
932+ }
933+ }
839934
840935 // Validate the final buffer (one value per group)
841936 let expected_null_buffer = mock. expected_null_buffer ( total_num_groups) ;
842937
843938 let null_buffer = null_state. build ( EmitTo :: All ) ;
844-
845- assert_eq ! ( null_buffer, expected_null_buffer) ;
939+ if let Some ( nulls) = & null_buffer {
940+ assert_eq ! ( * nulls, expected_null_buffer) ;
941+ }
846942 }
847943
848944 // Calls `accumulate_indices`
@@ -955,15 +1051,25 @@ mod test {
9551051 "\n \n accumulated_values:{accumulated_values:#?}\n \n expected_values:{expected_values:#?}"
9561052 ) ;
9571053
958- let seen_values = null_state. seen_values . finish_cloned ( ) ;
959- mock. validate_seen_values ( & seen_values) ;
1054+ match & null_state. seen_values {
1055+ SeenValues :: All { num_values } => {
1056+ assert_eq ! ( * num_values, total_num_groups) ;
1057+ }
1058+ SeenValues :: Some { values } => {
1059+ let seen_values = values. finish_cloned ( ) ;
1060+ mock. validate_seen_values ( & seen_values) ;
1061+ }
1062+ }
9601063
9611064 // Validate the final buffer (one value per group)
962- let expected_null_buffer = mock. expected_null_buffer ( total_num_groups) ;
1065+ let expected_null_buffer = Some ( mock. expected_null_buffer ( total_num_groups) ) ;
9631066
1067+ let is_all_seen = matches ! ( null_state. seen_values, SeenValues :: All { .. } ) ;
9641068 let null_buffer = null_state. build ( EmitTo :: All ) ;
9651069
966- assert_eq ! ( null_buffer, expected_null_buffer) ;
1070+ if !is_all_seen {
1071+ assert_eq ! ( null_buffer, expected_null_buffer) ;
1072+ }
9671073 }
9681074 }
9691075
0 commit comments