Skip to content

Commit 9ebbf95

Browse files
committed
Incremental group emission in HashAggregate
1 parent a51e3a0 commit 9ebbf95

File tree

19 files changed

+573
-154
lines changed

19 files changed

+573
-154
lines changed
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Benchmark comparing incremental emission vs all-at-once emission
19+
//! for hash aggregation.
20+
//!
21+
//! This benchmark measures the time-to-first-row improvement from
22+
//! the incremental drain optimization.
23+
//!
24+
//! Usage:
25+
//! cargo run --release --bin incremental_emit_bench -- --groups 1000000
26+
27+
use arrow::array::{Int64Array, StringArray};
28+
use arrow::datatypes::{DataType, Field, Schema};
29+
use arrow::record_batch::RecordBatch;
30+
use datafusion::common::Result;
31+
use datafusion::datasource::MemTable;
32+
use datafusion::prelude::*;
33+
use datafusion_common::instant::Instant;
34+
use futures::StreamExt;
35+
use std::sync::Arc;
36+
use std::time::Duration;
37+
38+
#[derive(Debug, Clone)]
39+
struct BenchmarkResult {
40+
time_to_first_row: Duration,
41+
total_time: Duration,
42+
num_output_batches: usize,
43+
total_output_rows: usize,
44+
}
45+
46+
async fn run_aggregation_benchmark(
47+
batch: RecordBatch,
48+
batch_size: usize,
49+
) -> Result<BenchmarkResult> {
50+
let config = SessionConfig::new().with_batch_size(batch_size);
51+
let ctx = SessionContext::new_with_config(config);
52+
53+
let schema = batch.schema();
54+
let table = MemTable::try_new(schema, vec![vec![batch]])?;
55+
ctx.register_table("bench_data", Arc::new(table))?;
56+
57+
let sql = "SELECT group_key, SUM(value) as total, COUNT(*) as cnt \
58+
FROM bench_data \
59+
GROUP BY group_key";
60+
61+
let df = ctx.sql(sql).await?;
62+
63+
let start = Instant::now();
64+
let mut stream = df.execute_stream().await?;
65+
66+
// Measure time to first batch
67+
let first_batch = stream.next().await;
68+
let time_to_first_row = start.elapsed();
69+
70+
let mut num_output_batches = 0;
71+
let mut total_output_rows = 0;
72+
73+
if let Some(Ok(batch)) = first_batch {
74+
num_output_batches += 1;
75+
total_output_rows += batch.num_rows();
76+
}
77+
78+
// Consume remaining batches
79+
while let Some(result) = stream.next().await {
80+
if let Ok(batch) = result {
81+
num_output_batches += 1;
82+
total_output_rows += batch.num_rows();
83+
}
84+
}
85+
86+
let total_time = start.elapsed();
87+
88+
Ok(BenchmarkResult {
89+
time_to_first_row,
90+
total_time,
91+
num_output_batches,
92+
total_output_rows,
93+
})
94+
}
95+
96+
fn create_test_data(num_groups: usize, rows_per_group: usize) -> Result<RecordBatch> {
97+
let total_rows = num_groups * rows_per_group;
98+
99+
// Create group keys: "group_0", "group_1", ..., "group_{num_groups-1}"
100+
// Each group appears rows_per_group times
101+
let group_keys: Vec<String> = (0..total_rows)
102+
.map(|i| format!("group_{}", i % num_groups))
103+
.collect();
104+
105+
// Create values: just use the row index
106+
let values: Vec<i64> = (0..total_rows as i64).collect();
107+
108+
let schema = Arc::new(Schema::new(vec![
109+
Field::new("group_key", DataType::Utf8, false),
110+
Field::new("value", DataType::Int64, false),
111+
]));
112+
113+
let batch = RecordBatch::try_new(
114+
schema,
115+
vec![
116+
Arc::new(StringArray::from(group_keys)),
117+
Arc::new(Int64Array::from(values)),
118+
],
119+
)?;
120+
121+
Ok(batch)
122+
}
123+
124+
#[tokio::main]
125+
async fn main() -> Result<()> {
126+
let args: Vec<String> = std::env::args().collect();
127+
128+
let num_groups = args
129+
.iter()
130+
.position(|s| s == "--groups")
131+
.and_then(|i| args.get(i + 1))
132+
.and_then(|s| s.parse().ok())
133+
.unwrap_or(100_000);
134+
135+
let rows_per_group = args
136+
.iter()
137+
.position(|s| s == "--rows-per-group")
138+
.and_then(|i| args.get(i + 1))
139+
.and_then(|s| s.parse().ok())
140+
.unwrap_or(10);
141+
142+
let iterations = args
143+
.iter()
144+
.position(|s| s == "--iterations")
145+
.and_then(|i| args.get(i + 1))
146+
.and_then(|s| s.parse().ok())
147+
.unwrap_or(3);
148+
149+
println!("=== Incremental Emit Benchmark ===");
150+
println!("Number of groups: {num_groups}");
151+
println!("Rows per group: {rows_per_group}");
152+
println!("Total input rows: {}", num_groups * rows_per_group);
153+
println!("Iterations: {iterations}");
154+
println!();
155+
156+
// Batch sizes to test
157+
// Note: We use num_groups as the "all-at-once" batch size to simulate the EmitTo::All behavior
158+
// of emitting all groups in a single batch.
159+
let batch_sizes = vec![
160+
(8192, "8192 (incremental)".to_string()),
161+
(32768, "32768 (larger batches)".to_string()),
162+
(
163+
num_groups,
164+
format!("{num_groups} (all-at-once, simulates EmitTo::All behavior)"),
165+
),
166+
];
167+
168+
println!("Running benchmarks...");
169+
println!();
170+
171+
for (batch_size, label) in batch_sizes {
172+
println!("--- Batch size: {label} ---");
173+
174+
let mut first_row_times = Vec::new();
175+
let mut total_times = Vec::new();
176+
177+
for i in 0..iterations {
178+
// Create fresh test data for each run
179+
let batch = create_test_data(num_groups, rows_per_group)?;
180+
let result = run_aggregation_benchmark(batch, batch_size).await?;
181+
182+
println!(
183+
" Iteration {}: first_row={:?}, total={:?}, batches={}, rows={}",
184+
i + 1,
185+
result.time_to_first_row,
186+
result.total_time,
187+
result.num_output_batches,
188+
result.total_output_rows
189+
);
190+
191+
first_row_times.push(result.time_to_first_row);
192+
total_times.push(result.total_time);
193+
}
194+
195+
let avg_first_row: Duration =
196+
first_row_times.iter().sum::<Duration>() / iterations as u32;
197+
let avg_total: Duration =
198+
total_times.iter().sum::<Duration>() / iterations as u32;
199+
200+
println!(" Average: first_row={avg_first_row:?}, total={avg_total:?}");
201+
println!();
202+
}
203+
204+
println!("=== Summary ===");
205+
println!("The 'time to first row' metric shows how quickly the first output");
206+
println!("batch is produced. With incremental emission (smaller batch sizes),");
207+
println!("this should be significantly faster than all-at-once emission.");
208+
209+
Ok(())
210+
}
211+
212+
/* Example output:
213+
214+
cargo run --bin incremental_emit -- --groups 1000000 --rows-per-group 5 --iterations 3
215+
216+
=== Incremental Emit Benchmark ===
217+
Number of groups: 1000000
218+
Rows per group: 5
219+
Total input rows: 5000000
220+
Iterations: 3
221+
222+
Running benchmarks...
223+
224+
--- Batch size: 8192 (incremental) ---
225+
Iteration 1: first_row=514.312458ms, total=750.121625ms, batches=128, rows=1000000
226+
Iteration 2: first_row=487.098583ms, total=680.311958ms, batches=128, rows=1000000
227+
Iteration 3: first_row=473.02925ms, total=668.469083ms, batches=128, rows=1000000
228+
Average: first_row=491.480097ms, total=699.634222ms
229+
230+
--- Batch size: 32768 (larger batches) ---
231+
Iteration 1: first_row=481.137417ms, total=497.485917ms, batches=32, rows=1000000
232+
Iteration 2: first_row=478.821ms, total=496.062959ms, batches=32, rows=1000000
233+
Iteration 3: first_row=524.281709ms, total=539.426584ms, batches=32, rows=1000000
234+
Average: first_row=494.746708ms, total=510.99182ms
235+
236+
--- Batch size: 1000000 (all-at-once, simulates EmitTo::All behavior) ---
237+
Iteration 1: first_row=1.22554625s, total=1.303929583s, batches=16, rows=1000000
238+
Iteration 2: first_row=1.237296333s, total=1.2897605s, batches=16, rows=1000000
239+
Iteration 3: first_row=1.235812417s, total=1.303563667s, batches=16, rows=1000000
240+
Average: first_row=1.232885s, total=1.299084583s
241+
242+
=== Summary ===
243+
The 'time to first row' metric shows how quickly the first output
244+
batch is produced. With incremental emission (smaller batch sizes),
245+
this should be significantly faster than all-at-once emission.
246+
*/

