Skip to content

Commit b7091c0

Browse files
Dandandanalamb
andauthored
Optimize Nullstate / accumulators (#19625)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #19636 ## Rationale for this change Speedup accumulator code (sum, avg, count) by specializing on non-null cases. ## What changes are included in this PR? * Specialize `Nullstate` to non-null values. * Use unchecked indexing ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent c98fa56 commit b7091c0

File tree

6 files changed

+188
-69
lines changed

6 files changed

+188
-69
lines changed

datafusion-examples/examples/udf/advanced_udaf.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,12 +314,16 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
314314
let prods = emit_to.take_needed(&mut self.prods);
315315
let nulls = self.null_state.build(emit_to);
316316

317-
assert_eq!(nulls.len(), prods.len());
317+
if let Some(nulls) = &nulls {
318+
assert_eq!(nulls.len(), counts.len());
319+
}
318320
assert_eq!(counts.len(), prods.len());
319321

320322
// don't evaluate geometric mean with null inputs to avoid errors on null values
321323

322-
let array: PrimitiveArray<Float64Type> = if nulls.null_count() > 0 {
324+
let array: PrimitiveArray<Float64Type> = if let Some(nulls) = &nulls
325+
&& nulls.null_count() > 0
326+
{
323327
let mut builder = PrimitiveBuilder::<Float64Type>::with_capacity(nulls.len());
324328
let iter = prods.into_iter().zip(counts).zip(nulls.iter());
325329

@@ -337,7 +341,7 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
337341
.zip(counts)
338342
.map(|(prod, count)| prod.powf(1.0 / count as f64))
339343
.collect::<Vec<_>>();
340-
PrimitiveArray::new(geo_mean.into(), Some(nulls)) // no copy
344+
PrimitiveArray::new(geo_mean.into(), nulls) // no copy
341345
.with_data_type(self.return_data_type.clone())
342346
};
343347

@@ -347,7 +351,6 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
347351
// return arrays for counts and prods
348352
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
349353
let nulls = self.null_state.build(emit_to);
350-
let nulls = Some(nulls);
351354

352355
let counts = emit_to.take_needed(&mut self.counts);
353356
let counts = UInt32Array::new(counts.into(), nulls.clone()); // zero copy

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs

Lines changed: 160 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,70 @@
2020
//! [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator
2121
2222
use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray};
23-
use arrow::buffer::{BooleanBuffer, NullBuffer};
23+
use arrow::buffer::NullBuffer;
2424
use arrow::datatypes::ArrowPrimitiveType;
2525

2626
use datafusion_expr_common::groups_accumulator::EmitTo;
27+
28+
/// If the input has nulls, then the accumulator must potentially
29+
/// handle each input null value specially (e.g. for `SUM` to mark the
30+
/// corresponding sum as null)
31+
///
32+
/// If there are filters present, `NullState` tracks if it has seen
33+
/// *any* value for that group (as some values may be filtered
34+
/// out). Without a filter, the accumulator is only passed groups that
35+
/// had at least one value to accumulate so they do not need to track
36+
/// if they have seen values for a particular group.
37+
#[derive(Debug)]
38+
pub enum SeenValues {
39+
/// All groups seen so far have seen at least one non-null value
40+
All {
41+
num_values: usize,
42+
},
43+
// Some groups have not yet seen a non-null value
44+
Some {
45+
values: BooleanBufferBuilder,
46+
},
47+
}
48+
49+
impl Default for SeenValues {
50+
fn default() -> Self {
51+
SeenValues::All { num_values: 0 }
52+
}
53+
}
54+
55+
impl SeenValues {
56+
/// Return a mutable reference to the `BooleanBufferBuilder` in `SeenValues::Some`.
57+
///
58+
/// If `self` is `SeenValues::All`, it is transitioned to `SeenValues::Some`
59+
/// by creating a new `BooleanBufferBuilder` where the first `num_values` are true.
60+
///
61+
/// The builder is then ensured to have at least `total_num_groups` length,
62+
/// with any new entries initialized to false.
63+
fn get_builder(&mut self, total_num_groups: usize) -> &mut BooleanBufferBuilder {
64+
match self {
65+
SeenValues::All { num_values } => {
66+
let mut builder = BooleanBufferBuilder::new(total_num_groups);
67+
builder.append_n(*num_values, true);
68+
if total_num_groups > *num_values {
69+
builder.append_n(total_num_groups - *num_values, false);
70+
}
71+
*self = SeenValues::Some { values: builder };
72+
match self {
73+
SeenValues::Some { values } => values,
74+
_ => unreachable!(),
75+
}
76+
}
77+
SeenValues::Some { values } => {
78+
if values.len() < total_num_groups {
79+
values.append_n(total_num_groups - values.len(), false);
80+
}
81+
values
82+
}
83+
}
84+
}
85+
}
86+
2787
/// Track the accumulator null state per row: if any values for that
2888
/// group were null and if any values have been seen at all for that group.
2989
///
@@ -53,12 +113,14 @@ use datafusion_expr_common::groups_accumulator::EmitTo;
53113
pub struct NullState {
54114
/// Have we seen any non-filtered input values for `group_index`?
55115
///
56-
/// If `seen_values[i]` is true, have seen at least one non null
116+
/// If `seen_values` is `SeenValues::Some(buffer)` and buffer\[i\] is true, have seen at least one non null
57117
/// value for group `i`
58118
///
59-
/// If `seen_values[i]` is false, have not seen any values that
119+
/// If `seen_values` is `SeenValues::Some(buffer)` and buffer\[i\] is false, have not seen any values that
60120
/// pass the filter yet for group `i`
61-
seen_values: BooleanBufferBuilder,
121+
///
122+
/// If `seen_values` is `SeenValues::All`, all groups have seen at least one non null value
123+
seen_values: SeenValues,
62124
}
63125

64126
impl Default for NullState {
@@ -70,14 +132,16 @@ impl Default for NullState {
70132
impl NullState {
71133
pub fn new() -> Self {
72134
Self {
73-
seen_values: BooleanBufferBuilder::new(0),
135+
seen_values: SeenValues::All { num_values: 0 },
74136
}
75137
}
76138

77139
/// return the size of all buffers allocated by this null state, not including self
78140
pub fn size(&self) -> usize {
79-
// capacity is in bits, so convert to bytes
80-
self.seen_values.capacity() / 8
141+
match &self.seen_values {
142+
SeenValues::All { .. } => 0,
143+
SeenValues::Some { values } => values.capacity() / 8,
144+
}
81145
}
82146

83147
/// Invokes `value_fn(group_index, value)` for each non null, non
@@ -107,10 +171,17 @@ impl NullState {
107171
T: ArrowPrimitiveType + Send,
108172
F: FnMut(usize, T::Native) + Send,
109173
{
110-
// ensure the seen_values is big enough (start everything at
111-
// "not seen" valid)
112-
let seen_values =
113-
initialize_builder(&mut self.seen_values, total_num_groups, false);
174+
// skip null handling if no nulls in input or accumulator
175+
if let SeenValues::All { num_values } = &mut self.seen_values
176+
&& opt_filter.is_none()
177+
&& values.null_count() == 0
178+
{
179+
accumulate(group_indices, values, None, value_fn);
180+
*num_values = total_num_groups;
181+
return;
182+
}
183+
184+
let seen_values = self.seen_values.get_builder(total_num_groups);
114185
accumulate(group_indices, values, opt_filter, |group_index, value| {
115186
seen_values.set_bit(group_index, true);
116187
value_fn(group_index, value);
@@ -140,10 +211,21 @@ impl NullState {
140211
let data = values.values();
141212
assert_eq!(data.len(), group_indices.len());
142213

143-
// ensure the seen_values is big enough (start everything at
144-
// "not seen" valid)
145-
let seen_values =
146-
initialize_builder(&mut self.seen_values, total_num_groups, false);
214+
// skip null handling if no nulls in input or accumulator
215+
if let SeenValues::All { num_values } = &mut self.seen_values
216+
&& opt_filter.is_none()
217+
&& values.null_count() == 0
218+
{
219+
group_indices
220+
.iter()
221+
.zip(data.iter())
222+
.for_each(|(&group_index, new_value)| value_fn(group_index, new_value));
223+
*num_values = total_num_groups;
224+
225+
return;
226+
}
227+
228+
let seen_values = self.seen_values.get_builder(total_num_groups);
147229

148230
// These could be made more performant by iterating in chunks of 64 bits at a time
149231
match (values.null_count() > 0, opt_filter) {
@@ -211,21 +293,39 @@ impl NullState {
211293
/// for the `emit_to` rows.
212294
///
213295
/// resets the internal state appropriately
214-
pub fn build(&mut self, emit_to: EmitTo) -> NullBuffer {
215-
let nulls: BooleanBuffer = self.seen_values.finish();
216-
217-
let nulls = match emit_to {
218-
EmitTo::All => nulls,
219-
EmitTo::First(n) => {
220-
// split off the first N values in seen_values
221-
let first_n_null: BooleanBuffer = nulls.slice(0, n);
222-
// reset the existing seen buffer
223-
self.seen_values
224-
.append_buffer(&nulls.slice(n, nulls.len() - n));
225-
first_n_null
296+
pub fn build(&mut self, emit_to: EmitTo) -> Option<NullBuffer> {
297+
match emit_to {
298+
EmitTo::All => {
299+
let old_seen = std::mem::take(&mut self.seen_values);
300+
match old_seen {
301+
SeenValues::All { .. } => None,
302+
SeenValues::Some { mut values } => {
303+
Some(NullBuffer::new(values.finish()))
304+
}
305+
}
226306
}
227-
};
228-
NullBuffer::new(nulls)
307+
EmitTo::First(n) => match &mut self.seen_values {
308+
SeenValues::All { num_values } => {
309+
*num_values = num_values.saturating_sub(n);
310+
None
311+
}
312+
SeenValues::Some { .. } => {
313+
let mut old_values = match std::mem::take(&mut self.seen_values) {
314+
SeenValues::Some { values } => values,
315+
_ => unreachable!(),
316+
};
317+
let nulls = old_values.finish();
318+
let first_n_null = nulls.slice(0, n);
319+
let remainder = nulls.slice(n, nulls.len() - n);
320+
let mut new_builder = BooleanBufferBuilder::new(remainder.len());
321+
new_builder.append_buffer(&remainder);
322+
self.seen_values = SeenValues::Some {
323+
values: new_builder,
324+
};
325+
Some(NullBuffer::new(first_n_null))
326+
}
327+
},
328+
}
229329
}
230330
}
231331

@@ -573,27 +673,14 @@ pub fn accumulate_indices<F>(
573673
}
574674
}
575675

576-
/// Ensures that `builder` contains a `BooleanBufferBuilder with at
577-
/// least `total_num_groups`.
578-
///
579-
/// All new entries are initialized to `default_value`
580-
fn initialize_builder(
581-
builder: &mut BooleanBufferBuilder,
582-
total_num_groups: usize,
583-
default_value: bool,
584-
) -> &mut BooleanBufferBuilder {
585-
if builder.len() < total_num_groups {
586-
let new_groups = total_num_groups - builder.len();
587-
builder.append_n(new_groups, default_value);
588-
}
589-
builder
590-
}
591-
592676
#[cfg(test)]
593677
mod test {
594678
use super::*;
595679

596-
use arrow::array::{Int32Array, UInt32Array};
680+
use arrow::{
681+
array::{Int32Array, UInt32Array},
682+
buffer::BooleanBuffer,
683+
};
597684
use rand::{Rng, rngs::ThreadRng};
598685
use std::collections::HashSet;
599686

@@ -834,15 +921,24 @@ mod test {
834921
accumulated_values, expected_values,
835922
"\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"
836923
);
837-
let seen_values = null_state.seen_values.finish_cloned();
838-
mock.validate_seen_values(&seen_values);
924+
925+
match &null_state.seen_values {
926+
SeenValues::All { num_values } => {
927+
assert_eq!(*num_values, total_num_groups);
928+
}
929+
SeenValues::Some { values } => {
930+
let seen_values = values.finish_cloned();
931+
mock.validate_seen_values(&seen_values);
932+
}
933+
}
839934

840935
// Validate the final buffer (one value per group)
841936
let expected_null_buffer = mock.expected_null_buffer(total_num_groups);
842937

843938
let null_buffer = null_state.build(EmitTo::All);
844-
845-
assert_eq!(null_buffer, expected_null_buffer);
939+
if let Some(nulls) = &null_buffer {
940+
assert_eq!(*nulls, expected_null_buffer);
941+
}
846942
}
847943

848944
// Calls `accumulate_indices`
@@ -955,15 +1051,25 @@ mod test {
9551051
"\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"
9561052
);
9571053

958-
let seen_values = null_state.seen_values.finish_cloned();
959-
mock.validate_seen_values(&seen_values);
1054+
match &null_state.seen_values {
1055+
SeenValues::All { num_values } => {
1056+
assert_eq!(*num_values, total_num_groups);
1057+
}
1058+
SeenValues::Some { values } => {
1059+
let seen_values = values.finish_cloned();
1060+
mock.validate_seen_values(&seen_values);
1061+
}
1062+
}
9601063

9611064
// Validate the final buffer (one value per group)
962-
let expected_null_buffer = mock.expected_null_buffer(total_num_groups);
1065+
let expected_null_buffer = Some(mock.expected_null_buffer(total_num_groups));
9631066

1067+
let is_all_seen = matches!(null_state.seen_values, SeenValues::All { .. });
9641068
let null_buffer = null_state.build(EmitTo::All);
9651069

966-
assert_eq!(null_buffer, expected_null_buffer);
1070+
if !is_all_seen {
1071+
assert_eq!(null_buffer, expected_null_buffer);
1072+
}
9671073
}
9681074
}
9691075

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ where
120120
};
121121

122122
let nulls = self.null_state.build(emit_to);
123-
let values = BooleanArray::new(values, Some(nulls));
123+
let values = BooleanArray::new(values, nulls);
124124
Ok(Arc::new(values))
125125
}
126126

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ where
106106
opt_filter,
107107
total_num_groups,
108108
|group_index, new_value| {
109-
let value = &mut self.values[group_index];
109+
// SAFETY: group_index is guaranteed to be in bounds
110+
let value = unsafe { self.values.get_unchecked_mut(group_index) };
110111
(self.prim_fn)(value, new_value);
111112
},
112113
);
@@ -117,7 +118,7 @@ where
117118
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
118119
let values = emit_to.take_needed(&mut self.values);
119120
let nulls = self.null_state.build(emit_to);
120-
let values = PrimitiveArray::<T>::new(values.into(), Some(nulls)) // no copy
121+
let values = PrimitiveArray::<T>::new(values.into(), nulls) // no copy
121122
.with_data_type(self.data_type.clone());
122123
Ok(Arc::new(values))
123124
}

0 commit comments

Comments
 (0)