Skip to content

Commit 317052e

Browse files
authored
perf: Optimize approx_distinct() for string, binary inputs (apache#21037)
## Which issue does this PR close? - Closes apache#21035. ## Rationale for this change We were making a defensive copy of every string and binary value before calling HyperLogLog, but that was unnecessary: HyperLogLog doesn't need owned types, and indeed never stores the input value. This improves the performance of `approx_distinct` on string and binary values by 6x-7x. ## What changes are included in this PR? * Add benchmark for `approx_distinct` * Optimize `approx_distinct` to avoid unnecessary copies * Cleanup: remove spurious type parameter from `StringViewHLLAccumulator` (unused) ## Are these changes tested? Yes. ## Are there any user-facing changes? No.
1 parent 7e4818d commit 317052e

File tree

3 files changed

+145
-30
lines changed

3 files changed

+145
-30
lines changed

datafusion/functions-aggregate/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,7 @@ harness = false
7575
[[bench]]
7676
harness = false
7777
name = "min_max_bytes"
78+
79+
[[bench]]
80+
name = "approx_distinct"
81+
harness = false
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::sync::Arc;
19+
20+
use arrow::array::{ArrayRef, Int64Array, StringArray, StringViewArray};
21+
use arrow::datatypes::{DataType, Field, Schema};
22+
use criterion::{Criterion, criterion_group, criterion_main};
23+
use datafusion_expr::function::AccumulatorArgs;
24+
use datafusion_expr::{Accumulator, AggregateUDFImpl};
25+
use datafusion_functions_aggregate::approx_distinct::ApproxDistinct;
26+
use datafusion_physical_expr::expressions::col;
27+
use rand::rngs::StdRng;
28+
use rand::{Rng, SeedableRng};
29+
30+
const BATCH_SIZE: usize = 8192;
31+
const STRING_LENGTH: usize = 20;
32+
33+
fn prepare_accumulator(data_type: DataType) -> Box<dyn Accumulator> {
34+
let schema = Arc::new(Schema::new(vec![Field::new("f", data_type, true)]));
35+
let expr = col("f", &schema).unwrap();
36+
let accumulator_args = AccumulatorArgs {
37+
return_field: Field::new("f", DataType::UInt64, true).into(),
38+
schema: &schema,
39+
expr_fields: &[expr.return_field(&schema).unwrap()],
40+
ignore_nulls: false,
41+
order_bys: &[],
42+
is_reversed: false,
43+
name: "approx_distinct(f)",
44+
is_distinct: false,
45+
exprs: &[expr],
46+
};
47+
ApproxDistinct::new().accumulator(accumulator_args).unwrap()
48+
}
49+
50+
/// Creates an Int64Array where values are drawn from `0..n_distinct`.
51+
fn create_i64_array(n_distinct: usize) -> Int64Array {
52+
let mut rng = StdRng::seed_from_u64(42);
53+
(0..BATCH_SIZE)
54+
.map(|_| Some(rng.random_range(0..n_distinct as i64)))
55+
.collect()
56+
}
57+
58+
/// Creates a pool of `n_distinct` random strings.
59+
fn create_string_pool(n_distinct: usize) -> Vec<String> {
60+
let mut rng = StdRng::seed_from_u64(42);
61+
(0..n_distinct)
62+
.map(|_| {
63+
(0..STRING_LENGTH)
64+
.map(|_| rng.random_range(b'a'..=b'z') as char)
65+
.collect()
66+
})
67+
.collect()
68+
}
69+
70+
/// Creates a StringArray where values are drawn from the given pool.
71+
fn create_string_array(pool: &[String]) -> StringArray {
72+
let mut rng = StdRng::seed_from_u64(99);
73+
(0..BATCH_SIZE)
74+
.map(|_| Some(pool[rng.random_range(0..pool.len())].as_str()))
75+
.collect()
76+
}
77+
78+
/// Creates a StringViewArray where values are drawn from the given pool.
79+
fn create_string_view_array(pool: &[String]) -> StringViewArray {
80+
let mut rng = StdRng::seed_from_u64(99);
81+
(0..BATCH_SIZE)
82+
.map(|_| Some(pool[rng.random_range(0..pool.len())].as_str()))
83+
.collect()
84+
}
85+
86+
fn approx_distinct_benchmark(c: &mut Criterion) {
87+
for pct in [80, 99] {
88+
let n_distinct = BATCH_SIZE * pct / 100;
89+
90+
// --- Int64 benchmarks ---
91+
let values = Arc::new(create_i64_array(n_distinct)) as ArrayRef;
92+
c.bench_function(&format!("approx_distinct i64 {pct}% distinct"), |b| {
93+
b.iter(|| {
94+
let mut accumulator = prepare_accumulator(DataType::Int64);
95+
accumulator
96+
.update_batch(std::slice::from_ref(&values))
97+
.unwrap()
98+
})
99+
});
100+
101+
let string_pool = create_string_pool(n_distinct);
102+
103+
// --- Utf8 benchmarks ---
104+
let values = Arc::new(create_string_array(&string_pool)) as ArrayRef;
105+
c.bench_function(&format!("approx_distinct utf8 {pct}% distinct"), |b| {
106+
b.iter(|| {
107+
let mut accumulator = prepare_accumulator(DataType::Utf8);
108+
accumulator
109+
.update_batch(std::slice::from_ref(&values))
110+
.unwrap()
111+
})
112+
});
113+
114+
// --- Utf8View benchmarks ---
115+
let values = Arc::new(create_string_view_array(&string_pool)) as ArrayRef;
116+
c.bench_function(&format!("approx_distinct utf8view {pct}% distinct"), |b| {
117+
b.iter(|| {
118+
let mut accumulator = prepare_accumulator(DataType::Utf8View);
119+
accumulator
120+
.update_batch(std::slice::from_ref(&values))
121+
.unwrap()
122+
})
123+
});
124+
}
125+
}
126+
127+
criterion_group!(benches, approx_distinct_benchmark);
128+
criterion_main!(benches);

datafusion/functions-aggregate/src/approx_distinct.rs

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,14 @@ make_udaf_expr_and_func!(
5555
approx_distinct_udaf
5656
);
5757

58-
impl<T: Hash> From<&HyperLogLog<T>> for ScalarValue {
58+
impl<T: Hash + ?Sized> From<&HyperLogLog<T>> for ScalarValue {
5959
fn from(v: &HyperLogLog<T>) -> ScalarValue {
6060
let values = v.as_ref().to_vec();
6161
ScalarValue::Binary(Some(values))
6262
}
6363
}
6464

65-
impl<T: Hash> TryFrom<&[u8]> for HyperLogLog<T> {
65+
impl<T: Hash + ?Sized> TryFrom<&[u8]> for HyperLogLog<T> {
6666
type Error = DataFusionError;
6767
fn try_from(v: &[u8]) -> Result<HyperLogLog<T>> {
6868
let arr: [u8; 16384] = v.try_into().map_err(|_| {
@@ -72,7 +72,7 @@ impl<T: Hash> TryFrom<&[u8]> for HyperLogLog<T> {
7272
}
7373
}
7474

75-
impl<T: Hash> TryFrom<&ScalarValue> for HyperLogLog<T> {
75+
impl<T: Hash + ?Sized> TryFrom<&ScalarValue> for HyperLogLog<T> {
7676
type Error = DataFusionError;
7777
fn try_from(v: &ScalarValue) -> Result<HyperLogLog<T>> {
7878
if let ScalarValue::Binary(Some(slice)) = v {
@@ -99,7 +99,6 @@ where
9999
T: ArrowPrimitiveType,
100100
T::Native: Hash,
101101
{
102-
/// new approx_distinct accumulator
103102
pub fn new() -> Self {
104103
Self {
105104
hll: HyperLogLog::new(),
@@ -112,15 +111,14 @@ struct StringHLLAccumulator<T>
112111
where
113112
T: OffsetSizeTrait,
114113
{
115-
hll: HyperLogLog<String>,
114+
hll: HyperLogLog<str>,
116115
phantom_data: PhantomData<T>,
117116
}
118117

119118
impl<T> StringHLLAccumulator<T>
120119
where
121120
T: OffsetSizeTrait,
122121
{
123-
/// new approx_distinct accumulator
124122
pub fn new() -> Self {
125123
Self {
126124
hll: HyperLogLog::new(),
@@ -130,22 +128,14 @@ where
130128
}
131129

132130
#[derive(Debug)]
133-
struct StringViewHLLAccumulator<T>
134-
where
135-
T: OffsetSizeTrait,
136-
{
137-
hll: HyperLogLog<String>,
138-
phantom_data: PhantomData<T>,
131+
struct StringViewHLLAccumulator {
132+
hll: HyperLogLog<str>,
139133
}
140134

141-
impl<T> StringViewHLLAccumulator<T>
142-
where
143-
T: OffsetSizeTrait,
144-
{
135+
impl StringViewHLLAccumulator {
145136
pub fn new() -> Self {
146137
Self {
147138
hll: HyperLogLog::new(),
148-
phantom_data: PhantomData,
149139
}
150140
}
151141
}
@@ -155,15 +145,14 @@ struct BinaryHLLAccumulator<T>
155145
where
156146
T: OffsetSizeTrait,
157147
{
158-
hll: HyperLogLog<Vec<u8>>,
148+
hll: HyperLogLog<[u8]>,
159149
phantom_data: PhantomData<T>,
160150
}
161151

162152
impl<T> BinaryHLLAccumulator<T>
163153
where
164154
T: OffsetSizeTrait,
165155
{
166-
/// new approx_distinct accumulator
167156
pub fn new() -> Self {
168157
Self {
169158
hll: HyperLogLog::new(),
@@ -213,23 +202,18 @@ where
213202
let array: &GenericBinaryArray<T> =
214203
downcast_value!(values[0], GenericBinaryArray, T);
215204
// flatten because we would skip nulls
216-
self.hll
217-
.extend(array.into_iter().flatten().map(|v| v.to_vec()));
205+
self.hll.extend(array.into_iter().flatten());
218206
Ok(())
219207
}
220208

221209
default_accumulator_impl!();
222210
}
223211

224-
impl<T> Accumulator for StringViewHLLAccumulator<T>
225-
where
226-
T: OffsetSizeTrait,
227-
{
212+
impl Accumulator for StringViewHLLAccumulator {
228213
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
229214
let array: &StringViewArray = downcast_value!(values[0], StringViewArray);
230215
// flatten because we would skip nulls
231-
self.hll
232-
.extend(array.iter().flatten().map(|s| s.to_string()));
216+
self.hll.extend(array.iter().flatten());
233217
Ok(())
234218
}
235219

@@ -244,8 +228,7 @@ where
244228
let array: &GenericStringArray<T> =
245229
downcast_value!(values[0], GenericStringArray, T);
246230
// flatten because we would skip nulls
247-
self.hll
248-
.extend(array.into_iter().flatten().map(|i| i.to_string()));
231+
self.hll.extend(array.into_iter().flatten());
249232
Ok(())
250233
}
251234

@@ -391,7 +374,7 @@ impl AggregateUDFImpl for ApproxDistinct {
391374
}
392375
DataType::Utf8 => Box::new(StringHLLAccumulator::<i32>::new()),
393376
DataType::LargeUtf8 => Box::new(StringHLLAccumulator::<i64>::new()),
394-
DataType::Utf8View => Box::new(StringViewHLLAccumulator::<i32>::new()),
377+
DataType::Utf8View => Box::new(StringViewHLLAccumulator::new()),
395378
DataType::Binary => Box::new(BinaryHLLAccumulator::<i32>::new()),
396379
DataType::LargeBinary => Box::new(BinaryHLLAccumulator::<i64>::new()),
397380
DataType::Null => {

0 commit comments

Comments
 (0)