Skip to content

Commit 45fb0b4

Browse files
fix(accumulators): preserve state in evaluate() for window frame queries (#19618)
Part of #19612 Rationale for this change When aggregate functions are used with window frames like ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, DataFusion uses PlainAggregateWindowExpr which calls [evaluate()](vscode-file://vscode-app/c:/Users/HP/AppData/Local/Programs/Microsoft%20VS%20Code/resources/app/out/vs/code/electron-browser/workbench/workbench.html) multiple times on the same accumulator instance. Accumulators that use [std::mem::take()](vscode-file://vscode-app/c:/Users/HP/AppData/Local/Programs/Microsoft%20VS%20Code/resources/app/out/vs/code/electron-browser/workbench/workbench.html) in their [evaluate()](vscode-file://vscode-app/c:/Users/HP/AppData/Local/Programs/Microsoft%20VS%20Code/resources/app/out/vs/code/electron-browser/workbench/workbench.html) method consume their internal state, causing incorrect results on subsequent calls. What changes are included in this PR? percentile_cont: Modified [evaluate()](vscode-file://vscode-app/c:/Users/HP/AppData/Local/Programs/Microsoft%20VS%20Code/resources/app/out/vs/code/electron-browser/workbench/workbench.html) to use mutable reference instead of consuming the Vec. Added retract_batch() support. string_agg: Changed [SimpleStringAggAccumulator::evaluate()](vscode-file://vscode-app/c:/Users/HP/AppData/Local/Programs/Microsoft%20VS%20Code/resources/app/out/vs/code/electron-browser/workbench/workbench.html) to clone instead of take. Added comprehensive test cases in [aggregate.slt](vscode-file://vscode-app/c:/Users/HP/AppData/Local/Programs/Microsoft%20VS%20Code/resources/app/out/vs/code/electron-browser/workbench/workbench.html) Added documentation about window-compatible accumulators Are these changes tested? Yes, added sqllogictest cases that verify: median() and percentile_cont(0.5) produce identical results in window frames percentile_cont with different percentiles works correctly string_agg accumulates correctly across window frame evaluations Are there any user-facing changes? No breaking changes. This is a bug fix that ensures aggregate functions work correctly in window contexts.
1 parent afc9121 commit 45fb0b4

File tree

4 files changed

+226
-24
lines changed

4 files changed

+226
-24
lines changed

datafusion/expr-common/src/accumulator.rs

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,30 @@ pub trait Accumulator: Send + Sync + Debug {
5858
/// running sum.
5959
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>;
6060

61-
/// Returns the final aggregate value, consuming the internal state.
61+
/// Returns the final aggregate value.
6262
///
6363
/// For example, the `SUM` accumulator maintains a running sum,
6464
/// and `evaluate` will produce that running sum as its output.
6565
///
66-
/// This function should not be called twice, otherwise it will
67-
/// result in potentially non-deterministic behavior.
68-
///
6966
/// This function gets `&mut self` to allow for the accumulator to build
7067
/// arrow-compatible internal state that can be returned without copying
71-
/// when possible (for example distinct strings)
68+
/// when possible (for example distinct strings).
69+
///
70+
/// ## Correctness
71+
///
72+
/// This function must not consume the internal state, as it is also used in window
73+
/// aggregate functions where it can be executed multiple times depending on the
74+
/// current window frame. Consuming the internal state can cause the next invocation
75+
/// to have incorrect results.
76+
///
77+
/// - Even if this accumulator doesn't implement [`retract_batch`] it may still be used
78+
/// in window aggregate functions where the window frame is
79+
/// `ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW`
80+
///
81+
/// It is fine to modify the state (e.g. re-order elements within internal state vec) so long
82+
/// as this doesn't cause an incorrect computation on the next call of evaluate.
83+
///
84+
/// [`retract_batch`]: Self::retract_batch
7285
fn evaluate(&mut self) -> Result<ScalarValue>;
7386

7487
/// Returns the allocated size required for this accumulator, in

datafusion/functions-aggregate/src/percentile_cont.rs

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::collections::HashMap;
1819
use std::fmt::Debug;
1920
use std::mem::{size_of, size_of_val};
2021
use std::sync::Arc;
@@ -52,7 +53,7 @@ use datafusion_expr::{
5253
};
5354
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
5455
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
55-
use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer;
56+
use datafusion_functions_aggregate_common::utils::{GenericDistinctBuffer, Hashable};
5657
use datafusion_macros::user_doc;
5758

5859
use crate::utils::validate_percentile_expr;
@@ -427,14 +428,48 @@ impl<T: ArrowNumericType + Debug> Accumulator for PercentileContAccumulator<T> {
427428
}
428429

429430
fn evaluate(&mut self) -> Result<ScalarValue> {
430-
let d = std::mem::take(&mut self.all_values);
431-
let value = calculate_percentile::<T>(d, self.percentile);
431+
let value = calculate_percentile::<T>(&mut self.all_values, self.percentile);
432432
ScalarValue::new_primitive::<T>(value, &T::DATA_TYPE)
433433
}
434434

435435
fn size(&self) -> usize {
436436
size_of_val(self) + self.all_values.capacity() * size_of::<T::Native>()
437437
}
438+
439+
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
440+
let mut to_remove: HashMap<ScalarValue, usize> = HashMap::new();
441+
for i in 0..values[0].len() {
442+
let v = ScalarValue::try_from_array(&values[0], i)?;
443+
if !v.is_null() {
444+
*to_remove.entry(v).or_default() += 1;
445+
}
446+
}
447+
448+
let mut i = 0;
449+
while i < self.all_values.len() {
450+
let k =
451+
ScalarValue::new_primitive::<T>(Some(self.all_values[i]), &T::DATA_TYPE)?;
452+
if let Some(count) = to_remove.get_mut(&k)
453+
&& *count > 0
454+
{
455+
self.all_values.swap_remove(i);
456+
*count -= 1;
457+
if *count == 0 {
458+
to_remove.remove(&k);
459+
if to_remove.is_empty() {
460+
break;
461+
}
462+
}
463+
} else {
464+
i += 1;
465+
}
466+
}
467+
Ok(())
468+
}
469+
470+
fn supports_retract_batch(&self) -> bool {
471+
true
472+
}
438473
}
439474

