Skip to content
Merged
23 changes: 18 additions & 5 deletions datafusion/expr-common/src/accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,30 @@ pub trait Accumulator: Send + Sync + Debug {
/// running sum.
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>;

/// Returns the final aggregate value, consuming the internal state.
/// Returns the final aggregate value.
///
/// For example, the `SUM` accumulator maintains a running sum,
/// and `evaluate` will produce that running sum as its output.
///
/// This function should not be called twice, otherwise it will
/// result in potentially non-deterministic behavior.
///
/// This function gets `&mut self` to allow for the accumulator to build
/// arrow-compatible internal state that can be returned without copying
/// when possible (for example distinct strings)
/// when possible (for example distinct strings).
///
/// ## Correctness
///
/// This function must not consume the internal state, as it is also used in window
/// aggregate functions where it can be executed multiple times depending on the
/// current window frame. Consuming the internal state can cause the next invocation
/// to have incorrect results.
///
/// - Even if this accumulator doesn't implement [`retract_batch`] it may still be used
/// in window aggregate functions where the window frame is
/// `ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW`
///
/// It is fine to modify the state (e.g. re-order elements within internal state vec) so long
/// as this doesn't cause an incorrect computation on the next call of evaluate.
///
/// [`retract_batch`]: Self::retract_batch
fn evaluate(&mut self) -> Result<ScalarValue>;

/// Returns the allocated size required for this accumulator, in
Expand Down
77 changes: 65 additions & 12 deletions datafusion/functions-aggregate/src/percentile_cont.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::collections::HashMap;
use std::fmt::Debug;
use std::mem::{size_of, size_of_val};
use std::sync::Arc;
Expand Down Expand Up @@ -52,7 +53,7 @@ use datafusion_expr::{
};
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer;
use datafusion_functions_aggregate_common::utils::{GenericDistinctBuffer, Hashable};
use datafusion_macros::user_doc;

use crate::utils::validate_percentile_expr;
Expand Down Expand Up @@ -427,14 +428,48 @@ impl<T: ArrowNumericType + Debug> Accumulator for PercentileContAccumulator<T> {
}

fn evaluate(&mut self) -> Result<ScalarValue> {
let d = std::mem::take(&mut self.all_values);
let value = calculate_percentile::<T>(d, self.percentile);
let value = calculate_percentile::<T>(&mut self.all_values, self.percentile);
ScalarValue::new_primitive::<T>(value, &T::DATA_TYPE)
}

fn size(&self) -> usize {
size_of_val(self) + self.all_values.capacity() * size_of::<T::Native>()
}

fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let mut to_remove: HashMap<ScalarValue, usize> = HashMap::new();
for i in 0..values[0].len() {
let v = ScalarValue::try_from_array(&values[0], i)?;
if !v.is_null() {
*to_remove.entry(v).or_default() += 1;
}
}

let mut i = 0;
while i < self.all_values.len() {
let k =
ScalarValue::new_primitive::<T>(Some(self.all_values[i]), &T::DATA_TYPE)?;
if let Some(count) = to_remove.get_mut(&k)
&& *count > 0
{
self.all_values.swap_remove(i);
*count -= 1;
if *count == 0 {
to_remove.remove(&k);
if to_remove.is_empty() {
break;
}
}
} else {
i += 1;
}
}
Ok(())
}

fn supports_retract_batch(&self) -> bool {
true
}
}

/// The percentile_cont groups accumulator accumulates the raw input values
Expand Down Expand Up @@ -549,13 +584,13 @@ impl<T: ArrowNumericType + Send> GroupsAccumulator

fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
// Emit values
let emit_group_values = emit_to.take_needed(&mut self.group_values);
let mut emit_group_values = emit_to.take_needed(&mut self.group_values);

// Calculate percentile for each group
let mut evaluate_result_builder =
PrimitiveBuilder::<T>::with_capacity(emit_group_values.len());
for values in emit_group_values {
let value = calculate_percentile::<T>(values, self.percentile);
for values in &mut emit_group_values {
let value = calculate_percentile::<T>(values.as_mut_slice(), self.percentile);
evaluate_result_builder.append_option(value);
}

Expand Down Expand Up @@ -652,17 +687,31 @@ impl<T: ArrowNumericType + Debug> Accumulator for DistinctPercentileContAccumula
}

fn evaluate(&mut self) -> Result<ScalarValue> {
let d = std::mem::take(&mut self.distinct_values.values)
.into_iter()
.map(|v| v.0)
.collect::<Vec<_>>();
let value = calculate_percentile::<T>(d, self.percentile);
let mut values: Vec<T::Native> =
self.distinct_values.values.iter().map(|v| v.0).collect();
let value = calculate_percentile::<T>(&mut values, self.percentile);
ScalarValue::new_primitive::<T>(value, &T::DATA_TYPE)
}

fn size(&self) -> usize {
size_of_val(self) + self.distinct_values.size()
}

fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}

let arr = values[0].as_primitive::<T>();
for value in arr.iter().flatten() {
self.distinct_values.values.remove(&Hashable(value));
}
Ok(())
}

fn supports_retract_batch(&self) -> bool {
true
}
}

