Skip to content
11 changes: 7 additions & 4 deletions datafusion-examples/examples/udf/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,12 +314,16 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
let prods = emit_to.take_needed(&mut self.prods);
let nulls = self.null_state.build(emit_to);

assert_eq!(nulls.len(), prods.len());
if let Some(nulls) = &nulls {
assert_eq!(nulls.len(), counts.len());
}
assert_eq!(counts.len(), prods.len());

// don't evaluate geometric mean with null inputs to avoid errors on null values

let array: PrimitiveArray<Float64Type> = if nulls.null_count() > 0 {
let array: PrimitiveArray<Float64Type> = if let Some(nulls) = &nulls
&& nulls.null_count() > 0
{
let mut builder = PrimitiveBuilder::<Float64Type>::with_capacity(nulls.len());
let iter = prods.into_iter().zip(counts).zip(nulls.iter());

Expand All @@ -337,7 +341,7 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
.zip(counts)
.map(|(prod, count)| prod.powf(1.0 / count as f64))
.collect::<Vec<_>>();
PrimitiveArray::new(geo_mean.into(), Some(nulls)) // no copy
PrimitiveArray::new(geo_mean.into(), nulls) // no copy
.with_data_type(self.return_data_type.clone())
};

Expand All @@ -347,7 +351,6 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
// return arrays for counts and prods
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
let nulls = self.null_state.build(emit_to);
let nulls = Some(nulls);

let counts = emit_to.take_needed(&mut self.counts);
let counts = UInt32Array::new(counts.into(), nulls.clone()); // zero copy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,64 @@
//! [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator

use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray};
use arrow::buffer::{BooleanBuffer, NullBuffer};
use arrow::buffer::NullBuffer;
use arrow::datatypes::ArrowPrimitiveType;

use datafusion_expr_common::groups_accumulator::EmitTo;

/// If the input has nulls, then the accumulator must potentially
/// handle each input null value specially (e.g. for `SUM` to mark the
/// corresponding sum as null)
///
/// If there are filters present, `NullState` tracks if it has seen
/// *any* value for that group (as some values may be filtered
/// out). Without a filter, the accumulator is only passed groups that
/// had at least one value to accumulate so they do not need to track
/// if they have seen values for a particular group.
#[derive(Debug)]
pub enum SeenValues {
/// All groups seen so far have seen at least one non-null value
All {
num_values: usize,
},
// Some groups have not yet seen a non-null value
Some {
values: BooleanBufferBuilder,
},
}

impl SeenValues {
/// Return a mutable reference to the `BooleanBufferBuilder` in `SeenValues::Some`.
///
/// If `self` is `SeenValues::All`, it is transitioned to `SeenValues::Some`
/// by creating a new `BooleanBufferBuilder` where the first `num_values` are true.
///
/// The builder is then ensured to have at least `total_num_groups` length,
/// with any new entries initialized to false.
fn get_builder(&mut self, total_num_groups: usize) -> &mut BooleanBufferBuilder {
match self {
SeenValues::All { num_values } => {
let mut builder = BooleanBufferBuilder::new(total_num_groups);
builder.append_n(*num_values, true);
if total_num_groups > *num_values {
builder.append_n(total_num_groups - *num_values, false);
}
*self = SeenValues::Some { values: builder };
match self {
SeenValues::Some { values } => values,
_ => unreachable!(),
}
}
SeenValues::Some { values } => {
if values.len() < total_num_groups {
values.append_n(total_num_groups - values.len(), false);
}
values
}
}
}
}

/// Track the accumulator null state per row: if any values for that
/// group were null and if any values have been seen at all for that group.
///
Expand Down Expand Up @@ -53,12 +107,14 @@ use datafusion_expr_common::groups_accumulator::EmitTo;
pub struct NullState {
/// Have we seen any non-filtered input values for `group_index`?
///
/// If `seen_values[i]` is true, have seen at least one non null
/// If `seen_values` is `SeenValues::Some(buffer)` and buffer\[i\] is true, have seen at least one non null
/// value for group `i`
///
/// If `seen_values[i]` is false, have not seen any values that
/// If `seen_values` is `SeenValues::Some(buffer)` and buffer\[i\] is false, have not seen any values that
/// pass the filter yet for group `i`
seen_values: BooleanBufferBuilder,
///
/// If `seen_values` is `SeenValues::All`, all groups have seen at least one non null value
seen_values: SeenValues,
}

