Skip to content

Commit e42a0b6

Browse files
authored
Refactor distinct aggregate implementations to use common buffer (#18348)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Relates to #2406 ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> Make it easier to write distinct variations of aggregate functions be refactoring some of the common code together; specifically how they handle maintaining the complete set of distinct primitive values, as this code was duplicated across different functions. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> Introduce new `GenericDistinctBuffer` which has methods similar to `Accumulator` to manage an internal `HashSet` of values, so implementations like `percentile_cont` and `sum` can use it internally and only implement their own evaluate functions. ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Existing tests. ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> No. <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent e29009f commit e42a0b6

File tree

5 files changed

+150
-229
lines changed

5 files changed

+150
-229
lines changed

datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs

Lines changed: 12 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ use datafusion_common::utils::SingleRowListArrayBuilder;
3838
use datafusion_common::ScalarValue;
3939
use datafusion_expr_common::accumulator::Accumulator;
4040

41-
use crate::utils::Hashable;
41+
use crate::utils::GenericDistinctBuffer;
4242

4343
#[derive(Debug)]
4444
pub struct PrimitiveDistinctCountAccumulator<T>
@@ -124,88 +124,42 @@ where
124124
}
125125

126126
#[derive(Debug)]
127-
pub struct FloatDistinctCountAccumulator<T>
128-
where
129-
T: ArrowPrimitiveType + Send,
130-
{
131-
values: HashSet<Hashable<T::Native>, RandomState>,
127+
pub struct FloatDistinctCountAccumulator<T: ArrowPrimitiveType> {
128+
values: GenericDistinctBuffer<T>,
132129
}
133130

134-
impl<T> FloatDistinctCountAccumulator<T>
135-
where
136-
T: ArrowPrimitiveType + Send,
137-
{
131+
impl<T: ArrowPrimitiveType> FloatDistinctCountAccumulator<T> {
138132
pub fn new() -> Self {
139133
Self {
140-
values: HashSet::default(),
134+
values: GenericDistinctBuffer::new(T::DATA_TYPE),
141135
}
142136
}
143137
}
144138

145-
impl<T> Default for FloatDistinctCountAccumulator<T>
146-
where
147-
T: ArrowPrimitiveType + Send,
148-
{
139+
impl<T: ArrowPrimitiveType> Default for FloatDistinctCountAccumulator<T> {
149140
fn default() -> Self {
150141
Self::new()
151142
}
152143
}
153144

154-
impl<T> Accumulator for FloatDistinctCountAccumulator<T>
155-
where
156-
T: ArrowPrimitiveType + Send + Debug,
157-
{
145+
impl<T: ArrowPrimitiveType + Debug> Accumulator for FloatDistinctCountAccumulator<T> {
158146
fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
159-
let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
160-
self.values.iter().map(|v| v.0),
161-
)) as ArrayRef;
162-
Ok(vec![SingleRowListArrayBuilder::new(arr).build_list_scalar()])
147+
self.values.state()
163148
}
164149

165150
fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> {
166-
if values.is_empty() {
167-
return Ok(());
168-
}
169-
170-
let arr = as_primitive_array::<T>(&values[0])?;
171-
arr.iter().for_each(|value| {
172-
if let Some(value) = value {
173-
self.values.insert(Hashable(value));
174-
}
175-
});
176-
177-
Ok(())
151+
self.values.update_batch(values)
178152
}
179153

180154
fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> {
181-
if states.is_empty() {
182-
return Ok(());
183-
}
184-
assert_eq!(
185-
states.len(),
186-
1,
187-
"count_distinct states must be single array"
188-
);
189-
190-
let arr = as_list_array(&states[0])?;
191-
arr.iter().try_for_each(|maybe_list| {
192-
if let Some(list) = maybe_list {
193-
let list = as_primitive_array::<T>(&list)?;
194-
self.values
195-
.extend(list.values().iter().map(|v| Hashable(*v)));
196-
};
197-
Ok(())
198-
})
155+
self.values.merge_batch(states)
199156
}
200157

201158
fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
202-
Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
159+
Ok(ScalarValue::Int64(Some(self.values.values.len() as i64)))
203160
}
204161

205162
fn size(&self) -> usize {
206-
let num_elements = self.values.len();
207-
let fixed_size = size_of_val(self) + size_of_val(&self.values);
208-
209-
estimate_memory_size::<T::Native>(num_elements, fixed_size).unwrap()
163+
size_of_val(self) + self.values.size()
210164
}
211165
}

