diff --git a/datafusion-examples/examples/udf/advanced_udaf.rs b/datafusion-examples/examples/udf/advanced_udaf.rs index fbb9e652486ce..03ba6f05bee19 100644 --- a/datafusion-examples/examples/udf/advanced_udaf.rs +++ b/datafusion-examples/examples/udf/advanced_udaf.rs @@ -314,12 +314,16 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator { let prods = emit_to.take_needed(&mut self.prods); let nulls = self.null_state.build(emit_to); - assert_eq!(nulls.len(), prods.len()); + if let Some(nulls) = &nulls { + assert_eq!(nulls.len(), counts.len()); + } assert_eq!(counts.len(), prods.len()); // don't evaluate geometric mean with null inputs to avoid errors on null values - let array: PrimitiveArray = if nulls.null_count() > 0 { + let array: PrimitiveArray = if let Some(nulls) = &nulls + && nulls.null_count() > 0 + { let mut builder = PrimitiveBuilder::::with_capacity(nulls.len()); let iter = prods.into_iter().zip(counts).zip(nulls.iter()); @@ -337,7 +341,7 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator { .zip(counts) .map(|(prod, count)| prod.powf(1.0 / count as f64)) .collect::>(); - PrimitiveArray::new(geo_mean.into(), Some(nulls)) // no copy + PrimitiveArray::new(geo_mean.into(), nulls) // no copy .with_data_type(self.return_data_type.clone()) }; @@ -347,7 +351,6 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator { // return arrays for counts and prods fn state(&mut self, emit_to: EmitTo) -> Result> { let nulls = self.null_state.build(emit_to); - let nulls = Some(nulls); let counts = emit_to.take_needed(&mut self.counts); let counts = UInt32Array::new(counts.into(), nulls.clone()); // zero copy diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index 29b8752048c3e..25f52df61136f 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -20,10 +20,70 @@ //! [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray}; -use arrow::buffer::{BooleanBuffer, NullBuffer}; +use arrow::buffer::NullBuffer; use arrow::datatypes::ArrowPrimitiveType; use datafusion_expr_common::groups_accumulator::EmitTo; + +/// If the input has nulls, then the accumulator must potentially +/// handle each input null value specially (e.g. for `SUM` to mark the +/// corresponding sum as null) +/// +/// If there are filters present, `NullState` tracks if it has seen +/// *any* value for that group (as some values may be filtered +/// out). Without a filter, the accumulator is only passed groups that +/// had at least one value to accumulate so they do not need to track +/// if they have seen values for a particular group. +#[derive(Debug)] +pub enum SeenValues { + /// All groups seen so far have seen at least one non-null value + All { + num_values: usize, + }, + // Some groups have not yet seen a non-null value + Some { + values: BooleanBufferBuilder, + }, +} + +impl Default for SeenValues { + fn default() -> Self { + SeenValues::All { num_values: 0 } + } +} + +impl SeenValues { + /// Return a mutable reference to the `BooleanBufferBuilder` in `SeenValues::Some`. + /// + /// If `self` is `SeenValues::All`, it is transitioned to `SeenValues::Some` + /// by creating a new `BooleanBufferBuilder` where the first `num_values` are true. + /// + /// The builder is then ensured to have at least `total_num_groups` length, + /// with any new entries initialized to false. + fn get_builder(&mut self, total_num_groups: usize) -> &mut BooleanBufferBuilder { + match self { + SeenValues::All { num_values } => { + let mut builder = BooleanBufferBuilder::new(total_num_groups); + builder.append_n(*num_values, true); + if total_num_groups > *num_values { + builder.append_n(total_num_groups - *num_values, false); + } + *self = SeenValues::Some { values: builder }; + match self { + SeenValues::Some { values } => values, + _ => unreachable!(), + } + } + SeenValues::Some { values } => { + if values.len() < total_num_groups { + values.append_n(total_num_groups - values.len(), false); + } + values + } + } + } +} + /// Track the accumulator null state per row: if any values for that /// group were null and if any values have been seen at all for that group. /// @@ -53,12 +113,14 @@ use datafusion_expr_common::groups_accumulator::EmitTo; pub struct NullState { /// Have we seen any non-filtered input values for `group_index`? /// - /// If `seen_values[i]` is true, have seen at least one non null + /// If `seen_values` is `SeenValues::Some(buffer)` and buffer\[i\] is true, have seen at least one non null /// value for group `i` /// - /// If `seen_values[i]` is false, have not seen any values that + /// If `seen_values` is `SeenValues::Some(buffer)` and buffer\[i\] is false, have not seen any values that /// pass the filter yet for group `i` - seen_values: BooleanBufferBuilder, + /// + /// If `seen_values` is `SeenValues::All`, all groups have seen at least one non null value + seen_values: SeenValues, } impl Default for NullState { @@ -70,14 +132,16 @@ impl Default for NullState { impl NullState { pub fn new() -> Self { Self { - seen_values: BooleanBufferBuilder::new(0), + seen_values: SeenValues::All { num_values: 0 }, } } /// return the size of all buffers allocated by this null state, not including self pub fn size(&self) -> usize { - // capacity is in bits, so convert to bytes - self.seen_values.capacity() / 8 + match &self.seen_values { + SeenValues::All { .. } => 0, + SeenValues::Some { values } => values.capacity() / 8, + } } /// Invokes `value_fn(group_index, value)` for each non null, non @@ -107,10 +171,17 @@ impl NullState { T: ArrowPrimitiveType + Send, F: FnMut(usize, T::Native) + Send, { - // ensure the seen_values is big enough (start everything at - // "not seen" valid) - let seen_values = - initialize_builder(&mut self.seen_values, total_num_groups, false); + // skip null handling if no nulls in input or accumulator + if let SeenValues::All { num_values } = &mut self.seen_values + && opt_filter.is_none() + && values.null_count() == 0 + { + accumulate(group_indices, values, None, value_fn); + *num_values = total_num_groups; + return; + } + + let seen_values = self.seen_values.get_builder(total_num_groups); accumulate(group_indices, values, opt_filter, |group_index, value| { seen_values.set_bit(group_index, true); value_fn(group_index, value); @@ -140,10 +211,21 @@ impl NullState { let data = values.values(); assert_eq!(data.len(), group_indices.len()); - // ensure the seen_values is big enough (start everything at - // "not seen" valid) - let seen_values = - initialize_builder(&mut self.seen_values, total_num_groups, false); + // skip null handling if no nulls in input or accumulator + if let SeenValues::All { num_values } = &mut self.seen_values + && opt_filter.is_none() + && values.null_count() == 0 + { + group_indices + .iter() + .zip(data.iter()) + .for_each(|(&group_index, new_value)| value_fn(group_index, new_value)); + *num_values = total_num_groups; + + return; + } + + let seen_values = self.seen_values.get_builder(total_num_groups); // These could be made more performant by iterating in chunks of 64 bits at a time match (values.null_count() > 0, opt_filter) { @@ -211,21 +293,39 @@ impl NullState { /// for the `emit_to` rows. /// /// resets the internal state appropriately - pub fn build(&mut self, emit_to: EmitTo) -> NullBuffer { - let nulls: BooleanBuffer = self.seen_values.finish(); - - let nulls = match emit_to { - EmitTo::All => nulls, - EmitTo::First(n) => { - // split off the first N values in seen_values - let first_n_null: BooleanBuffer = nulls.slice(0, n); - // reset the existing seen buffer - self.seen_values - .append_buffer(&nulls.slice(n, nulls.len() - n)); - first_n_null + pub fn build(&mut self, emit_to: EmitTo) -> Option { + match emit_to { + EmitTo::All => { + let old_seen = std::mem::take(&mut self.seen_values); + match old_seen { + SeenValues::All { .. } => None, + SeenValues::Some { mut values } => { + Some(NullBuffer::new(values.finish())) + } + } } - }; - NullBuffer::new(nulls) + EmitTo::First(n) => match &mut self.seen_values { + SeenValues::All { num_values } => { + *num_values = num_values.saturating_sub(n); + None + } + SeenValues::Some { .. } => { + let mut old_values = match std::mem::take(&mut self.seen_values) { + SeenValues::Some { values } => values, + _ => unreachable!(), + }; + let nulls = old_values.finish(); + let first_n_null = nulls.slice(0, n); + let remainder = nulls.slice(n, nulls.len() - n); + let mut new_builder = BooleanBufferBuilder::new(remainder.len()); + new_builder.append_buffer(&remainder); + self.seen_values = SeenValues::Some { + values: new_builder, + }; + Some(NullBuffer::new(first_n_null)) + } + }, + } } } @@ -573,27 +673,14 @@ pub fn accumulate_indices( } } -/// Ensures that `builder` contains a `BooleanBufferBuilder with at -/// least `total_num_groups`. -/// -/// All new entries are initialized to `default_value` -fn initialize_builder( - builder: &mut BooleanBufferBuilder, - total_num_groups: usize, - default_value: bool, -) -> &mut BooleanBufferBuilder { - if builder.len() < total_num_groups { - let new_groups = total_num_groups - builder.len(); - builder.append_n(new_groups, default_value); - } - builder -} - #[cfg(test)] mod test { use super::*; - use arrow::array::{Int32Array, UInt32Array}; + use arrow::{ + array::{Int32Array, UInt32Array}, + buffer::BooleanBuffer, + }; use rand::{Rng, rngs::ThreadRng}; use std::collections::HashSet; @@ -834,15 +921,24 @@ mod test { accumulated_values, expected_values, "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}" ); - let seen_values = null_state.seen_values.finish_cloned(); - mock.validate_seen_values(&seen_values); + + match &null_state.seen_values { + SeenValues::All { num_values } => { + assert_eq!(*num_values, total_num_groups); + } + SeenValues::Some { values } => { + let seen_values = values.finish_cloned(); + mock.validate_seen_values(&seen_values); + } + } // Validate the final buffer (one value per group) let expected_null_buffer = mock.expected_null_buffer(total_num_groups); let null_buffer = null_state.build(EmitTo::All); - - assert_eq!(null_buffer, expected_null_buffer); + if let Some(nulls) = &null_buffer { + assert_eq!(*nulls, expected_null_buffer); + } } // Calls `accumulate_indices` @@ -955,15 +1051,25 @@ mod test { "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}" ); - let seen_values = null_state.seen_values.finish_cloned(); - mock.validate_seen_values(&seen_values); + match &null_state.seen_values { + SeenValues::All { num_values } => { + assert_eq!(*num_values, total_num_groups); + } + SeenValues::Some { values } => { + let seen_values = values.finish_cloned(); + mock.validate_seen_values(&seen_values); + } + } // Validate the final buffer (one value per group) - let expected_null_buffer = mock.expected_null_buffer(total_num_groups); + let expected_null_buffer = Some(mock.expected_null_buffer(total_num_groups)); + let is_all_seen = matches!(null_state.seen_values, SeenValues::All { .. }); let null_buffer = null_state.build(EmitTo::All); - assert_eq!(null_buffer, expected_null_buffer); + if !is_all_seen { + assert_eq!(null_buffer, expected_null_buffer); + } } } diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs index 149312e5a9c0f..f716b48f0cccc 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs @@ -120,7 +120,7 @@ where }; let nulls = self.null_state.build(emit_to); - let values = BooleanArray::new(values, Some(nulls)); + let values = BooleanArray::new(values, nulls); Ok(Arc::new(values)) } diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs index 656b95d140dde..acf875b686139 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs @@ -106,7 +106,8 @@ where opt_filter, total_num_groups, |group_index, new_value| { - let value = &mut self.values[group_index]; + // SAFETY: group_index is guaranteed to be in bounds + let value = unsafe { self.values.get_unchecked_mut(group_index) }; (self.prim_fn)(value, new_value); }, ); @@ -117,7 +118,7 @@ where fn evaluate(&mut self, emit_to: EmitTo) -> Result { let values = emit_to.take_needed(&mut self.values); let nulls = self.null_state.build(emit_to); - let values = PrimitiveArray::::new(values.into(), Some(nulls)) // no copy + let values = PrimitiveArray::::new(values.into(), nulls) // no copy .with_data_type(self.data_type.clone()); Ok(Arc::new(values)) } diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 46a8dbf9540b6..543116db1ddb6 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -821,7 +821,8 @@ where opt_filter, total_num_groups, |group_index, new_value| { - let sum = &mut self.sums[group_index]; + // SAFETY: group_index is guaranteed to be in bounds + let sum = unsafe { self.sums.get_unchecked_mut(group_index) }; *sum = sum.add_wrapping(new_value); self.counts[group_index] += 1; @@ -836,12 +837,16 @@ where let sums = emit_to.take_needed(&mut self.sums); let nulls = self.null_state.build(emit_to); - assert_eq!(nulls.len(), sums.len()); + if let Some(nulls) = &nulls { + assert_eq!(nulls.len(), sums.len()); + } assert_eq!(counts.len(), sums.len()); // don't evaluate averages with null inputs to avoid errors on null values - let array: PrimitiveArray = if nulls.null_count() > 0 { + let array: PrimitiveArray = if let Some(nulls) = &nulls + && nulls.null_count() > 0 + { let mut builder = PrimitiveBuilder::::with_capacity(nulls.len()) .with_data_type(self.return_data_type.clone()); let iter = sums.into_iter().zip(counts).zip(nulls.iter()); @@ -860,7 +865,7 @@ where .zip(counts.into_iter()) .map(|(sum, count)| (self.avg_fn)(sum, count)) .collect::>>()?; - PrimitiveArray::new(averages.into(), Some(nulls)) // no copy + PrimitiveArray::new(averages.into(), nulls) // no copy .with_data_type(self.return_data_type.clone()) }; @@ -870,7 +875,6 @@ where // return arrays for sums and counts fn state(&mut self, emit_to: EmitTo) -> Result> { let nulls = self.null_state.build(emit_to); - let nulls = Some(nulls); let counts = emit_to.take_needed(&mut self.counts); let counts = UInt64Array::new(counts.into(), nulls.clone()); // zero copy @@ -904,7 +908,9 @@ where opt_filter, total_num_groups, |group_index, partial_count| { - self.counts[group_index] += partial_count; + // SAFETY: group_index is guaranteed to be in bounds + let count = unsafe { self.counts.get_unchecked_mut(group_index) }; + *count += partial_count; }, ); @@ -916,7 +922,8 @@ where opt_filter, total_num_groups, |group_index, new_value: ::Native| { - let sum = &mut self.sums[group_index]; + // SAFETY: group_index is guaranteed to be in bounds + let sum = unsafe { self.sums.get_unchecked_mut(group_index) }; *sum = sum.add_wrapping(new_value); }, ); diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index a7c819acafea8..10cc2ad33f563 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -598,7 +598,9 @@ impl GroupsAccumulator for CountGroupsAccumulator { values.logical_nulls().as_ref(), opt_filter, |group_index| { - self.counts[group_index] += 1; + // SAFETY: group_index is guaranteed to be in bounds + let count = unsafe { self.counts.get_unchecked_mut(group_index) }; + *count += 1; }, );