diff --git a/datafusion/expr-common/src/accumulator.rs b/datafusion/expr-common/src/accumulator.rs index fc4e90114beea..3acf110a0bfc7 100644 --- a/datafusion/expr-common/src/accumulator.rs +++ b/datafusion/expr-common/src/accumulator.rs @@ -58,17 +58,30 @@ pub trait Accumulator: Send + Sync + Debug { /// running sum. fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>; - /// Returns the final aggregate value, consuming the internal state. + /// Returns the final aggregate value. /// /// For example, the `SUM` accumulator maintains a running sum, /// and `evaluate` will produce that running sum as its output. /// - /// This function should not be called twice, otherwise it will - /// result in potentially non-deterministic behavior. - /// /// This function gets `&mut self` to allow for the accumulator to build /// arrow-compatible internal state that can be returned without copying - /// when possible (for example distinct strings) + /// when possible (for example distinct strings). + /// + /// ## Correctness + /// + /// This function must not consume the internal state, as it is also used in window + /// aggregate functions where it can be executed multiple times depending on the + /// current window frame. Consuming the internal state can cause the next invocation + /// to have incorrect results. + /// + /// - Even if this accumulator doesn't implement [`retract_batch`] it may still be used + /// in window aggregate functions where the window frame is + /// `ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW` + /// + /// It is fine to modify the state (e.g. re-order elements within internal state vec) so long + /// as this doesn't cause an incorrect computation on the next call of evaluate. + /// + /// [`retract_batch`]: Self::retract_batch fn evaluate(&mut self) -> Result; /// Returns the allocated size required for this accumulator, in diff --git a/datafusion/functions-aggregate/src/percentile_cont.rs b/datafusion/functions-aggregate/src/percentile_cont.rs index c82f03a3b5f0e..37f4ffd9d1707 100644 --- a/datafusion/functions-aggregate/src/percentile_cont.rs +++ b/datafusion/functions-aggregate/src/percentile_cont.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::fmt::Debug; use std::mem::{size_of, size_of_val}; use std::sync::Arc; @@ -52,7 +53,7 @@ use datafusion_expr::{ }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask; -use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer; +use datafusion_functions_aggregate_common::utils::{GenericDistinctBuffer, Hashable}; use datafusion_macros::user_doc; use crate::utils::validate_percentile_expr; @@ -427,14 +428,48 @@ impl Accumulator for PercentileContAccumulator { } fn evaluate(&mut self) -> Result { - let d = std::mem::take(&mut self.all_values); - let value = calculate_percentile::(d, self.percentile); + let value = calculate_percentile::(&mut self.all_values, self.percentile); ScalarValue::new_primitive::(value, &T::DATA_TYPE) } fn size(&self) -> usize { size_of_val(self) + self.all_values.capacity() * size_of::() } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let mut to_remove: HashMap = HashMap::new(); + for i in 0..values[0].len() { + let v = ScalarValue::try_from_array(&values[0], i)?; + if !v.is_null() { + *to_remove.entry(v).or_default() += 1; + } + } + + let mut i = 0; + while i < self.all_values.len() { + let k = + ScalarValue::new_primitive::(Some(self.all_values[i]), &T::DATA_TYPE)?; + if let Some(count) = to_remove.get_mut(&k) + && *count > 0 + { + self.all_values.swap_remove(i); + *count -= 1; + if *count == 0 { + to_remove.remove(&k); + if to_remove.is_empty() { + break; + } + } + } else { + i += 1; + } + } + Ok(()) + } + + fn supports_retract_batch(&self) -> bool { + true + } } /// The percentile_cont groups accumulator accumulates the raw input values @@ -549,13 +584,13 @@ impl GroupsAccumulator fn evaluate(&mut self, emit_to: EmitTo) -> Result { // Emit values - let emit_group_values = emit_to.take_needed(&mut self.group_values); + let mut emit_group_values = emit_to.take_needed(&mut self.group_values); // Calculate percentile for each group let mut evaluate_result_builder = PrimitiveBuilder::::with_capacity(emit_group_values.len()); - for values in emit_group_values { - let value = calculate_percentile::(values, self.percentile); + for values in &mut emit_group_values { + let value = calculate_percentile::(values.as_mut_slice(), self.percentile); evaluate_result_builder.append_option(value); } @@ -652,17 +687,31 @@ impl Accumulator for DistinctPercentileContAccumula } fn evaluate(&mut self) -> Result { - let d = std::mem::take(&mut self.distinct_values.values) - .into_iter() - .map(|v| v.0) - .collect::>(); - let value = calculate_percentile::(d, self.percentile); + let mut values: Vec = + self.distinct_values.values.iter().map(|v| v.0).collect(); + let value = calculate_percentile::(&mut values, self.percentile); ScalarValue::new_primitive::(value, &T::DATA_TYPE) } fn size(&self) -> usize { size_of_val(self) + self.distinct_values.size() } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let arr = values[0].as_primitive::(); + for value in arr.iter().flatten() { + self.distinct_values.values.remove(&Hashable(value)); + } + Ok(()) + } + + fn supports_retract_batch(&self) -> bool { + true + } } /// Calculate the percentile value for a given set of values. @@ -672,8 +721,12 @@ impl Accumulator for DistinctPercentileContAccumula /// For percentile p and n values: /// - If p * (n-1) is an integer, return the value at that position /// - Otherwise, interpolate between the two closest values +/// +/// Note: This function takes a mutable slice and sorts it in place, but does not +/// consume the data. This is important for window frame queries where evaluate() +/// may be called multiple times on the same accumulator state. fn calculate_percentile( - mut values: Vec, + values: &mut [T::Native], percentile: f64, ) -> Option { let cmp = |x: &T::Native, y: &T::Native| x.compare(*y); diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 77e9f60afd3cf..1c10818c091db 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -384,14 +384,13 @@ impl Accumulator for SimpleStringAggAccumulator { } fn evaluate(&mut self) -> Result { - let result = if self.has_value { - ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string))) + if self.has_value { + Ok(ScalarValue::LargeUtf8(Some( + self.accumulated_string.clone(), + ))) } else { - ScalarValue::LargeUtf8(None) - }; - - self.has_value = false; - Ok(result) + Ok(ScalarValue::LargeUtf8(None)) + } } fn size(&self) -> usize { diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index df980ab863362..3c962a0f87f36 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1130,6 +1130,102 @@ ORDER BY tags, timestamp; 4 tag2 90 75 80 95 5 tag2 100 80 80 100 +########### +# Issue #19612: Test that percentile_cont produces correct results +# in window frame queries. Previously percentile_cont consumed its internal state +# during evaluate(), causing incorrect results when called multiple times. +########### + +# Test percentile_cont sliding window (same as median) +query ITRR +SELECT + timestamp, + tags, + value, + percentile_cont(value, 0.5) OVER ( + PARTITION BY tags + ORDER BY timestamp + ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING + ) AS value_percentile_50 +FROM median_window_test +ORDER BY tags, timestamp; +---- +1 tag1 10 15 +2 tag1 20 20 +3 tag1 30 30 +4 tag1 40 40 +5 tag1 50 45 +1 tag2 60 65 +2 tag2 70 70 +3 tag2 80 80 +4 tag2 90 90 +5 tag2 100 95 + +# Test percentile_cont non-sliding window +query ITRRRR +SELECT + timestamp, + tags, + value, + percentile_cont(value, 0.5) OVER ( + PARTITION BY tags + ORDER BY timestamp + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + ) AS value_percentile_unbounded_preceding, + percentile_cont(value, 0.5) OVER ( + PARTITION BY tags + ORDER BY timestamp + ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + ) AS value_percentile_unbounded_both, + percentile_cont(value, 0.5) OVER ( + PARTITION BY tags + ORDER BY timestamp + ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING + ) AS value_percentile_unbounded_following +FROM median_window_test +ORDER BY tags, timestamp; +---- +1 tag1 10 10 30 30 +2 tag1 20 15 30 35 +3 tag1 30 20 30 40 +4 tag1 40 25 30 45 +5 tag1 50 30 30 50 +1 tag2 60 60 80 80 +2 tag2 70 65 80 85 +3 tag2 80 70 80 90 +4 tag2 90 75 80 95 +5 tag2 100 80 80 100 + +# Test percentile_cont with different percentile values +query ITRRR +SELECT + timestamp, + tags, + value, + percentile_cont(value, 0.25) OVER ( + PARTITION BY tags + ORDER BY timestamp + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + ) AS p25, + percentile_cont(value, 0.75) OVER ( + PARTITION BY tags + ORDER BY timestamp + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + ) AS p75 +FROM median_window_test +ORDER BY tags, timestamp; +---- +1 tag1 10 10 10 +2 tag1 20 12.5 17.5 +3 tag1 30 15 25 +4 tag1 40 17.5 32.5 +5 tag1 50 20 40 +1 tag2 60 60 60 +2 tag2 70 62.5 67.5 +3 tag2 80 65 75 +4 tag2 90 67.5 82.5 +5 tag2 100 70 90 + statement ok DROP TABLE median_window_test; @@ -8250,3 +8346,44 @@ query R select percentile_cont(null, 0.5); ---- NULL + +# Test string_agg window frame behavior (fix for issue #19612) +statement ok +CREATE TABLE string_agg_window_test ( + id INT, + grp VARCHAR, + val VARCHAR +); + +statement ok +INSERT INTO string_agg_window_test (id, grp, val) VALUES +(1, 'A', 'a'), +(2, 'A', 'b'), +(3, 'A', 'c'), +(1, 'B', 'x'), +(2, 'B', 'y'), +(3, 'B', 'z'); + +# Test string_agg with ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +# The function should maintain state correctly across multiple evaluate() calls +query ITT +SELECT + id, + grp, + string_agg(val, ',') OVER ( + PARTITION BY grp + ORDER BY id + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + ) AS cumulative_string +FROM string_agg_window_test +ORDER BY grp, id; +---- +1 A a +2 A a,b +3 A a,b,c +1 B x +2 B x,y +3 B x,y,z + +statement ok +DROP TABLE string_agg_window_test;