@@ -22,7 +22,11 @@ use crate::aggregate::stats::StatsType;
22
22
use crate :: aggregate:: stddev:: StddevAccumulator ;
23
23
use crate :: expressions:: format_state_name;
24
24
use crate :: { AggregateExpr , PhysicalExpr } ;
25
- use arrow:: { array:: ArrayRef , datatypes:: DataType , datatypes:: Field } ;
25
+ use arrow:: {
26
+ array:: ArrayRef ,
27
+ compute:: { and, filter, is_not_null} ,
28
+ datatypes:: { DataType , Field } ,
29
+ } ;
26
30
use datafusion_common:: Result ;
27
31
use datafusion_common:: ScalarValue ;
28
32
use datafusion_expr:: Accumulator ;
@@ -145,14 +149,39 @@ impl Accumulator for CorrelationAccumulator {
145
149
}
146
150
147
151
fn update_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
148
- self . covar . update_batch ( values) ?;
152
+ // TODO: null input skipping logic duplicated across Correlation
153
+ // and its children accumulators.
154
+ // This could be simplified by splitting up input filtering and
155
+ // calculation logic in children accumulators, and calling only
156
+ // calculation part from Correlation
157
+ let values = if values[ 0 ] . null_count ( ) != 0 || values[ 1 ] . null_count ( ) != 0 {
158
+ let mask = and ( & is_not_null ( & values[ 0 ] ) ?, & is_not_null ( & values[ 1 ] ) ?) ?;
159
+ let values1 = filter ( & values[ 0 ] , & mask) ?;
160
+ let values2 = filter ( & values[ 1 ] , & mask) ?;
161
+
162
+ vec ! [ values1, values2]
163
+ } else {
164
+ values. to_vec ( )
165
+ } ;
166
+
167
+ self . covar . update_batch ( & values) ?;
149
168
self . stddev1 . update_batch ( & values[ 0 ..1 ] ) ?;
150
169
self . stddev2 . update_batch ( & values[ 1 ..2 ] ) ?;
151
170
Ok ( ( ) )
152
171
}
153
172
154
173
fn retract_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
155
- self . covar . retract_batch ( values) ?;
174
+ let values = if values[ 0 ] . null_count ( ) != 0 || values[ 1 ] . null_count ( ) != 0 {
175
+ let mask = and ( & is_not_null ( & values[ 0 ] ) ?, & is_not_null ( & values[ 1 ] ) ?) ?;
176
+ let values1 = filter ( & values[ 0 ] , & mask) ?;
177
+ let values2 = filter ( & values[ 1 ] , & mask) ?;
178
+
179
+ vec ! [ values1, values2]
180
+ } else {
181
+ values. to_vec ( )
182
+ } ;
183
+
184
+ self . covar . retract_batch ( & values) ?;
156
185
self . stddev1 . retract_batch ( & values[ 0 ..1 ] ) ?;
157
186
self . stddev2 . retract_batch ( & values[ 1 ..2 ] ) ?;
158
187
Ok ( ( ) )
@@ -341,48 +370,44 @@ mod tests {
341
370
342
371
#[ test]
343
372
fn correlation_i32_with_nulls_2 ( ) -> Result < ( ) > {
344
- let a: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [ Some ( 1 ) , None , Some ( 3 ) ] ) ) ;
345
- let b: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [ Some ( 4 ) , Some ( 5 ) , Some ( 6 ) ] ) ) ;
346
-
347
- let schema = Schema :: new ( vec ! [
348
- Field :: new( "a" , DataType :: Int32 , true ) ,
349
- Field :: new( "b" , DataType :: Int32 , true ) ,
350
- ] ) ;
351
- let batch = RecordBatch :: try_new ( Arc :: new ( schema. clone ( ) ) , vec ! [ a, b] ) ?;
352
-
353
- let agg = Arc :: new ( Correlation :: new (
354
- col ( "a" , & schema) ?,
355
- col ( "b" , & schema) ?,
356
- "bla" . to_string ( ) ,
357
- DataType :: Float64 ,
358
- ) ) ;
359
- let actual = aggregate ( & batch, agg) ;
360
- assert ! ( actual. is_err( ) ) ;
373
+ let a: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [
374
+ Some ( 1 ) ,
375
+ None ,
376
+ Some ( 2 ) ,
377
+ Some ( 9 ) ,
378
+ Some ( 3 ) ,
379
+ ] ) ) ;
380
+ let b: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [
381
+ Some ( 4 ) ,
382
+ Some ( 5 ) ,
383
+ Some ( 5 ) ,
384
+ None ,
385
+ Some ( 6 ) ,
386
+ ] ) ) ;
361
387
362
- Ok ( ( ) )
388
+ generic_test_op2 ! (
389
+ a,
390
+ b,
391
+ DataType :: Int32 ,
392
+ DataType :: Int32 ,
393
+ Correlation ,
394
+ ScalarValue :: from( 1_f64 )
395
+ )
363
396
}
364
397
365
398
#[ test]
366
399
fn correlation_i32_all_nulls ( ) -> Result < ( ) > {
367
400
let a: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [ None , None ] ) ) ;
368
401
let b: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [ None , None ] ) ) ;
369
402
370
- let schema = Schema :: new ( vec ! [
371
- Field :: new( "a" , DataType :: Int32 , true ) ,
372
- Field :: new( "b" , DataType :: Int32 , true ) ,
373
- ] ) ;
374
- let batch = RecordBatch :: try_new ( Arc :: new ( schema. clone ( ) ) , vec ! [ a, b] ) ?;
375
-
376
- let agg = Arc :: new ( Correlation :: new (
377
- col ( "a" , & schema) ?,
378
- col ( "b" , & schema) ?,
379
- "bla" . to_string ( ) ,
380
- DataType :: Float64 ,
381
- ) ) ;
382
- let actual = aggregate ( & batch, agg) ;
383
- assert ! ( actual. is_err( ) ) ;
384
-
385
- Ok ( ( ) )
403
+ generic_test_op2 ! (
404
+ a,
405
+ b,
406
+ DataType :: Int32 ,
407
+ DataType :: Int32 ,
408
+ Correlation ,
409
+ ScalarValue :: Float64 ( None )
410
+ )
386
411
}
387
412
388
413
#[ test]
0 commit comments