440475
/// The percentile_cont groups accumulator accumulates the raw input values
@@ -549,13 +584,13 @@ impl<T: ArrowNumericType + Send> GroupsAccumulator
549584

550585
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
551586
// Emit values
552-
let emit_group_values = emit_to.take_needed(&mut self.group_values);
587+
let mut emit_group_values = emit_to.take_needed(&mut self.group_values);
553588

554589
// Calculate percentile for each group
555590
let mut evaluate_result_builder =
556591
PrimitiveBuilder::<T>::with_capacity(emit_group_values.len());
557-
for values in emit_group_values {
558-
let value = calculate_percentile::<T>(values, self.percentile);
592+
for values in &mut emit_group_values {
593+
let value = calculate_percentile::<T>(values.as_mut_slice(), self.percentile);
559594
evaluate_result_builder.append_option(value);
560595
}
561596

@@ -652,17 +687,31 @@ impl<T: ArrowNumericType + Debug> Accumulator for DistinctPercentileContAccumula
652687
}
653688

654689
fn evaluate(&mut self) -> Result<ScalarValue> {
655-
let d = std::mem::take(&mut self.distinct_values.values)
656-
.into_iter()
657-
.map(|v| v.0)
658-
.collect::<Vec<_>>();
659-
let value = calculate_percentile::<T>(d, self.percentile);
690+
let mut values: Vec<T::Native> =
691+
self.distinct_values.values.iter().map(|v| v.0).collect();
692+
let value = calculate_percentile::<T>(&mut values, self.percentile);
660693
ScalarValue::new_primitive::<T>(value, &T::DATA_TYPE)
661694
}
662695

663696
fn size(&self) -> usize {
664697
size_of_val(self) + self.distinct_values.size()
665698
}
699+
700+
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
701+
if values.is_empty() {
702+
return Ok(());
703+
}
704+
705+
let arr = values[0].as_primitive::<T>();
706+
for value in arr.iter().flatten() {
707+
self.distinct_values.values.remove(&Hashable(value));
708+
}
709+
Ok(())
710+
}
711+
712+
fn supports_retract_batch(&self) -> bool {
713+
true
714+
}
666715
}
667716

668717
/// Calculate the percentile value for a given set of values.
@@ -672,8 +721,12 @@ impl<T: ArrowNumericType + Debug> Accumulator for DistinctPercentileContAccumula
672721
/// For percentile p and n values:
673722
/// - If p * (n-1) is an integer, return the value at that position
674723
/// - Otherwise, interpolate between the two closest values
724+
///
725+
/// Note: This function takes a mutable slice and sorts it in place, but does not
726+
/// consume the data. This is important for window frame queries where evaluate()
727+
/// may be called multiple times on the same accumulator state.
675728
fn calculate_percentile<T: ArrowNumericType>(
676-
mut values: Vec<T::Native>,
729+
values: &mut [T::Native],
677730
percentile: f64,
678731
) -> Option<T::Native> {
679732
let cmp = |x: &T::Native, y: &T::Native| x.compare(*y);

datafusion/functions-aggregate/src/string_agg.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -384,14 +384,13 @@ impl Accumulator for SimpleStringAggAccumulator {
384384
}
385385

386386
fn evaluate(&mut self) -> Result<ScalarValue> {
387-
let result = if self.has_value {
388-
ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string)))
387+
if self.has_value {
388+
Ok(ScalarValue::LargeUtf8(Some(
389+
self.accumulated_string.clone(),
390+
)))
389391
} else {
390-
ScalarValue::LargeUtf8(None)
391-
};
392-
393-
self.has_value = false;
394-
Ok(result)
392+
Ok(ScalarValue::LargeUtf8(None))
393+
}
395394
}
396395

