Skip to content

Commit 1eb46df

Browse files
korowaozankabak
andauthored
Covariance single row input & null skipping (apache#4852)
* covariance & correlation single row & null skipping * Apply suggestions from code review Co-authored-by: Mehmet Ozan Kabak <[email protected]> * unwrap_or_internal_err macro instead of unwrap Co-authored-by: Mehmet Ozan Kabak <[email protected]>
1 parent 292eb95 commit 1eb46df

File tree

3 files changed

+292
-110
lines changed

3 files changed

+292
-110
lines changed

datafusion/core/tests/sqllogictests/test_files/aggregate.slt

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,94 @@ SELECT covar(c2, c12) FROM aggregate_test_100
3636
----
3737
-0.07996901247859442
3838

39+
# single_row_query_covar_1
40+
query R
41+
select covar_samp(sq.column1, sq.column2) from (values (1.1, 2.2)) as sq
42+
----
43+
NULL
44+
45+
# single_row_query_covar_2
46+
query R
47+
select covar_pop(sq.column1, sq.column2) from (values (1.1, 2.2)) as sq
48+
----
49+
0
50+
51+
# all_nulls_query_covar
52+
query R
53+
with data as (
54+
select null::int as f, null::int as b
55+
union all
56+
select null::int as f, null::int as b
57+
)
58+
select covar_samp(f, b), covar_pop(f, b)
59+
from data
60+
----
61+
NULL NULL
62+
63+
# covar_query_with_nulls
64+
query R
65+
with data as (
66+
select 1 as f, 4 as b
67+
union all
68+
select null as f, 99 as b
69+
union all
70+
select 2 as f, 5 as b
71+
union all
72+
select 98 as f, null as b
73+
union all
74+
select 3 as f, 6 as b
75+
union all
76+
select null as f, null as b
77+
)
78+
select covar_samp(f, b), covar_pop(f, b)
79+
from data
80+
----
81+
1 0.6666666666666666
82+
3983
# csv_query_correlation
4084
query R
4185
SELECT corr(c2, c12) FROM aggregate_test_100
4286
----
4387
-0.19064544190576607
4488

89+
# single_row_query_correlation
90+
query R
91+
select corr(sq.column1, sq.column2) from (values (1.1, 2.2)) as sq
92+
----
93+
0
94+
95+
# all_nulls_query_correlation
96+
query R
97+
with data as (
98+
select null::int as f, null::int as b
99+
union all
100+
select null::int as f, null::int as b
101+
)
102+
select corr(f, b)
103+
from data
104+
----
105+
NULL
106+
107+
# correlation_query_with_nulls
108+
query R
109+
with data as (
110+
select 1 as f, 4 as b
111+
union all
112+
select null as f, 99 as b
113+
union all
114+
select 2 as f, 5 as b
115+
union all
116+
select 98 as f, null as b
117+
union all
118+
select 3 as f, 6 as b
119+
union all
120+
select null as f, null as b
121+
)
122+
select corr(f, b)
123+
from data
124+
----
125+
1
126+
45127
# csv_query_variance_1
46128
query R
47129
SELECT var_pop(c2) FROM aggregate_test_100

datafusion/physical-expr/src/aggregate/correlation.rs

Lines changed: 62 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@ use crate::aggregate::stats::StatsType;
2222
use crate::aggregate::stddev::StddevAccumulator;
2323
use crate::expressions::format_state_name;
2424
use crate::{AggregateExpr, PhysicalExpr};
25-
use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
25+
use arrow::{
26+
array::ArrayRef,
27+
compute::{and, filter, is_not_null},
28+
datatypes::{DataType, Field},
29+
};
2630
use datafusion_common::Result;
2731
use datafusion_common::ScalarValue;
2832
use datafusion_expr::Accumulator;
@@ -145,14 +149,39 @@ impl Accumulator for CorrelationAccumulator {
145149
}
146150

147151
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
148-
self.covar.update_batch(values)?;
152+
// TODO: null input skipping logic duplicated across Correlation
153+
// and its children accumulators.
154+
// This could be simplified by splitting up input filtering and
155+
// calculation logic in children accumulators, and calling only
156+
// calculation part from Correlation
157+
let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
158+
let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
159+
let values1 = filter(&values[0], &mask)?;
160+
let values2 = filter(&values[1], &mask)?;
161+
162+
vec![values1, values2]
163+
} else {
164+
values.to_vec()
165+
};
166+
167+
self.covar.update_batch(&values)?;
149168
self.stddev1.update_batch(&values[0..1])?;
150169
self.stddev2.update_batch(&values[1..2])?;
151170
Ok(())
152171
}
153172

154173
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
155-
self.covar.retract_batch(values)?;
174+
let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
175+
let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
176+
let values1 = filter(&values[0], &mask)?;
177+
let values2 = filter(&values[1], &mask)?;
178+
179+
vec![values1, values2]
180+
} else {
181+
values.to_vec()
182+
};
183+
184+
self.covar.retract_batch(&values)?;
156185
self.stddev1.retract_batch(&values[0..1])?;
157186
self.stddev2.retract_batch(&values[1..2])?;
158187
Ok(())
@@ -341,48 +370,44 @@ mod tests {
341370

342371
#[test]
343372
fn correlation_i32_with_nulls_2() -> Result<()> {
344-
let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
345-
let b: ArrayRef = Arc::new(Int32Array::from(vec![Some(4), Some(5), Some(6)]));
346-
347-
let schema = Schema::new(vec![
348-
Field::new("a", DataType::Int32, true),
349-
Field::new("b", DataType::Int32, true),
350-
]);
351-
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, b])?;
352-
353-
let agg = Arc::new(Correlation::new(
354-
col("a", &schema)?,
355-
col("b", &schema)?,
356-
"bla".to_string(),
357-
DataType::Float64,
358-
));
359-
let actual = aggregate(&batch, agg);
360-
assert!(actual.is_err());
373+
let a: ArrayRef = Arc::new(Int32Array::from(vec![
374+
Some(1),
375+
None,
376+
Some(2),
377+
Some(9),
378+
Some(3),
379+
]));
380+
let b: ArrayRef = Arc::new(Int32Array::from(vec![
381+
Some(4),
382+
Some(5),
383+
Some(5),
384+
None,
385+
Some(6),
386+
]));
361387

362-
Ok(())
388+
generic_test_op2!(
389+
a,
390+
b,
391+
DataType::Int32,
392+
DataType::Int32,
393+
Correlation,
394+
ScalarValue::from(1_f64)
395+
)
363396
}
364397

365398
#[test]
366399
fn correlation_i32_all_nulls() -> Result<()> {
367400
let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None]));
368401
let b: ArrayRef = Arc::new(Int32Array::from(vec![None, None]));
369402

370-
let schema = Schema::new(vec![
371-
Field::new("a", DataType::Int32, true),
372-
Field::new("b", DataType::Int32, true),
373-
]);
374-
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, b])?;
375-
376-
let agg = Arc::new(Correlation::new(
377-
col("a", &schema)?,
378-
col("b", &schema)?,
379-
"bla".to_string(),
380-
DataType::Float64,
381-
));
382-
let actual = aggregate(&batch, agg);
383-
assert!(actual.is_err());
384-
385-
Ok(())
403+
generic_test_op2!(
404+
a,
405+
b,
406+
DataType::Int32,
407+
DataType::Int32,
408+
Correlation,
409+
ScalarValue::Float64(None)
410+
)
386411
}
387412

388413
#[test]

0 commit comments

Comments
 (0)