Skip to content

Commit 6954497

Browse files
fix(accumulators): preserve state in evaluate() for window frame queries
This commit fixes issue #19612 where accumulators that don't implement retract_batch exhibit buggy behavior in window frame queries. ## Problem When aggregate functions are used with window frames like `ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW`, DataFusion uses PlainAggregateWindowExpr which calls evaluate() multiple times on the same accumulator instance. Accumulators that use std::mem::take() in their evaluate() method consume their internal state, causing incorrect results on subsequent calls. ## Solution 1. **percentile_cont**: Modified evaluate() to use mutable reference instead of consuming the Vec. Added retract_batch() support for both PercentileContAccumulator and DistinctPercentileContAccumulator. 2. **string_agg**: Changed SimpleStringAggAccumulator::evaluate() to clone the accumulated string instead of taking it. ## Changes - datafusion/functions-aggregate/src/percentile_cont.rs: - Changed calculate_percentile() to take &mut [T::Native] instead of Vec<T::Native> - Updated PercentileContAccumulator::evaluate() to pass reference - Updated DistinctPercentileContAccumulator::evaluate() to clone values - Added retract_batch() implementation using HashMap for efficient removal - Updated PercentileContGroupsAccumulator::evaluate() for consistency - datafusion/functions-aggregate/src/string_agg.rs: - Changed evaluate() to use clone() instead of std::mem::take() - datafusion/sqllogictest/test_files/aggregate.slt: - Added test cases for percentile_cont with window frames - Added test comparing median() vs percentile_cont(0.5) behavior - Added test for string_agg cumulative window frame - docs/source/library-user-guide/functions/adding-udfs.md: - Added documentation about window-compatible accumulators - Explained evaluate() state preservation requirements - Documented retract_batch() implementation guidance Closes #19612
1 parent 8809dae commit 6954497

File tree

4 files changed

+280
-18
lines changed

4 files changed

+280
-18
lines changed

datafusion/functions-aggregate/src/percentile_cont.rs

Lines changed: 77 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::collections::HashMap;
1819
use std::fmt::{Debug, Formatter};
1920
use std::mem::{size_of, size_of_val};
2021
use std::sync::Arc;
@@ -52,7 +53,7 @@ use datafusion_expr::{
5253
};
5354
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
5455
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
55-
use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer;
56+
use datafusion_functions_aggregate_common::utils::{GenericDistinctBuffer, Hashable};
5657
use datafusion_macros::user_doc;
5758

5859
use crate::utils::validate_percentile_expr;
@@ -533,14 +534,57 @@ impl<T: ArrowNumericType> Accumulator for PercentileContAccumulator<T> {
533534
}
534535

535536
fn evaluate(&mut self) -> Result<ScalarValue> {
536-
let d = std::mem::take(&mut self.all_values);
537-
let value = calculate_percentile::<T>(d, self.percentile);
537+
let value = calculate_percentile::<T>(&mut self.all_values, self.percentile);
538538
ScalarValue::new_primitive::<T>(value, &self.data_type)
539539
}
540540

541541
fn size(&self) -> usize {
542542
size_of_val(self) + self.all_values.capacity() * size_of::<T::Native>()
543543
}
544+
545+
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
546+
// Cast to target type if needed (e.g., integer to Float64)
547+
let values = if values[0].data_type() != &self.data_type {
548+
arrow::compute::cast(&values[0], &self.data_type)?
549+
} else {
550+
Arc::clone(&values[0])
551+
};
552+
553+
let mut to_remove: HashMap<ScalarValue, usize> = HashMap::new();
554+
for i in 0..values.len() {
555+
let v = ScalarValue::try_from_array(&values, i)?;
556+
if !v.is_null() {
557+
*to_remove.entry(v).or_default() += 1;
558+
}
559+
}
560+
561+
let mut i = 0;
562+
while i < self.all_values.len() {
563+
let k = ScalarValue::new_primitive::<T>(
564+
Some(self.all_values[i]),
565+
&self.data_type,
566+
)?;
567+
if let Some(count) = to_remove.get_mut(&k)
568+
&& *count > 0
569+
{
570+
self.all_values.swap_remove(i);
571+
*count -= 1;
572+
if *count == 0 {
573+
to_remove.remove(&k);
574+
if to_remove.is_empty() {
575+
break;
576+
}
577+
}
578+
} else {
579+
i += 1;
580+
}
581+
}
582+
Ok(())
583+
}
584+
585+
fn supports_retract_batch(&self) -> bool {
586+
true
587+
}
544588
}
545589