datafusion/functions-aggregate-common/src/aggregate/sum_distinct/numeric.rs

Lines changed: 19 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -17,107 +17,67 @@
1717

1818
//! Defines the accumulator for `SUM DISTINCT` for primitive numeric types
1919
20-
use std::collections::HashSet;
2120
use std::fmt::Debug;
22-
use std::mem::{size_of, size_of_val};
21+
use std::mem::size_of_val;
2322

24-
use ahash::RandomState;
25-
use arrow::array::Array;
2623
use arrow::array::ArrayRef;
2724
use arrow::array::ArrowNativeTypeOp;
2825
use arrow::array::ArrowPrimitiveType;
29-
use arrow::array::AsArray;
3026
use arrow::datatypes::ArrowNativeType;
3127
use arrow::datatypes::DataType;
3228

3329
use datafusion_common::Result;
3430
use datafusion_common::ScalarValue;
3531
use datafusion_expr_common::accumulator::Accumulator;
3632

37-
use crate::utils::Hashable;
33+
use crate::utils::GenericDistinctBuffer;
3834

3935
/// Accumulator for computing SUM(DISTINCT expr)
36+
#[derive(Debug)]
4037
pub struct DistinctSumAccumulator<T: ArrowPrimitiveType> {
41-
values: HashSet<Hashable<T::Native>, RandomState>,
38+
values: GenericDistinctBuffer<T>,
4239
data_type: DataType,
4340
}
4441

45-
impl<T: ArrowPrimitiveType> Debug for DistinctSumAccumulator<T> {
46-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47-
write!(f, "DistinctSumAccumulator({})", self.data_type)
48-
}
49-
}
50-
5142
impl<T: ArrowPrimitiveType> DistinctSumAccumulator<T> {
5243
pub fn new(data_type: &DataType) -> Self {
5344
Self {
54-
values: HashSet::default(),
45+
values: GenericDistinctBuffer::new(data_type.clone()),
5546
data_type: data_type.clone(),
5647
}
5748
}
5849

5950
pub fn distinct_count(&self) -> usize {
60-
self.values.len()
51+
self.values.values.len()
6152
}
6253
}
6354

64-
impl<T: ArrowPrimitiveType> Accumulator for DistinctSumAccumulator<T> {
55+
impl<T: ArrowPrimitiveType + Debug> Accumulator for DistinctSumAccumulator<T> {
6556
fn state(&mut self) -> Result<Vec<ScalarValue>> {
66-
// 1. Stores aggregate state in `ScalarValue::List`
67-
// 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set
68-
let state_out = {
69-
let distinct_values = self
70-
.values
71-
.iter()
72-
.map(|value| {
73-
ScalarValue::new_primitive::<T>(Some(value.0), &self.data_type)
74-
})
75-
.collect::<Result<Vec<_>>>()?;
76-
77-
vec![ScalarValue::List(ScalarValue::new_list_nullable(
78-
&distinct_values,
79-
&self.data_type,
80-
))]
81-
};
82-
Ok(state_out)
57+
self.values.state()
8358
}
8459

8560
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
86-
if values.is_empty() {
87-
return Ok(());
88-
}
89-
90-
let array = values[0].as_primitive::<T>();
91-
match array.nulls().filter(|x| x.null_count() > 0) {
92-
Some(n) => {
93-
for idx in n.valid_indices() {
94-
self.values.insert(Hashable(array.value(idx)));
95-
}
96-
}
97-
None => array.values().iter().for_each(|x| {
98-
self.values.insert(Hashable(*x));
99-
}),
100-
}
101-
Ok(())
61+
self.values.update_batch(values)
10262
}
10363

10464
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
105-
for x in states[0].as_list::<i32>().iter().flatten() {
106-
self.update_batch(&[x])?
107-
}
108-
Ok(())
65+
self.values.merge_batch(states)
10966
}
11067

11168
fn evaluate(&mut self) -> Result<ScalarValue> {
112-
let mut acc = T::Native::usize_as(0);
113-
for distinct_value in self.values.iter() {
114-
acc = acc.add_wrapping(distinct_value.0)
69+
if self.distinct_count() == 0 {
70+
ScalarValue::new_primitive::<T>(None, &self.data_type)
71+
} else {
72+
let mut acc = T::Native::usize_as(0);
73+
for distinct_value in self.values.values.iter() {
74+
acc = acc.add_wrapping(distinct_value.0)
75+
}
76+
ScalarValue::new_primitive::<T>(Some(acc), &self.data_type)
11577
}
116-
let v = (!self.values.is_empty()).then_some(acc);
117-
ScalarValue::new_primitive::<T>(v, &self.data_type)
11878
}
11979

12080
fn size(&self) -> usize {
121-
size_of_val(self) + self.values.capacity() * size_of::<T::Native>()
81+
size_of_val(self) + self.values.size()
12282
}
12383
}

