Skip to content

Commit 4dd7825

Browse files
authored
Perform type coercion for corr aggregate function (#15776)
* Type coercion for corr aggregate function during planning * perform type coercion during logical planning * remove redundant type coercion test file
1 parent 4bc66c8 commit 4dd7825

File tree

2 files changed

+120
-14
lines changed

2 files changed

+120
-14
lines changed

datafusion/functions-aggregate/src/correlation.rs

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use arrow::array::{
2626
downcast_array, Array, AsArray, BooleanArray, Float64Array, NullBufferBuilder,
2727
UInt64Array,
2828
};
29-
use arrow::compute::{and, filter, is_not_null, kernels::cast};
29+
use arrow::compute::{and, filter, is_not_null};
3030
use arrow::datatypes::{FieldRef, Float64Type, UInt64Type};
3131
use arrow::{
3232
array::ArrayRef,
@@ -38,10 +38,9 @@ use log::debug;
3838

3939
use crate::covariance::CovarianceAccumulator;
4040
use crate::stddev::StddevAccumulator;
41-
use datafusion_common::{plan_err, Result, ScalarValue};
41+
use datafusion_common::{Result, ScalarValue};
4242
use datafusion_expr::{
4343
function::{AccumulatorArgs, StateFieldsArgs},
44-
type_coercion::aggregates::NUMERICS,
4544
utils::format_state_name,
4645
Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
4746
};
@@ -83,10 +82,13 @@ impl Default for Correlation {
8382
}
8483

8584
impl Correlation {
86-
/// Create a new COVAR_POP aggregate function
85+
/// Create a new CORR aggregate function
8786
pub fn new() -> Self {
8887
Self {
89-
signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
88+
signature: Signature::exact(
89+
vec![DataType::Float64, DataType::Float64],
90+
Volatility::Immutable,
91+
),
9092
}
9193
}
9294
}
@@ -105,11 +107,7 @@ impl AggregateUDFImpl for Correlation {
105107
&self.signature
106108
}
107109

108-
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
109-
if !arg_types[0].is_numeric() {
110-
return plan_err!("Correlation requires numeric input types");
111-
}
112-
110+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
113111
Ok(DataType::Float64)
114112
}
115113

@@ -375,10 +373,8 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator {
375373
self.sum_xx.resize(total_num_groups, 0.0);
376374
self.sum_yy.resize(total_num_groups, 0.0);
377375

378-
let array_x = &cast(&values[0], &DataType::Float64)?;
379-
let array_x = downcast_array::<Float64Array>(array_x);
380-
let array_y = &cast(&values[1], &DataType::Float64)?;
381-
let array_y = downcast_array::<Float64Array>(array_y);
376+
let array_x = downcast_array::<Float64Array>(&values[0]);
377+
let array_y = downcast_array::<Float64Array>(&values[1]);
382378

383379
accumulate_multiple(
384380
group_indices,

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2548,7 +2548,117 @@ select covar_samp(c1, c2), arrow_typeof(covar_samp(c1, c2)) from t;
25482548
statement ok
25492549
drop table t;
25502550

2551+
# correlation_f64_1
2552+
statement ok
2553+
create table t (c1 double, c2 double) as values (1, 4), (2, 5), (3, 6);
2554+
2555+
query RT rowsort
2556+
select corr(c1, c2), arrow_typeof(corr(c1, c2)) from t;
2557+
----
2558+
1 Float64
2559+
2560+
# correlation with different numeric types (create test data)
2561+
statement ok
2562+
CREATE OR REPLACE TABLE corr_test(
2563+
int8_col TINYINT,
2564+
int16_col SMALLINT,
2565+
int32_col INT,
2566+
int64_col BIGINT,
2567+
uint32_col INT UNSIGNED,
2568+
float32_col FLOAT,
2569+
float64_col DOUBLE
2570+
) as VALUES
2571+
(1, 10, 100, 1000, 10000, 1.1, 10.1),
2572+
(2, 20, 200, 2000, 20000, 2.2, 20.2),
2573+
(3, 30, 300, 3000, 30000, 3.3, 30.3),
2574+
(4, 40, 400, 4000, 40000, 4.4, 40.4),
2575+
(5, 50, 500, 5000, 50000, 5.5, 50.5);
2576+
2577+
# correlation using int32 and float64
2578+
query R
2579+
SELECT corr(int32_col, float64_col) FROM corr_test;
2580+
----
2581+
1
2582+
2583+
# correlation using int64 and int32
2584+
query R
2585+
SELECT corr(int64_col, int32_col) FROM corr_test;
2586+
----
2587+
1
2588+
2589+
# correlation using float32 and int8
2590+
query R
2591+
SELECT corr(float32_col, int8_col) FROM corr_test;
2592+
----
2593+
1
2594+
2595+
# correlation using uint32 and int16
2596+
query R
2597+
SELECT corr(uint32_col, int16_col) FROM corr_test;
2598+
----
2599+
1
2600+
2601+
# correlation with nulls
2602+
statement ok
2603+
CREATE OR REPLACE TABLE corr_nulls(
2604+
x INT,
2605+
y DOUBLE
2606+
) as VALUES
2607+
(1, 10.0),
2608+
(2, 20.0),
2609+
(NULL, 30.0),
2610+
(4, NULL),
2611+
(5, 50.0);
2612+
2613+
# correlation with some nulls (should skip null pairs)
2614+
query R
2615+
SELECT corr(x, y) FROM corr_nulls;
2616+
----
2617+
1
2618+
2619+
# correlation with single row (should return NULL)
2620+
statement ok
2621+
CREATE OR REPLACE TABLE corr_single_row(
2622+
x INT,
2623+
y DOUBLE
2624+
) as VALUES
2625+
(1, 10.0);
2626+
2627+
query R
2628+
SELECT corr(x, y) FROM corr_single_row;
2629+
----
2630+
0
2631+
2632+
# correlation with all nulls
2633+
statement ok
2634+
CREATE OR REPLACE TABLE corr_all_nulls(
2635+
x INT,
2636+
y DOUBLE
2637+
) as VALUES
2638+
(NULL, NULL),
2639+
(NULL, NULL);
2640+
2641+
query R
2642+
SELECT corr(x, y) FROM corr_all_nulls;
2643+
----
2644+
NULL
2645+
2646+
statement ok
2647+
drop table corr_test;
2648+
2649+
statement ok
2650+
drop table corr_nulls;
2651+
2652+
statement ok
2653+
drop table corr_single_row;
2654+
2655+
statement ok
2656+
drop table corr_all_nulls;
2657+
25512658
# covariance_f64_4
2659+
statement ok
2660+
drop table if exists t;
2661+
25522662
statement ok
25532663
create table t (c1 double, c2 double) as values (1.1, 4.1), (2.0, 5.0), (3.0, 6.0);
25542664

0 commit comments

Comments
 (0)