Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 34 additions & 28 deletions datafusion/functions-aggregate/src/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,11 +411,15 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator {
}

fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
let n = match emit_to {
EmitTo::All => self.count.len(),
EmitTo::First(n) => n,
};

// Drain the state vectors for the groups being emitted
let counts = emit_to.take_needed(&mut self.count);
let sum_xs = emit_to.take_needed(&mut self.sum_x);
let sum_ys = emit_to.take_needed(&mut self.sum_y);
let sum_xys = emit_to.take_needed(&mut self.sum_xy);
let sum_xxs = emit_to.take_needed(&mut self.sum_xx);
let sum_yys = emit_to.take_needed(&mut self.sum_yy);

let n = counts.len();
let mut values = Vec::with_capacity(n);
let mut nulls = NullBufferBuilder::new(n);

Expand All @@ -427,14 +431,13 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator {
// result should be `Null` (according to PostgreSQL's behavior).
// - However, if any of the accumulated values contain NaN, the result should
// be NaN regardless of the count (even for single-row groups).
//
for i in 0..n {
let count = self.count[i];
let sum_x = self.sum_x[i];
let sum_y = self.sum_y[i];
let sum_xy = self.sum_xy[i];
let sum_xx = self.sum_xx[i];
let sum_yy = self.sum_yy[i];
let count = counts[i];
let sum_x = sum_xs[i];
let sum_y = sum_ys[i];
let sum_xy = sum_xys[i];
let sum_xx = sum_xxs[i];
let sum_yy = sum_yys[i];

// If BOTH sum_x AND sum_y are NaN, then both input values are NaN → return NaN
// If only ONE of them is NaN, then only one input value is NaN → return NULL
Expand Down Expand Up @@ -470,18 +473,21 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator {
}

fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
let n = match emit_to {
EmitTo::All => self.count.len(),
EmitTo::First(n) => n,
};
// Drain the state vectors for the groups being emitted
let count = emit_to.take_needed(&mut self.count);
let sum_x = emit_to.take_needed(&mut self.sum_x);
let sum_y = emit_to.take_needed(&mut self.sum_y);
let sum_xy = emit_to.take_needed(&mut self.sum_xy);
let sum_xx = emit_to.take_needed(&mut self.sum_xx);
let sum_yy = emit_to.take_needed(&mut self.sum_yy);

Ok(vec![
Arc::new(UInt64Array::from(self.count[0..n].to_vec())),
Arc::new(Float64Array::from(self.sum_x[0..n].to_vec())),
Arc::new(Float64Array::from(self.sum_y[0..n].to_vec())),
Arc::new(Float64Array::from(self.sum_xy[0..n].to_vec())),
Arc::new(Float64Array::from(self.sum_xx[0..n].to_vec())),
Arc::new(Float64Array::from(self.sum_yy[0..n].to_vec())),
Arc::new(UInt64Array::from(count)),
Arc::new(Float64Array::from(sum_x)),
Arc::new(Float64Array::from(sum_y)),
Arc::new(Float64Array::from(sum_xy)),
Arc::new(Float64Array::from(sum_xx)),
Arc::new(Float64Array::from(sum_yy)),
])
}

Expand Down Expand Up @@ -537,12 +543,12 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator {
}

fn size(&self) -> usize {
size_of_val(&self.count)
+ size_of_val(&self.sum_x)
+ size_of_val(&self.sum_y)
+ size_of_val(&self.sum_xy)
+ size_of_val(&self.sum_xx)
+ size_of_val(&self.sum_yy)
self.count.capacity() * size_of::<u64>()
+ self.sum_x.capacity() * size_of::<f64>()
+ self.sum_y.capacity() * size_of::<f64>()
+ self.sum_xy.capacity() * size_of::<f64>()
+ self.sum_xx.capacity() * size_of::<f64>()
+ self.sum_yy.capacity() * size_of::<f64>()
}
}

Expand Down
Loading