datafusion/functions-aggregate-common/src/utils.rs

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,20 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::array::{ArrayRef, ArrowNativeTypeOp};
18+
use ahash::RandomState;
19+
use arrow::array::{
20+
Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray,
21+
};
1922
use arrow::compute::SortOptions;
2023
use arrow::datatypes::{
2124
ArrowNativeType, DataType, DecimalType, Field, FieldRef, ToByteSlice,
2225
};
23-
use datafusion_common::{exec_err, internal_datafusion_err, Result};
26+
use datafusion_common::cast::{as_list_array, as_primitive_array};
27+
use datafusion_common::utils::memory::estimate_memory_size;
28+
use datafusion_common::utils::SingleRowListArrayBuilder;
29+
use datafusion_common::{
30+
exec_err, internal_datafusion_err, HashSet, Result, ScalarValue,
31+
};
2432
use datafusion_expr_common::accumulator::Accumulator;
2533
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
2634
use std::sync::Arc;
@@ -167,3 +175,90 @@ impl<T: DecimalType> DecimalAverager<T> {
167175
}
168176
}
169177
}
178+
179+
/// Generic way to collect distinct values for accumulators.
180+
///
181+
/// The intermediate state is represented as a List of scalar values updated by
182+
/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values
183+
/// in the final evaluation step so that we avoid expensive conversions and
184+
/// allocations during `update_batch`.
185+
pub struct GenericDistinctBuffer<T: ArrowPrimitiveType> {
186+
pub values: HashSet<Hashable<T::Native>, RandomState>,
187+
data_type: DataType,
188+
}
189+
190+
impl<T: ArrowPrimitiveType> std::fmt::Debug for GenericDistinctBuffer<T> {
191+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192+
write!(
193+
f,
194+
"GenericDistinctBuffer({}, values={})",
195+
self.data_type,
196+
self.values.len()
197+
)
198+
}
199+
}
200+
201+
impl<T: ArrowPrimitiveType> GenericDistinctBuffer<T> {
202+
pub fn new(data_type: DataType) -> Self {
203+
Self {
204+
values: HashSet::default(),
205+
data_type,
206+
}
207+
}
208+
209+
/// Mirrors [`Accumulator::state`].
210+
pub fn state(&self) -> Result<Vec<ScalarValue>> {
211+
let arr = Arc::new(
212+
PrimitiveArray::<T>::from_iter_values(self.values.iter().map(|v| v.0))
213+
// Ideally we'd just use T::DATA_TYPE but this misses things like
214+
// decimal scale/precision and timestamp timezones, which need to
215+
// match up with Accumulator::state_fields
216+
.with_data_type(self.data_type.clone()),
217+
);
218+
Ok(vec![SingleRowListArrayBuilder::new(arr).build_list_scalar()])
219+
}
220+
221+
/// Mirrors [`Accumulator::update_batch`].
222+
pub fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
223+
if values.is_empty() {
224+
return Ok(());
225+
}
226+
227+
debug_assert_eq!(
228+
values.len(),
229+
1,
230+
"DistinctValuesBuffer::update_batch expects only a single input array"
231+
);
232+
233+
let arr = as_primitive_array::<T>(&values[0])?;
234+
if arr.null_count() > 0 {
235+
self.values.extend(arr.iter().flatten().map(Hashable));
236+
} else {
237+
self.values
238+
.extend(arr.values().iter().cloned().map(Hashable));
239+
}
240+
241+
Ok(())
242+
}
243+
244+
/// Mirrors [`Accumulator::merge_batch`].
245+
pub fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
246+
if states.is_empty() {
247+
return Ok(());
248+
}
249+
250+
let array = as_list_array(&states[0])?;
251+
for list in array.iter().flatten() {
252+
self.update_batch(&[list])?;
253+
}
254+
255+
Ok(())
256+
}
257+
258+
/// Mirrors [`Accumulator::size`].
259+
pub fn size(&self) -> usize {
260+
let num_elements = self.values.len();
261+
let fixed_size = size_of_val(self) + size_of_val(&self.values);
262+
estimate_memory_size::<T::Native>(num_elements, fixed_size).unwrap()
263+
}
264+
}

0 commit comments

Comments
 (0)