546590
/// The percentile_cont groups accumulator accumulates the raw input values
@@ -665,13 +709,13 @@ impl<T: ArrowNumericType + Send> GroupsAccumulator
665709

666710
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
667711
// Emit values
668-
let emit_group_values = emit_to.take_needed(&mut self.group_values);
712+
let mut emit_group_values = emit_to.take_needed(&mut self.group_values);
669713

670714
// Calculate percentile for each group
671715
let mut evaluate_result_builder =
672716
PrimitiveBuilder::<T>::new().with_data_type(self.data_type.clone());
673-
for values in emit_group_values {
674-
let value = calculate_percentile::<T>(values, self.percentile);
717+
for values in &mut emit_group_values {
718+
let value = calculate_percentile::<T>(values.as_mut_slice(), self.percentile);
675719
evaluate_result_builder.append_option(value);
676720
}
677721

@@ -768,17 +812,35 @@ impl<T: ArrowNumericType + Debug> Accumulator for DistinctPercentileContAccumula
768812
}
769813

770814
fn evaluate(&mut self) -> Result<ScalarValue> {
771-
let d = std::mem::take(&mut self.distinct_values.values)
772-
.into_iter()
815+
let mut values: Vec<T::Native> = self
816+
.distinct_values
817+
.values
818+
.iter()
773819
.map(|v| v.0)
774-
.collect::<Vec<_>>();
775-
let value = calculate_percentile::<T>(d, self.percentile);
820+
.collect();
821+
let value = calculate_percentile::<T>(&mut values, self.percentile);
776822
ScalarValue::new_primitive::<T>(value, &self.data_type)
777823
}
778824

779825
fn size(&self) -> usize {
780826
size_of_val(self) + self.distinct_values.size()
781827
}
828+
829+
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
830+
if values.is_empty() {
831+
return Ok(());
832+
}
833+
834+
let arr = values[0].as_primitive::<T>();
835+
for value in arr.iter().flatten() {
836+
self.distinct_values.values.remove(&Hashable(value));
837+
}
838+
Ok(())
839+
}
840+
841+
fn supports_retract_batch(&self) -> bool {
842+
true
843+
}
782844
}
783845

784846
/// Calculate the percentile value for a given set of values.
@@ -788,8 +850,12 @@ impl<T: ArrowNumericType + Debug> Accumulator for DistinctPercentileContAccumula
788850
/// For percentile p and n values:
789851
/// - If p * (n-1) is an integer, return the value at that position
790852
/// - Otherwise, interpolate between the two closest values
853+
///
854+
/// Note: This function takes a mutable slice and sorts it in place, but does not
855+
/// consume the data. This is important for window frame queries where evaluate()
856+
/// may be called multiple times on the same accumulator state.
791857
fn calculate_percentile<T: ArrowNumericType>(
792-
mut values: Vec<T::Native>,
858+
values: &mut [T::Native],
793859
percentile: f64,
794860
) -> Option<T::Native> {
795861
let cmp = |x: &T::Native, y: &T::Native| x.compare(*y);

datafusion/functions-aggregate/src/string_agg.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -384,14 +384,11 @@ impl Accumulator for SimpleStringAggAccumulator {
384384
}
385385

386386
fn evaluate(&mut self) -> Result<ScalarValue> {
387-
let result = if self.has_value {
388-
ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string)))
387+
if self.has_value {
388+
Ok(ScalarValue::LargeUtf8(Some(self.accumulated_string.clone())))
389389
} else {
390-
ScalarValue::LargeUtf8(None)
391-
};
392-
393-
self.has_value = false;
394-
Ok(result)
390+
Ok(ScalarValue::LargeUtf8(None))
391+
}
395392
}
396393