/// Calculate the percentile value for a given set of values.
Expand All @@ -672,8 +721,12 @@ impl<T: ArrowNumericType + Debug> Accumulator for DistinctPercentileContAccumula
/// For percentile p and n values:
/// - If p * (n-1) is an integer, return the value at that position
/// - Otherwise, interpolate between the two closest values
///
/// Note: This function takes a mutable slice and sorts it in place, but does not
/// consume the data. This is important for window frame queries where evaluate()
/// may be called multiple times on the same accumulator state.
fn calculate_percentile<T: ArrowNumericType>(
mut values: Vec<T::Native>,
values: &mut [T::Native],
percentile: f64,
) -> Option<T::Native> {
let cmp = |x: &T::Native, y: &T::Native| x.compare(*y);
Expand Down
13 changes: 6 additions & 7 deletions datafusion/functions-aggregate/src/string_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,14 +384,13 @@ impl Accumulator for SimpleStringAggAccumulator {
}

fn evaluate(&mut self) -> Result<ScalarValue> {
let result = if self.has_value {
ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string)))
if self.has_value {
Ok(ScalarValue::LargeUtf8(Some(
self.accumulated_string.clone(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is unavoidable 🙁

I might need to think on this a bit to see if there are ways around requiring this clone 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this approach is fine for now

)))
} else {
ScalarValue::LargeUtf8(None)
};

self.has_value = false;
Ok(result)
Ok(ScalarValue::LargeUtf8(None))
}
}

fn size(&self) -> usize {
Expand Down
137 changes: 137 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,102 @@ ORDER BY tags, timestamp;
4 tag2 90 75 80 95
5 tag2 100 80 80 100

###########
# Issue #19612: Test that percentile_cont produces correct results
# in window frame queries. Previously percentile_cont consumed its internal state
# during evaluate(), causing incorrect results when called multiple times.
###########

# Test percentile_cont sliding window (same as median)
query ITRR
SELECT
timestamp,
tags,
value,
percentile_cont(value, 0.5) OVER (
PARTITION BY tags
ORDER BY timestamp
ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
) AS value_percentile_50
FROM median_window_test
ORDER BY tags, timestamp;
----
1 tag1 10 15
2 tag1 20 20
3 tag1 30 30
4 tag1 40 40
5 tag1 50 45
1 tag2 60 65
2 tag2 70 70
3 tag2 80 80
4 tag2 90 90
5 tag2 100 95

# Test percentile_cont non-sliding window
query ITRRRR
SELECT
timestamp,
tags,
value,
percentile_cont(value, 0.5) OVER (
PARTITION BY tags
ORDER BY timestamp
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
) AS value_percentile_unbounded_preceding,
percentile_cont(value, 0.5) OVER (
PARTITION BY tags
ORDER BY timestamp
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS value_percentile_unbounded_both,
percentile_cont(value, 0.5) OVER (
PARTITION BY tags
ORDER BY timestamp
ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING
) AS value_percentile_unbounded_following
FROM median_window_test
ORDER BY tags, timestamp;
----
1 tag1 10 10 30 30
2 tag1 20 15 30 35
3 tag1 30 20 30 40
4 tag1 40 25 30 45
5 tag1 50 30 30 50
1 tag2 60 60 80 80
2 tag2 70 65 80 85
3 tag2 80 70 80 90
4 tag2 90 75 80 95
5 tag2 100 80 80 100

# Test percentile_cont with different percentile values
query ITRRR
SELECT
timestamp,
tags,
value,
percentile_cont(value, 0.25) OVER (
PARTITION BY tags
ORDER BY timestamp
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
) AS p25,
percentile_cont(value, 0.75) OVER (
PARTITION BY tags
ORDER BY timestamp
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
) AS p75
FROM median_window_test
ORDER BY tags, timestamp;
----
1 tag1 10 10 10
2 tag1 20 12.5 17.5
3 tag1 30 15 25
4 tag1 40 17.5 32.5
5 tag1 50 20 40
1 tag2 60 60 60
2 tag2 70 62.5 67.5
3 tag2 80 65 75
4 tag2 90 67.5 82.5
5 tag2 100 70 90

statement ok
DROP TABLE median_window_test;

Expand Down Expand Up @@ -8250,3 +8346,44 @@ query R
select percentile_cont(null, 0.5);
----
NULL

# Test string_agg window frame behavior (fix for issue #19612)
statement ok
CREATE TABLE string_agg_window_test (
id INT,
grp VARCHAR,
val VARCHAR
);

statement ok
INSERT INTO string_agg_window_test (id, grp, val) VALUES
(1, 'A', 'a'),
(2, 'A', 'b'),
(3, 'A', 'c'),
(1, 'B', 'x'),
(2, 'B', 'y'),
(3, 'B', 'z');

# Test string_agg with ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
# The function should maintain state correctly across multiple evaluate() calls
query ITT
SELECT
id,
grp,
string_agg(val, ',') OVER (
PARTITION BY grp
ORDER BY id
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
) AS cumulative_string
FROM string_agg_window_test
ORDER BY grp, id;
----
1 A a
2 A a,b
3 A a,b,c
1 B x
2 B x,y
3 B x,y,z

statement ok
DROP TABLE string_agg_window_test;