impl Default for NullState {
Expand All @@ -70,14 +126,16 @@ impl Default for NullState {
impl NullState {
pub fn new() -> Self {
Self {
seen_values: BooleanBufferBuilder::new(0),
seen_values: SeenValues::All { num_values: 0 },
}
}

/// return the size of all buffers allocated by this null state, not including self
pub fn size(&self) -> usize {
// capacity is in bits, so convert to bytes
self.seen_values.capacity() / 8
match &self.seen_values {
SeenValues::All { .. } => 0,
SeenValues::Some { values } => values.capacity() / 8,
}
}

/// Invokes `value_fn(group_index, value)` for each non null, non
Expand Down Expand Up @@ -107,10 +165,16 @@ impl NullState {
T: ArrowPrimitiveType + Send,
F: FnMut(usize, T::Native) + Send,
{
// ensure the seen_values is big enough (start everything at
// "not seen" valid)
let seen_values =
initialize_builder(&mut self.seen_values, total_num_groups, false);
if let SeenValues::All { num_values } = &mut self.seen_values
&& opt_filter.is_none()
&& values.null_count() == 0
{
accumulate(group_indices, values, None, value_fn);
*num_values = total_num_groups;
return;
}

let seen_values = self.seen_values.get_builder(total_num_groups);
accumulate(group_indices, values, opt_filter, |group_index, value| {
seen_values.set_bit(group_index, true);
value_fn(group_index, value);
Expand Down Expand Up @@ -140,10 +204,20 @@ impl NullState {
let data = values.values();
assert_eq!(data.len(), group_indices.len());

// ensure the seen_values is big enough (start everything at
// "not seen" valid)
let seen_values =
initialize_builder(&mut self.seen_values, total_num_groups, false);
if let SeenValues::All { num_values } = &mut self.seen_values
&& opt_filter.is_none()
&& values.null_count() == 0
{
group_indices
.iter()
.zip(data.iter())
.for_each(|(&group_index, new_value)| value_fn(group_index, new_value));
*num_values = total_num_groups;

return;
}

let seen_values = self.seen_values.get_builder(total_num_groups);

// These could be made more performant by iterating in chunks of 64 bits at a time
match (values.null_count() > 0, opt_filter) {
Expand Down Expand Up @@ -211,21 +285,45 @@ impl NullState {
/// for the `emit_to` rows.
///
/// resets the internal state appropriately
pub fn build(&mut self, emit_to: EmitTo) -> NullBuffer {
let nulls: BooleanBuffer = self.seen_values.finish();

let nulls = match emit_to {
EmitTo::All => nulls,
EmitTo::First(n) => {
// split off the first N values in seen_values
let first_n_null: BooleanBuffer = nulls.slice(0, n);
// reset the existing seen buffer
self.seen_values
.append_buffer(&nulls.slice(n, nulls.len() - n));
first_n_null
pub fn build(&mut self, emit_to: EmitTo) -> Option<NullBuffer> {
match emit_to {
EmitTo::All => {
let old_seen = std::mem::replace(
&mut self.seen_values,
SeenValues::All { num_values: 0 },
);
match old_seen {
SeenValues::All { .. } => None,
SeenValues::Some { mut values } => {
Some(NullBuffer::new(values.finish()))
}
}
}
};
NullBuffer::new(nulls)
EmitTo::First(n) => match &mut self.seen_values {
SeenValues::All { num_values } => {
*num_values = num_values.saturating_sub(n);
None
}
SeenValues::Some { .. } => {
let mut old_values = match std::mem::replace(
&mut self.seen_values,
SeenValues::All { num_values: 0 },
) {
SeenValues::Some { values } => values,
_ => unreachable!(),
};
let nulls = old_values.finish();
let first_n_null = nulls.slice(0, n);
let remainder = nulls.slice(n, nulls.len() - n);
let mut new_builder = BooleanBufferBuilder::new(remainder.len());
new_builder.append_buffer(&remainder);
self.seen_values = SeenValues::Some {
values: new_builder,
};
Some(NullBuffer::new(first_n_null))
}
},
}
}
}

Expand Down Expand Up @@ -573,27 +671,14 @@ pub fn accumulate_indices<F>(
}
}

/// Ensures that `builder` contains a `BooleanBufferBuilder with at
/// least `total_num_groups`.
///
/// All new entries are initialized to `default_value`
fn initialize_builder(
builder: &mut BooleanBufferBuilder,
total_num_groups: usize,
default_value: bool,
) -> &mut BooleanBufferBuilder {
if builder.len() < total_num_groups {
let new_groups = total_num_groups - builder.len();
builder.append_n(new_groups, default_value);
}
builder
}

#[cfg(test)]
mod test {
use super::*;

use arrow::array::{Int32Array, UInt32Array};
use arrow::{
array::{Int32Array, UInt32Array},
buffer::BooleanBuffer,
};
use rand::{Rng, rngs::ThreadRng};
use std::collections::HashSet;

Expand Down Expand Up @@ -834,15 +919,24 @@ mod test {
accumulated_values, expected_values,
"\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"
);
let seen_values = null_state.seen_values.finish_cloned();
mock.validate_seen_values(&seen_values);

match &null_state.seen_values {
SeenValues::All { num_values } => {
assert_eq!(*num_values, total_num_groups);
}
SeenValues::Some { values } => {
let seen_values = values.finish_cloned();
mock.validate_seen_values(&seen_values);
}
}

// Validate the final buffer (one value per group)
let expected_null_buffer = mock.expected_null_buffer(total_num_groups);

let null_buffer = null_state.build(EmitTo::All);

assert_eq!(null_buffer, expected_null_buffer);
if let Some(nulls) = &null_buffer {
assert_eq!(*nulls, expected_null_buffer);
}
}

// Calls `accumulate_indices`
Expand Down Expand Up @@ -955,15 +1049,25 @@ mod test {
"\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"
);

let seen_values = null_state.seen_values.finish_cloned();
mock.validate_seen_values(&seen_values);
match &null_state.seen_values {
SeenValues::All { num_values } => {
assert_eq!(*num_values, total_num_groups);
}
SeenValues::Some { values } => {
let seen_values = values.finish_cloned();
mock.validate_seen_values(&seen_values);
}
}

// Validate the final buffer (one value per group)
let expected_null_buffer = mock.expected_null_buffer(total_num_groups);
let expected_null_buffer = Some(mock.expected_null_buffer(total_num_groups));

let is_all_seen = matches!(null_state.seen_values, SeenValues::All { .. });
let null_buffer = null_state.build(EmitTo::All);

assert_eq!(null_buffer, expected_null_buffer);
if !is_all_seen {
assert_eq!(null_buffer, expected_null_buffer);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ where
};

let nulls = self.null_state.build(emit_to);
let values = BooleanArray::new(values, Some(nulls));
let values = BooleanArray::new(values, nulls);
Ok(Arc::new(values))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ where
opt_filter,
total_num_groups,
|group_index, new_value| {
let value = &mut self.values[group_index];
// SAFETY: group_index is guaranteed to be in bounds
let value = unsafe { self.values.get_unchecked_mut(group_index) };
(self.prim_fn)(value, new_value);
},
);
Expand All @@ -117,7 +118,7 @@ where
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
let values = emit_to.take_needed(&mut self.values);
let nulls = self.null_state.build(emit_to);
let values = PrimitiveArray::<T>::new(values.into(), Some(nulls)) // no copy
let values = PrimitiveArray::<T>::new(values.into(), nulls) // no copy
.with_data_type(self.data_type.clone());
Ok(Arc::new(values))
}
Expand Down
Loading