397394
fn size(&self) -> usize {

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8241,3 +8241,137 @@ NULL NULL NULL NULL
82418241

82428242
statement ok
82438243
drop table distinct_avg;
8244+
8245+
###########
8246+
# Issue #19612: Test that percentile_cont and median produce identical results
8247+
# in window frame queries with ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW.
8248+
# Previously percentile_cont consumed its internal state during evaluate(),
8249+
# causing incorrect results when called multiple times in window queries.
8250+
###########
8251+
8252+
# Test percentile_cont window frame behavior (fix for issue #19612)
8253+
statement ok
8254+
CREATE TABLE percentile_window_test (
8255+
timestamp INT,
8256+
tags VARCHAR,
8257+
value DOUBLE
8258+
);
8259+
8260+
statement ok
8261+
INSERT INTO percentile_window_test (timestamp, tags, value) VALUES
8262+
(1, 'tag1', 10.0),
8263+
(2, 'tag1', 20.0),
8264+
(3, 'tag1', 30.0),
8265+
(4, 'tag1', 40.0),
8266+
(5, 'tag1', 50.0),
8267+
(1, 'tag2', 60.0),
8268+
(2, 'tag2', 70.0),
8269+
(3, 'tag2', 80.0),
8270+
(4, 'tag2', 90.0),
8271+
(5, 'tag2', 100.0);
8272+
8273+
# Test that median and percentile_cont(0.5) produce the same results
8274+
# with ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW frame.
8275+
# Both functions should maintain state correctly across multiple evaluate() calls.
8276+
query ITRRR
8277+
SELECT
8278+
timestamp,
8279+
tags,
8280+
value,
8281+
median(value) OVER (
8282+
PARTITION BY tags
8283+
ORDER BY timestamp
8284+
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
8285+
) AS value_median,
8286+
percentile_cont(value, 0.5) OVER (
8287+
PARTITION BY tags
8288+
ORDER BY timestamp
8289+
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
8290+
) AS value_percentile_50
8291+
FROM percentile_window_test
8292+
ORDER BY tags, timestamp;
8293+
----
8294+
1 tag1 10 10 10
8295+
2 tag1 20 15 15
8296+
3 tag1 30 20 20
8297+
4 tag1 40 25 25
8298+
5 tag1 50 30 30
8299+
1 tag2 60 60 60
8300+
2 tag2 70 65 65
8301+
3 tag2 80 70 70
8302+
4 tag2 90 75 75
8303+
5 tag2 100 80 80
8304+
8305+
# Test percentile_cont with different percentile values
8306+
query ITRRR
8307+
SELECT
8308+
timestamp,
8309+
tags,
8310+
value,
8311+
percentile_cont(value, 0.25) OVER (
8312+
PARTITION BY tags
8313+
ORDER BY timestamp
8314+
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
8315+
) AS p25,
8316+
percentile_cont(value, 0.75) OVER (
8317+
PARTITION BY tags
8318+
ORDER BY timestamp
8319+
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
8320+
) AS p75
8321+
FROM percentile_window_test
8322+
ORDER BY tags, timestamp;
8323+
----
8324+
1 tag1 10 10 10
8325+
2 tag1 20 12.5 17.5
8326+
3 tag1 30 15 25
8327+
4 tag1 40 17.5 32.5
8328+
5 tag1 50 20 40
8329+
1 tag2 60 60 60
8330+
2 tag2 70 62.5 67.5
8331+
3 tag2 80 65 75
8332+
4 tag2 90 67.5 82.5
8333+
5 tag2 100 70 90
8334+
8335+
statement ok
8336+
DROP TABLE percentile_window_test;
8337+
8338+
# Test string_agg window frame behavior (fix for issue #19612)
8339+
statement ok
8340+
CREATE TABLE string_agg_window_test (
8341+
id INT,
8342+
grp VARCHAR,
8343+
val VARCHAR
8344+
);
8345+
8346+
statement ok
8347+
INSERT INTO string_agg_window_test (id, grp, val) VALUES
8348+
(1, 'A', 'a'),
8349+
(2, 'A', 'b'),
8350+
(3, 'A', 'c'),
8351+
(1, 'B', 'x'),
8352+
(2, 'B', 'y'),
8353+
(3, 'B', 'z');
8354+
8355+
# Test string_agg with ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
8356+
# The function should maintain state correctly across multiple evaluate() calls
8357+
query ITT
8358+
SELECT
8359+
id,
8360+
grp,
8361+
string_agg(val, ',') OVER (
8362+
PARTITION BY grp
8363+
ORDER BY id
8364+
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
8365+
) AS cumulative_string
8366+
FROM string_agg_window_test
8367+
ORDER BY grp, id;
8368+
----
8369+
1 A a
8370+
2 A a,b
8371+
3 A a,b,c
8372+
1 B x
8373+
2 B x,y
8374+
3 B x,y,z
8375+
8376+
statement ok
8377+
DROP TABLE string_agg_window_test;

docs/source/library-user-guide/functions/adding-udfs.md

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,6 +1350,71 @@ async fn main() -> Result<()> {
13501350
[`create_udaf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udaf.html
13511351
[`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/advanced_udaf.rs
13521352

1353+
### Window Frame Compatible Accumulators
1354+
1355+
When an aggregate function is used in a window context with a sliding frame (e.g., `ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW`),
1356+
DataFusion may call `evaluate()` multiple times on the same accumulator instance to compute results for each row in the window.
1357+
This has important implications for how you implement your accumulator:
1358+
1359+
#### The `evaluate()` Method Must Not Consume State
1360+
1361+
The `evaluate()` method should return the current aggregate value **without modifying or consuming the accumulator's internal state**.
1362+
This is critical because:
1363+
1364+
1. **Multiple evaluations**: For window queries, `evaluate()` is called once per row in the partition
1365+
2. **State preservation**: The internal state must remain intact for subsequent `evaluate()` calls
1366+
1367+
**Incorrect implementation** (consumes state):
1368+
1369+
```rust
1370+
fn evaluate(&mut self) -> Result<ScalarValue> {
1371+
// BAD: std::mem::take() consumes the values, leaving an empty Vec
1372+
let values = std::mem::take(&mut self.values);
1373+
// After this call, self.values is empty and subsequent
1374+
// evaluate() calls will return incorrect results
1375+
calculate_result(values)
1376+
}
1377+
```
1378+
1379+
**Correct implementation** (preserves state):
1380+
1381+
```rust
1382+
fn evaluate(&mut self) -> Result<ScalarValue> {
1383+
// GOOD: Use a reference or clone to preserve state
1384+
calculate_result(&mut self.values)
1385+
// Or: calculate_result(self.values.clone())
1386+
}
1387+
```
1388+
1389+
#### Implementing `retract_batch` for Sliding Windows
1390+
1391+
For more efficient sliding window calculations, you can implement the `retract_batch` method.
1392+
This allows DataFusion to remove values that have "left" the window frame instead of recalculating from scratch:
1393+
1394+
```rust
1395+
impl Accumulator for MyAccumulator {
1396+
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1397+
// Remove the given values from the accumulator state
1398+
// This is the inverse of update_batch
1399+
for value in values[0].iter().flatten() {
1400+
self.remove_value(value);
1401+
}
1402+
Ok(())
1403+
}
1404+
1405+
fn supports_retract_batch(&self) -> bool {
1406+
true // Enable this optimization
1407+
}
1408+
}
1409+
```
1410+
1411+
If your accumulator does not support `retract_batch` (returns `false` from `supports_retract_batch()`),
1412+
DataFusion will use `PlainAggregateWindowExpr` which calls `evaluate()` multiple times on the same
1413+
accumulator. In this case, it is **essential** that your `evaluate()` method does not consume the
1414+
accumulator's state.
1415+
1416+
See [issue #19612](https://github.com/apache/datafusion/issues/19612) for more details on this behavior.
1417+
13531418
## Adding a Table UDF
13541419

13551420
A User-Defined Table Function (UDTF) is a function that takes parameters and returns a `TableProvider`.

0 commit comments

Comments
 (0)