Skip to content

Commit 52dd753

Browse files
committed
Simplify
1 parent 05414e8 commit 52dd753

File tree

1 file changed

+14
-15
lines changed
  • datafusion/functions-aggregate-common/src/aggregate/groups_accumulator

1 file changed

+14
-15
lines changed

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,7 @@ pub struct NullState {
5959
/// If `seen_values[i]` is false, have not seen any values that
6060
/// pass the filter yet for group `i`
6161
seen_values: BooleanBufferBuilder,
62-
/// If true, all groups seen so far have seen at least one non-null value
63-
/// and no filters have been applied.
64-
all_seen: bool,
62+
/// Tracks number of total groups
6563
seen_values_size: usize,
6664
}
6765

@@ -75,7 +73,6 @@ impl NullState {
7573
pub fn new() -> Self {
7674
Self {
7775
seen_values: BooleanBufferBuilder::new(0),
78-
all_seen: true,
7976
seen_values_size: 0,
8077
}
8178
}
@@ -113,13 +110,12 @@ impl NullState {
113110
T: ArrowPrimitiveType + Send,
114111
F: FnMut(usize, T::Native) + Send,
115112
{
116-
if self.all_seen && opt_filter.is_none() && values.null_count() == 0 {
113+
if self.seen_values.capacity() == 0 && opt_filter.is_none() && values.null_count() == 0 {
117114
accumulate(group_indices, values, None, value_fn);
118115
self.seen_values_size = total_num_groups;
119116
} else {
120-
let prev_seen: bool = self.all_seen && self.seen_values_size > 0;
117+
let prev_seen: bool = self.all_seen() && self.seen_values_size > 0;
121118

122-
self.all_seen = false;
123119
let seen_values =
124120
initialize_builder(&mut self.seen_values, total_num_groups, false);
125121

@@ -159,7 +155,7 @@ impl NullState {
159155
let data = values.values();
160156
assert_eq!(data.len(), group_indices.len());
161157

162-
if self.all_seen && opt_filter.is_none() && values.null_count() == 0 {
158+
if self.seen_values.capacity() == 0 && opt_filter.is_none() && values.null_count() == 0 {
163159
group_indices
164160
.iter()
165161
.zip(data.iter())
@@ -168,9 +164,7 @@ impl NullState {
168164

169165
return;
170166
}
171-
let prev_seen: bool = self.all_seen && self.seen_values_size > 0;
172-
173-
self.all_seen = false;
167+
let prev_seen: bool = self.seen_values.capacity() == 0 && self.seen_values_size > 0;
174168
// ensure the seen_values is big enough (start everything at
175169
// "not seen" valid)
176170
let seen_values =
@@ -253,15 +247,15 @@ impl NullState {
253247
let nulls = match emit_to {
254248
EmitTo::All => {
255249
self.seen_values_size = 0;
256-
if self.all_seen {
250+
if self.all_seen() {
257251
// all groups have seen at least one non null value
258252
return None;
259253
} else {
260254
nulls
261255
}
262256
}
263257
EmitTo::First(n) => {
264-
if self.all_seen {
258+
if self.all_seen() {
265259
self.seen_values_size -= n;
266260
return None;
267261
}
@@ -275,6 +269,11 @@ impl NullState {
275269
};
276270
Some(NullBuffer::new(nulls))
277271
}
272+
273+
/// Returns true if all groups have seen at least one non null
274+
fn all_seen(&self) -> bool {
275+
self.seen_values.capacity() == 0
276+
}
278277
}
279278

280279
/// Invokes `value_fn(group_index, value)` for each non null, non
@@ -883,7 +882,7 @@ mod test {
883882
"\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"
884883
);
885884
let seen_values = null_state.seen_values.finish_cloned();
886-
if null_state.all_seen {
885+
if null_state.all_seen() {
887886
assert_eq!(null_state.seen_values_size, total_num_groups);
888887
} else {
889888
mock.validate_seen_values(&seen_values);
@@ -1016,7 +1015,7 @@ mod test {
10161015

10171016
let null_buffer = null_state.build(EmitTo::All);
10181017

1019-
if !null_state.all_seen {
1018+
if !null_state.all_seen() {
10201019
assert_eq!(null_buffer, expected_null_buffer);
10211020
}
10221021
}

0 commit comments

Comments
 (0)