397396
fn size(&self) -> usize {

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,6 +1130,102 @@ ORDER BY tags, timestamp;
11301130
4 tag2 90 75 80 95
11311131
5 tag2 100 80 80 100
11321132

1133+
###########
1134+
# Issue #19612: Test that percentile_cont produces correct results
1135+
# in window frame queries. Previously percentile_cont consumed its internal state
1136+
# during evaluate(), causing incorrect results when called multiple times.
1137+
###########
1138+
1139+
# Test percentile_cont sliding window (same as median)
1140+
query ITRR
1141+
SELECT
1142+
timestamp,
1143+
tags,
1144+
value,
1145+
percentile_cont(value, 0.5) OVER (
1146+
PARTITION BY tags
1147+
ORDER BY timestamp
1148+
ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
1149+
) AS value_percentile_50
1150+
FROM median_window_test
1151+
ORDER BY tags, timestamp;
1152+
----
1153+
1 tag1 10 15
1154+
2 tag1 20 20
1155+
3 tag1 30 30
1156+
4 tag1 40 40
1157+
5 tag1 50 45
1158+
1 tag2 60 65
1159+
2 tag2 70 70
1160+
3 tag2 80 80
1161+
4 tag2 90 90
1162+
5 tag2 100 95
1163+
1164+
# Test percentile_cont non-sliding window
1165+
query ITRRRR
1166+
SELECT
1167+
timestamp,
1168+
tags,
1169+
value,
1170+
percentile_cont(value, 0.5) OVER (
1171+
PARTITION BY tags
1172+
ORDER BY timestamp
1173+
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
1174+
) AS value_percentile_unbounded_preceding,
1175+
percentile_cont(value, 0.5) OVER (
1176+
PARTITION BY tags
1177+
ORDER BY timestamp
1178+
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
1179+
) AS value_percentile_unbounded_both,
1180+
percentile_cont(value, 0.5) OVER (
1181+
PARTITION BY tags
1182+
ORDER BY timestamp
1183+
ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING
1184+
) AS value_percentile_unbounded_following
1185+
FROM median_window_test
1186+
ORDER BY tags, timestamp;
1187+
----
1188+
1 tag1 10 10 30 30
1189+
2 tag1 20 15 30 35
1190+
3 tag1 30 20 30 40
1191+
4 tag1 40 25 30 45
1192+
5 tag1 50 30 30 50
1193+
1 tag2 60 60 80 80
1194+
2 tag2 70 65 80 85
1195+
3 tag2 80 70 80 90
1196+
4 tag2 90 75 80 95
1197+
5 tag2 100 80 80 100
1198+
1199+
# Test percentile_cont with different percentile values
1200+
query ITRRR
1201+
SELECT
1202+
timestamp,
1203+
tags,
1204+
value,
1205+
percentile_cont(value, 0.25) OVER (
1206+
PARTITION BY tags
1207+
ORDER BY timestamp
1208+
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
1209+
) AS p25,
1210+
percentile_cont(value, 0.75) OVER (
1211+
PARTITION BY tags
1212+
ORDER BY timestamp
1213+
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
1214+
) AS p75
1215+
FROM median_window_test
1216+
ORDER BY tags, timestamp;
1217+
----
1218+
1 tag1 10 10 10
1219+
2 tag1 20 12.5 17.5
1220+
3 tag1 30 15 25
1221+
4 tag1 40 17.5 32.5
1222+
5 tag1 50 20 40
1223+
1 tag2 60 60 60
1224+
2 tag2 70 62.5 67.5
1225+
3 tag2 80 65 75
1226+
4 tag2 90 67.5 82.5
1227+
5 tag2 100 70 90
1228+
11331229
statement ok
11341230
DROP TABLE median_window_test;
11351231

@@ -8250,3 +8346,44 @@ query R
82508346
select percentile_cont(null, 0.5);
82518347
----
82528348
NULL
8349+
8350+
# Test string_agg window frame behavior (fix for issue #19612)
8351+
statement ok
8352+
CREATE TABLE string_agg_window_test (
8353+
id INT,
8354+
grp VARCHAR,
8355+
val VARCHAR
8356+
);
8357+
8358+
statement ok
8359+
INSERT INTO string_agg_window_test (id, grp, val) VALUES
8360+
(1, 'A', 'a'),
8361+
(2, 'A', 'b'),
8362+
(3, 'A', 'c'),
8363+
(1, 'B', 'x'),
8364+
(2, 'B', 'y'),
8365+
(3, 'B', 'z');
8366+
8367+
# Test string_agg with ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
8368+
# The function should maintain state correctly across multiple evaluate() calls
8369+
query ITT
8370+
SELECT
8371+
id,
8372+
grp,
8373+
string_agg(val, ',') OVER (
8374+
PARTITION BY grp
8375+
ORDER BY id
8376+
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
8377+
) AS cumulative_string
8378+
FROM string_agg_window_test
8379+
ORDER BY grp, id;
8380+
----
8381+
1 A a
8382+
2 A a,b
8383+
3 A a,b,c
8384+
1 B x
8385+
2 B x,y
8386+
3 B x,y,z
8387+
8388+
statement ok
8389+
DROP TABLE string_agg_window_test;

0 commit comments

Comments
 (0)