@@ -23,8 +23,7 @@ use crate::expressions::{lit, try_cast};
2323use arrow:: array:: * ;
2424use arrow:: compute:: kernels:: zip:: zip;
2525use arrow:: compute:: {
26- FilterBuilder , FilterPredicate , SlicesIterator , is_not_null, not, nullif,
27- prep_null_mask_filter,
26+ FilterBuilder , FilterPredicate , is_not_null, not, nullif, prep_null_mask_filter,
2827} ;
2928use arrow:: datatypes:: { DataType , Schema , UInt32Type , UnionMode } ;
3029use arrow:: error:: ArrowError ;
@@ -39,6 +38,7 @@ use std::hash::Hash;
3938use std:: { any:: Any , sync:: Arc } ;
4039
4140use crate :: expressions:: case:: literal_lookup_table:: LiteralLookupTable ;
41+ use arrow:: compute:: kernels:: merge:: { MergeIndex , merge, merge_n} ;
4242use datafusion_common:: tree_node:: { Transformed , TreeNode , TreeNodeRecursion } ;
4343use datafusion_physical_expr_common:: datum:: compare_with_eq;
4444use itertools:: Itertools ;
@@ -336,189 +336,6 @@ fn filter_array(
336336 filter. filter ( array)
337337}
338338
339- fn merge (
340- mask : & BooleanArray ,
341- truthy : ColumnarValue ,
342- falsy : ColumnarValue ,
343- ) -> std:: result:: Result < ArrayRef , ArrowError > {
344- let ( truthy, truthy_is_scalar) = match truthy {
345- ColumnarValue :: Array ( a) => ( a, false ) ,
346- ColumnarValue :: Scalar ( s) => ( s. to_array ( ) ?, true ) ,
347- } ;
348- let ( falsy, falsy_is_scalar) = match falsy {
349- ColumnarValue :: Array ( a) => ( a, false ) ,
350- ColumnarValue :: Scalar ( s) => ( s. to_array ( ) ?, true ) ,
351- } ;
352-
353- if truthy_is_scalar && falsy_is_scalar {
354- return zip ( mask, & Scalar :: new ( truthy) , & Scalar :: new ( falsy) ) ;
355- }
356-
357- let falsy = falsy. to_data ( ) ;
358- let truthy = truthy. to_data ( ) ;
359-
360- let mut mutable = MutableArrayData :: new ( vec ! [ & truthy, & falsy] , false , truthy. len ( ) ) ;
361-
362- // the SlicesIterator slices only the true values. So the gaps left by this iterator we need to
363- // fill with falsy values
364-
365- // keep track of how much is filled
366- let mut filled = 0 ;
367- let mut falsy_offset = 0 ;
368- let mut truthy_offset = 0 ;
369-
370- SlicesIterator :: new ( mask) . for_each ( |( start, end) | {
371- // the gap needs to be filled with falsy values
372- if start > filled {
373- if falsy_is_scalar {
374- for _ in filled..start {
375- // Copy the first item from the 'falsy' array into the output buffer.
376- mutable. extend ( 1 , 0 , 1 ) ;
377- }
378- } else {
379- let falsy_length = start - filled;
380- let falsy_end = falsy_offset + falsy_length;
381- mutable. extend ( 1 , falsy_offset, falsy_end) ;
382- falsy_offset = falsy_end;
383- }
384- }
385- // fill with truthy values
386- if truthy_is_scalar {
387- for _ in start..end {
388- // Copy the first item from the 'truthy' array into the output buffer.
389- mutable. extend ( 0 , 0 , 1 ) ;
390- }
391- } else {
392- let truthy_length = end - start;
393- let truthy_end = truthy_offset + truthy_length;
394- mutable. extend ( 0 , truthy_offset, truthy_end) ;
395- truthy_offset = truthy_end;
396- }
397- filled = end;
398- } ) ;
399- // the remaining part is falsy
400- if filled < mask. len ( ) {
401- if falsy_is_scalar {
402- for _ in filled..mask. len ( ) {
403- // Copy the first item from the 'falsy' array into the output buffer.
404- mutable. extend ( 1 , 0 , 1 ) ;
405- }
406- } else {
407- let falsy_length = mask. len ( ) - filled;
408- let falsy_end = falsy_offset + falsy_length;
409- mutable. extend ( 1 , falsy_offset, falsy_end) ;
410- }
411- }
412-
413- let data = mutable. freeze ( ) ;
414- Ok ( make_array ( data) )
415- }
416-
417- /// Merges elements by index from a list of [`ArrayData`], creating a new [`ColumnarValue`] from
418- /// those values.
419- ///
420- /// Each element in `indices` is the index of an array in `values`. The `indices` array is processed
421- /// sequentially. The first occurrence of index value `n` will be mapped to the first
422- /// value of the array at index `n`. The second occurrence to the second value, and so on.
423- /// An index value where `PartialResultIndex::is_none` is `true` is used to indicate null values.
424- ///
425- /// # Implementation notes
426- ///
427- /// This algorithm is similar in nature to both `zip` and `interleave`, but there are some important
428- /// differences.
429- ///
430- /// In contrast to `zip`, this function supports multiple input arrays. Instead of a boolean
431- /// selection vector, an index array is to take values from the input arrays, and a special marker
432- /// value is used to indicate null values.
433- ///
434- /// In contrast to `interleave`, this function does not use pairs of indices. The values in
435- /// `indices` serve the same purpose as the first value in the pairs passed to `interleave`.
436- /// The index in the array is implicit and is derived from the number of times a particular array
437- /// index occurs.
438- /// The more constrained indexing mechanism used by this algorithm makes it easier to copy values
439- /// in contiguous slices. In the example below, the two subsequent elements from array `2` can be
440- /// copied in a single operation from the source array instead of copying them one by one.
441- /// Long spans of null values are also especially cheap because they do not need to be represented
442- /// in an input array.
443- ///
444- /// # Safety
445- ///
446- /// This function does not check that the number of occurrences of any particular array index matches
447- /// the length of the corresponding input array. If an array contains more values than required, the
448- /// spurious values will be ignored. If an array contains fewer values than necessary, this function
449- /// will panic.
450- ///
451- /// # Example
452- ///
453- /// ```text
454- /// ┌───────────┐ ┌─────────┐ ┌─────────┐
455- /// │┌─────────┐│ │ None │ │ NULL │
456- /// ││ A ││ ├─────────┤ ├─────────┤
457- /// │└─────────┘│ │ 1 │ │ B │
458- /// │┌─────────┐│ ├─────────┤ ├─────────┤
459- /// ││ B ││ │ 0 │ merge(values, indices) │ A │
460- /// │└─────────┘│ ├─────────┤ ─────────────────────────▶ ├─────────┤
461- /// │┌─────────┐│ │ None │ │ NULL │
462- /// ││ C ││ ├─────────┤ ├─────────┤
463- /// │├─────────┤│ │ 2 │ │ C │
464- /// ││ D ││ ├─────────┤ ├─────────┤
465- /// │└─────────┘│ │ 2 │ │ D │
466- /// └───────────┘ └─────────┘ └─────────┘
467- /// values indices result
468- /// ```
469- fn merge_n ( values : & [ ArrayData ] , indices : & [ PartialResultIndex ] ) -> Result < ArrayRef > {
470- #[ cfg( debug_assertions) ]
471- for ix in indices {
472- if let Some ( index) = ix. index ( ) {
473- assert ! (
474- index < values. len( ) ,
475- "Index out of bounds: {} >= {}" ,
476- index,
477- values. len( )
478- ) ;
479- }
480- }
481-
482- let data_refs = values. iter ( ) . collect ( ) ;
483- let mut mutable = MutableArrayData :: new ( data_refs, true , indices. len ( ) ) ;
484-
485- // This loop extends the mutable array by taking slices from the partial results.
486- //
487- // take_offsets keeps track of how many values have been taken from each array.
488- let mut take_offsets = vec ! [ 0 ; values. len( ) + 1 ] ;
489- let mut start_row_ix = 0 ;
490- loop {
491- let array_ix = indices[ start_row_ix] ;
492-
493- // Determine the length of the slice to take.
494- let mut end_row_ix = start_row_ix + 1 ;
495- while end_row_ix < indices. len ( ) && indices[ end_row_ix] == array_ix {
496- end_row_ix += 1 ;
497- }
498- let slice_length = end_row_ix - start_row_ix;
499-
500- // Extend mutable with either nulls or with values from the array.
501- match array_ix. index ( ) {
502- None => mutable. extend_nulls ( slice_length) ,
503- Some ( index) => {
504- let start_offset = take_offsets[ index] ;
505- let end_offset = start_offset + slice_length;
506- mutable. extend ( index, start_offset, end_offset) ;
507- take_offsets[ index] = end_offset;
508- }
509- }
510-
511- if end_row_ix == indices. len ( ) {
512- break ;
513- } else {
514- // Set the start_row_ix for the next slice.
515- start_row_ix = end_row_ix;
516- }
517- }
518-
519- Ok ( make_array ( mutable. freeze ( ) ) )
520- }
521-
522339/// An index into the partial results array that's more compact than `usize`.
523340///
524341/// `u32::MAX` is reserved as a special 'none' value. This is used instead of
@@ -561,7 +378,9 @@ impl PartialResultIndex {
561378 fn is_none ( & self ) -> bool {
562379 self . index == NONE_VALUE
563380 }
381+ }
564382
383+ impl MergeIndex for PartialResultIndex {
565384 /// Returns `Some(index)` if this value is not the 'none' placeholder, `None` otherwise.
566385 fn index ( & self ) -> Option < usize > {
567386 if self . is_none ( ) {
@@ -589,7 +408,7 @@ enum ResultState {
589408 Partial {
590409 // A `Vec` of partial results that should be merged.
591410 // `partial_result_indices` contains indexes into this vec.
592- arrays : Vec < ArrayData > ,
411+ arrays : Vec < ArrayRef > ,
593412 // Indicates per result row from which array in `partial_results` a value should be taken.
594413 indices : Vec < PartialResultIndex > ,
595414 } ,
@@ -670,7 +489,7 @@ impl ResultBuilder {
670489 } else if row_indices. len ( ) == self . row_count {
671490 self . set_complete_result ( ColumnarValue :: Array ( a) )
672491 } else {
673- self . add_partial_result ( row_indices, a. to_data ( ) )
492+ self . add_partial_result ( row_indices, a)
674493 }
675494 }
676495 ColumnarValue :: Scalar ( s) => {
@@ -679,7 +498,7 @@ impl ResultBuilder {
679498 } else {
680499 self . add_partial_result (
681500 row_indices,
682- s. to_array_of_size ( row_indices. len ( ) ) ?. to_data ( ) ,
501+ s. to_array_of_size ( row_indices. len ( ) ) ?,
683502 )
684503 }
685504 }
@@ -694,7 +513,7 @@ impl ResultBuilder {
694513 fn add_partial_result (
695514 & mut self ,
696515 row_indices : & ArrayRef ,
697- row_values : ArrayData ,
516+ row_values : ArrayRef ,
698517 ) -> Result < ( ) > {
699518 assert_or_internal_err ! (
700519 row_indices. null_count( ) == 0 ,
@@ -775,7 +594,8 @@ impl ResultBuilder {
775594 }
776595 ResultState :: Partial { arrays, indices } => {
777596 // Merge partial results into a single array.
778- Ok ( ColumnarValue :: Array ( merge_n ( & arrays, & indices) ?) )
597+ let array_refs = arrays. iter ( ) . map ( |a| a. as_ref ( ) ) . collect :: < Vec < _ > > ( ) ;
598+ Ok ( ColumnarValue :: Array ( merge_n ( & array_refs, & indices) ?) )
779599 }
780600 ResultState :: Complete ( v) => {
781601 // If we have a complete result, we can just return it.
@@ -1152,11 +972,20 @@ impl CaseBody {
1152972
1153973 let else_value = else_expr. evaluate ( & else_batch) ?;
1154974
1155- Ok ( ColumnarValue :: Array ( merge (
1156- & when_value,
1157- then_value,
1158- else_value,
1159- ) ?) )
975+ Ok ( ColumnarValue :: Array ( match ( then_value, else_value) {
976+ ( ColumnarValue :: Array ( t) , ColumnarValue :: Array ( e) ) => {
977+ merge ( & when_value, & t, & e)
978+ }
979+ ( ColumnarValue :: Scalar ( t) , ColumnarValue :: Array ( e) ) => {
980+ merge ( & when_value, & t. to_scalar ( ) ?, & e)
981+ }
982+ ( ColumnarValue :: Array ( t) , ColumnarValue :: Scalar ( e) ) => {
983+ merge ( & when_value, & t, & e. to_scalar ( ) ?)
984+ }
985+ ( ColumnarValue :: Scalar ( t) , ColumnarValue :: Scalar ( e) ) => {
986+ merge ( & when_value, & t. to_scalar ( ) ?, & e. to_scalar ( ) ?)
987+ }
988+ } ?) )
1160989 }
1161990}
1162991
@@ -2567,57 +2396,6 @@ mod tests {
25672396 Ok ( ( ) )
25682397 }
25692398
2570- #[ test]
2571- fn test_merge_n ( ) {
2572- let a1 = StringArray :: from ( vec ! [ Some ( "A" ) ] ) . to_data ( ) ;
2573- let a2 = StringArray :: from ( vec ! [ Some ( "B" ) ] ) . to_data ( ) ;
2574- let a3 = StringArray :: from ( vec ! [ Some ( "C" ) , Some ( "D" ) ] ) . to_data ( ) ;
2575-
2576- let indices = vec ! [
2577- PartialResultIndex :: none( ) ,
2578- PartialResultIndex :: try_new( 1 ) . unwrap( ) ,
2579- PartialResultIndex :: try_new( 0 ) . unwrap( ) ,
2580- PartialResultIndex :: none( ) ,
2581- PartialResultIndex :: try_new( 2 ) . unwrap( ) ,
2582- PartialResultIndex :: try_new( 2 ) . unwrap( ) ,
2583- ] ;
2584-
2585- let merged = merge_n ( & [ a1, a2, a3] , & indices) . unwrap ( ) ;
2586- let merged = merged. as_string :: < i32 > ( ) ;
2587-
2588- assert_eq ! ( merged. len( ) , indices. len( ) ) ;
2589- assert ! ( !merged. is_valid( 0 ) ) ;
2590- assert ! ( merged. is_valid( 1 ) ) ;
2591- assert_eq ! ( merged. value( 1 ) , "B" ) ;
2592- assert ! ( merged. is_valid( 2 ) ) ;
2593- assert_eq ! ( merged. value( 2 ) , "A" ) ;
2594- assert ! ( !merged. is_valid( 3 ) ) ;
2595- assert ! ( merged. is_valid( 4 ) ) ;
2596- assert_eq ! ( merged. value( 4 ) , "C" ) ;
2597- assert ! ( merged. is_valid( 5 ) ) ;
2598- assert_eq ! ( merged. value( 5 ) , "D" ) ;
2599- }
2600-
2601- #[ test]
2602- fn test_merge ( ) {
2603- let a1 = Arc :: new ( StringArray :: from ( vec ! [ Some ( "A" ) , Some ( "C" ) ] ) ) ;
2604- let a2 = Arc :: new ( StringArray :: from ( vec ! [ Some ( "B" ) ] ) ) ;
2605-
2606- let mask = BooleanArray :: from ( vec ! [ true , false , true ] ) ;
2607-
2608- let merged =
2609- merge ( & mask, ColumnarValue :: Array ( a1) , ColumnarValue :: Array ( a2) ) . unwrap ( ) ;
2610- let merged = merged. as_string :: < i32 > ( ) ;
2611-
2612- assert_eq ! ( merged. len( ) , mask. len( ) ) ;
2613- assert ! ( merged. is_valid( 0 ) ) ;
2614- assert_eq ! ( merged. value( 0 ) , "A" ) ;
2615- assert ! ( merged. is_valid( 1 ) ) ;
2616- assert_eq ! ( merged. value( 1 ) , "B" ) ;
2617- assert ! ( merged. is_valid( 2 ) ) ;
2618- assert_eq ! ( merged. value( 2 ) , "C" ) ;
2619- }
2620-
26212399 fn when_then_else (
26222400 when : & Arc < dyn PhysicalExpr > ,
26232401 then : & Arc < dyn PhysicalExpr > ,
0 commit comments