datafusion/expr-common/src/groups_accumulator.rs

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,35 +23,46 @@ use datafusion_common::{Result, not_impl_err};
2323
/// Describes how many rows should be emitted during grouping.
2424
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2525
pub enum EmitTo {
26-
/// Emit all groups
27-
All,
2826
/// Emit only the first `n` groups and shift all existing group
2927
/// indexes down by `n`.
3028
///
3129
/// For example, if `n=10`, group_index `0, 1, ... 9` are emitted
3230
/// and group indexes `10, 11, 12, ...` become `0, 1, 2, ...`.
3331
First(usize),
32+
33+
/// Emit the next `n` groups without shifting indices.
34+
///
35+
/// Used during final drain after all input is processed. More efficient than
36+
/// `First` for draining because it doesn't update hash table indices.
37+
/// Returns empty arrays when all groups have been emitted.
38+
Next(usize),
3439
}
3540

3641
impl EmitTo {
3742
/// Removes the number of rows from `v` required to emit the right
3843
/// number of rows, returning a `Vec` with elements taken, and the
3944
/// remaining values in `v`.
4045
///
41-
/// This avoids copying if Self::All
46+
/// For `Next(n)`, clamps to `v.len()` to handle draining the final elements.
4247
pub fn take_needed<T>(&self, v: &mut Vec<T>) -> Vec<T> {
48+
let n = match self {
49+
Self::First(n) => *n,
50+
Self::Next(n) => (*n).min(v.len()),
51+
};
52+
53+
if n >= v.len() {
54+
std::mem::take(v)
55+
} else {
56+
let mut t = v.split_off(n);
57+
std::mem::swap(v, &mut t);
58+
t
59+
}
60+
}
61+
62+
/// Returns the batch size for this emit operation.
63+
pub fn batch_size(&self) -> usize {
4364
match self {
44-
Self::All => {
45-
// Take the entire vector, leave new (empty) vector
46-
std::mem::take(v)
47-
}
48-
Self::First(n) => {
49-
// get end n+1,.. values into t
50-
let mut t = v.split_off(*n);
51-
// leave n+1,.. in v
52-
std::mem::swap(v, &mut t);
53-
t
54-
}
65+
Self::First(n) | Self::Next(n) => *n,
5566
}
5667
}
5768
}
@@ -145,16 +156,17 @@ pub trait GroupsAccumulator: Send {
145156
/// each group, and `evaluate` will produce that running sum as
146157
/// its output for all groups, in group_index order
147158
///
148-
/// If `emit_to` is [`EmitTo::All`], the accumulator should
149-
/// return all groups and release / reset its internal state
150-
/// equivalent to when it was first created.
151-
///
152159
/// If `emit_to` is [`EmitTo::First`], only the first `n` groups
153160
/// should be emitted and the state for those first groups
154161
/// removed. State for the remaining groups must be retained for
155162
/// future use. The group_indices on subsequent calls to
156163
/// `update_batch` or `merge_batch` will be shifted down by
157164
/// `n`. See [`EmitTo::First`] for more details.
165+
///
166+
/// If `emit_to` is [`EmitTo::Next`], the first `n` groups should
167+
/// be emitted and removed. Unlike `First`, subsequent group indices
168+
/// are not shifted (hash table updates are skipped). Used for final
169+
/// drain when no more lookups are needed.
158170
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef>;
159171

160172
/// Returns the intermediate aggregate state for this accumulator,

datafusion/ffi/src/udaf/groups_accumulator.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -438,24 +438,24 @@ impl GroupsAccumulator for ForeignGroupsAccumulator {
438438
#[repr(C)]
439439
#[derive(Debug, StableAbi)]
440440
pub enum FFI_EmitTo {
441-
All,
442441
First(usize),
442+
Next(usize),
443443
}
444444

445445
impl From<EmitTo> for FFI_EmitTo {
446446
fn from(value: EmitTo) -> Self {
447447
match value {
448-
EmitTo::All => Self::All,
449448
EmitTo::First(v) => Self::First(v),
449+
EmitTo::Next(v) => Self::Next(v),
450450
}
451451
}
452452
}
453453

454454
impl From<FFI_EmitTo> for EmitTo {
455455
fn from(value: FFI_EmitTo) -> Self {
456456
match value {
457-
FFI_EmitTo::All => Self::All,
458457
FFI_EmitTo::First(v) => Self::First(v),
458+
FFI_EmitTo::Next(v) => Self::Next(v),
459459
}
460460
}
461461
}
@@ -491,15 +491,15 @@ mod tests {
491491
3,
492492
)?;
493493

494-
let groups_bool = foreign_accum.evaluate(EmitTo::All)?;
494+
let groups_bool = foreign_accum.evaluate(EmitTo::Next(usize::MAX))?;
495495
let groups_bool = groups_bool.as_any().downcast_ref::<BooleanArray>().unwrap();
496496

497497
assert_eq!(
498498
groups_bool,
499499
create_array!(Boolean, vec![Some(true), Some(false), None]).as_ref()
500500
);
501501

502-
let state = foreign_accum.state(EmitTo::All)?;
502+
let state = foreign_accum.state(EmitTo::Next(usize::MAX))?;
503503
assert_eq!(state.len(), 1);
504504

505505
// To verify merging batches works, create a second state to add in
@@ -509,7 +509,7 @@ mod tests {
509509

510510
let opt_filter = create_array!(Boolean, vec![true]);
511511
foreign_accum.merge_batch(&second_states, &[0], Some(opt_filter.as_ref()), 1)?;
512-
let groups_bool = foreign_accum.evaluate(EmitTo::All)?;
512+
let groups_bool = foreign_accum.evaluate(EmitTo::Next(usize::MAX))?;
513513
assert_eq!(groups_bool.len(), 1);
514514
assert_eq!(
515515
groups_bool.as_ref(),
@@ -540,7 +540,7 @@ mod tests {
540540
/// This test ensures all enum values are properly translated
541541
#[test]
542542
fn test_all_emit_to_round_trip() -> Result<()> {
543-
test_emit_to_round_trip(EmitTo::All)?;
543+
test_emit_to_round_trip(EmitTo::Next(usize::MAX))?;
544544
test_emit_to_round_trip(EmitTo::First(10))?;
545545

546546
Ok(())

0 commit comments

Comments
 (0)