Skip to content

Commit f84a390

Browse files
fix: drain CORR state vectors on EmitTo::First in streaming aggregation
1 parent 79767e2 commit f84a390

File tree

2 files changed

+44
-29
lines changed

2 files changed

+44
-29
lines changed

datafusion/functions-aggregate/src/correlation.rs

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -411,11 +411,15 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator {
411411
}
412412

413413
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
414-
let n = match emit_to {
415-
EmitTo::All => self.count.len(),
416-
EmitTo::First(n) => n,
417-
};
418-
414+
// Drain the state vectors for the groups being emitted
415+
let counts = emit_to.take_needed(&mut self.count);
416+
let sum_xs = emit_to.take_needed(&mut self.sum_x);
417+
let sum_ys = emit_to.take_needed(&mut self.sum_y);
418+
let sum_xys = emit_to.take_needed(&mut self.sum_xy);
419+
let sum_xxs = emit_to.take_needed(&mut self.sum_xx);
420+
let sum_yys = emit_to.take_needed(&mut self.sum_yy);
421+
422+
let n = counts.len();
419423
let mut values = Vec::with_capacity(n);
420424
let mut nulls = NullBufferBuilder::new(n);
421425

@@ -427,14 +431,13 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator {
427431
// result should be `Null` (according to PostgreSQL's behavior).
428432
// - However, if any of the accumulated values contain NaN, the result should
429433
// be NaN regardless of the count (even for single-row groups).
430-
//
431434
for i in 0..n {
432-
let count = self.count[i];
433-
let sum_x = self.sum_x[i];
434-
let sum_y = self.sum_y[i];
435-
let sum_xy = self.sum_xy[i];
436-
let sum_xx = self.sum_xx[i];
437-
let sum_yy = self.sum_yy[i];
435+
let count = counts[i];
436+
let sum_x = sum_xs[i];
437+
let sum_y = sum_ys[i];
438+
let sum_xy = sum_xys[i];
439+
let sum_xx = sum_xxs[i];
440+
let sum_yy = sum_yys[i];
438441

439442
// If BOTH sum_x AND sum_y are NaN, then both input values are NaN → return NaN
440443
// If only ONE of them is NaN, then only one input value is NaN → return NULL
@@ -470,18 +473,21 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator {
470473
}
471474

472475
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
473-
let n = match emit_to {
474-
EmitTo::All => self.count.len(),
475-
EmitTo::First(n) => n,
476-
};
476+
// Drain the state vectors for the groups being emitted
477+
let count = emit_to.take_needed(&mut self.count);
478+
let sum_x = emit_to.take_needed(&mut self.sum_x);
479+
let sum_y = emit_to.take_needed(&mut self.sum_y);
480+
let sum_xy = emit_to.take_needed(&mut self.sum_xy);
481+
let sum_xx = emit_to.take_needed(&mut self.sum_xx);
482+
let sum_yy = emit_to.take_needed(&mut self.sum_yy);
477483

478484
Ok(vec![
479-
Arc::new(UInt64Array::from(self.count[0..n].to_vec())),
480-
Arc::new(Float64Array::from(self.sum_x[0..n].to_vec())),
481-
Arc::new(Float64Array::from(self.sum_y[0..n].to_vec())),
482-
Arc::new(Float64Array::from(self.sum_xy[0..n].to_vec())),
483-
Arc::new(Float64Array::from(self.sum_xx[0..n].to_vec())),
484-
Arc::new(Float64Array::from(self.sum_yy[0..n].to_vec())),
485+
Arc::new(UInt64Array::from(count)),
486+
Arc::new(Float64Array::from(sum_x)),
487+
Arc::new(Float64Array::from(sum_y)),
488+
Arc::new(Float64Array::from(sum_xy)),
489+
Arc::new(Float64Array::from(sum_xx)),
490+
Arc::new(Float64Array::from(sum_yy)),
485491
])
486492
}
487493

@@ -537,12 +543,12 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator {
537543
}
538544

539545
fn size(&self) -> usize {
540-
size_of_val(&self.count)
541-
+ size_of_val(&self.sum_x)
542-
+ size_of_val(&self.sum_y)
543-
+ size_of_val(&self.sum_xy)
544-
+ size_of_val(&self.sum_xx)
545-
+ size_of_val(&self.sum_yy)
546+
self.count.capacity() * size_of::<u64>()
547+
+ self.sum_x.capacity() * size_of::<f64>()
548+
+ self.sum_y.capacity() * size_of::<f64>()
549+
+ self.sum_xy.capacity() * size_of::<f64>()
550+
+ self.sum_xx.capacity() * size_of::<f64>()
551+
+ self.sum_yy.capacity() * size_of::<f64>()
546552
}
547553
}
548554

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3062,7 +3062,7 @@ set datafusion.execution.target_partitions = 1;
30623062

30633063
# Bucket 1: CORR = 1, -1, 1, -1 (y varies)
30643064
# Bucket 2: CORR = NULL (y constant, zero variance)
3065-
statement error DataFusion error: Arrow error: Invalid argument error: all columns in a record batch must have the same length
3065+
query IIR
30663066
SELECT bucket, grp, CORR(x, y) FROM (
30673067
SELECT * FROM (VALUES
30683068
(1, 1, 1.0, 1.0), (1, 1, 2.0, 2.0),
@@ -3079,6 +3079,15 @@ SELECT bucket, grp, CORR(x, y) FROM (
30793079
) AS ordered_data
30803080
GROUP BY bucket, grp
30813081
ORDER BY bucket, grp;
3082+
----
3083+
1 1 1
3084+
1 2 -1
3085+
1 3 1
3086+
1 4 -1
3087+
2 1 NULL
3088+
2 2 NULL
3089+
2 3 NULL
3090+
2 4 NULL
30823091

30833092
statement ok
30843093
set datafusion.execution.target_partitions = 4;

0 commit comments

Comments
 (0)