Skip to content

Commit 03904e1

Browse files
authored
Replace custom merge operator with arrow-rs implementation (#19424)
## Which issue does this PR close? - Closes #19423. ## Rationale for this change The functions `arrow_select::merge::merge` and `arrow_select::merge::merge_n` were first implemented for DataFusion in `case.rs`. They have since been generalised and moved to `arrow-rs`. Now that an `arrow-rs` is available that contains these functions, DataFusion should make use of them. ## What changes are included in this PR? - Remove `merge` and `merge_n` from `case.rs` along with the unit tests for those functions - Adapt code for their equivalents from `arrow-rs` ## Are these changes tested? Covered by existing unit tests and SLTs ## Are there any user-facing changes? No
1 parent ef2c1a3 commit 03904e1

File tree

1 file changed

+24
-246
lines changed
  • datafusion/physical-expr/src/expressions

1 file changed

+24
-246
lines changed

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 24 additions & 246 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ use crate::expressions::{lit, try_cast};
2323
use arrow::array::*;
2424
use arrow::compute::kernels::zip::zip;
2525
use arrow::compute::{
26-
FilterBuilder, FilterPredicate, SlicesIterator, is_not_null, not, nullif,
27-
prep_null_mask_filter,
26+
FilterBuilder, FilterPredicate, is_not_null, not, nullif, prep_null_mask_filter,
2827
};
2928
use arrow::datatypes::{DataType, Schema, UInt32Type, UnionMode};
3029
use arrow::error::ArrowError;
@@ -39,6 +38,7 @@ use std::hash::Hash;
3938
use std::{any::Any, sync::Arc};
4039

4140
use crate::expressions::case::literal_lookup_table::LiteralLookupTable;
41+
use arrow::compute::kernels::merge::{MergeIndex, merge, merge_n};
4242
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
4343
use datafusion_physical_expr_common::datum::compare_with_eq;
4444
use itertools::Itertools;
@@ -336,189 +336,6 @@ fn filter_array(
336336
filter.filter(array)
337337
}
338338

339-
fn merge(
340-
mask: &BooleanArray,
341-
truthy: ColumnarValue,
342-
falsy: ColumnarValue,
343-
) -> std::result::Result<ArrayRef, ArrowError> {
344-
let (truthy, truthy_is_scalar) = match truthy {
345-
ColumnarValue::Array(a) => (a, false),
346-
ColumnarValue::Scalar(s) => (s.to_array()?, true),
347-
};
348-
let (falsy, falsy_is_scalar) = match falsy {
349-
ColumnarValue::Array(a) => (a, false),
350-
ColumnarValue::Scalar(s) => (s.to_array()?, true),
351-
};
352-
353-
if truthy_is_scalar && falsy_is_scalar {
354-
return zip(mask, &Scalar::new(truthy), &Scalar::new(falsy));
355-
}
356-
357-
let falsy = falsy.to_data();
358-
let truthy = truthy.to_data();
359-
360-
let mut mutable = MutableArrayData::new(vec![&truthy, &falsy], false, truthy.len());
361-
362-
// the SlicesIterator slices only the true values. So the gaps left by this iterator we need to
363-
// fill with falsy values
364-
365-
// keep track of how much is filled
366-
let mut filled = 0;
367-
let mut falsy_offset = 0;
368-
let mut truthy_offset = 0;
369-
370-
SlicesIterator::new(mask).for_each(|(start, end)| {
371-
// the gap needs to be filled with falsy values
372-
if start > filled {
373-
if falsy_is_scalar {
374-
for _ in filled..start {
375-
// Copy the first item from the 'falsy' array into the output buffer.
376-
mutable.extend(1, 0, 1);
377-
}
378-
} else {
379-
let falsy_length = start - filled;
380-
let falsy_end = falsy_offset + falsy_length;
381-
mutable.extend(1, falsy_offset, falsy_end);
382-
falsy_offset = falsy_end;
383-
}
384-
}
385-
// fill with truthy values
386-
if truthy_is_scalar {
387-
for _ in start..end {
388-
// Copy the first item from the 'truthy' array into the output buffer.
389-
mutable.extend(0, 0, 1);
390-
}
391-
} else {
392-
let truthy_length = end - start;
393-
let truthy_end = truthy_offset + truthy_length;
394-
mutable.extend(0, truthy_offset, truthy_end);
395-
truthy_offset = truthy_end;
396-
}
397-
filled = end;
398-
});
399-
// the remaining part is falsy
400-
if filled < mask.len() {
401-
if falsy_is_scalar {
402-
for _ in filled..mask.len() {
403-
// Copy the first item from the 'falsy' array into the output buffer.
404-
mutable.extend(1, 0, 1);
405-
}
406-
} else {
407-
let falsy_length = mask.len() - filled;
408-
let falsy_end = falsy_offset + falsy_length;
409-
mutable.extend(1, falsy_offset, falsy_end);
410-
}
411-
}
412-
413-
let data = mutable.freeze();
414-
Ok(make_array(data))
415-
}
416-
417-
/// Merges elements by index from a list of [`ArrayData`], creating a new [`ColumnarValue`] from
418-
/// those values.
419-
///
420-
/// Each element in `indices` is the index of an array in `values`. The `indices` array is processed
421-
/// sequentially. The first occurrence of index value `n` will be mapped to the first
422-
/// value of the array at index `n`. The second occurrence to the second value, and so on.
423-
/// An index value where `PartialResultIndex::is_none` is `true` is used to indicate null values.
424-
///
425-
/// # Implementation notes
426-
///
427-
/// This algorithm is similar in nature to both `zip` and `interleave`, but there are some important
428-
/// differences.
429-
///
430-
/// In contrast to `zip`, this function supports multiple input arrays. Instead of a boolean
431-
/// selection vector, an index array is to take values from the input arrays, and a special marker
432-
/// value is used to indicate null values.
433-
///
434-
/// In contrast to `interleave`, this function does not use pairs of indices. The values in
435-
/// `indices` serve the same purpose as the first value in the pairs passed to `interleave`.
436-
/// The index in the array is implicit and is derived from the number of times a particular array
437-
/// index occurs.
438-
/// The more constrained indexing mechanism used by this algorithm makes it easier to copy values
439-
/// in contiguous slices. In the example below, the two subsequent elements from array `2` can be
440-
/// copied in a single operation from the source array instead of copying them one by one.
441-
/// Long spans of null values are also especially cheap because they do not need to be represented
442-
/// in an input array.
443-
///
444-
/// # Safety
445-
///
446-
/// This function does not check that the number of occurrences of any particular array index matches
447-
/// the length of the corresponding input array. If an array contains more values than required, the
448-
/// spurious values will be ignored. If an array contains fewer values than necessary, this function
449-
/// will panic.
450-
///
451-
/// # Example
452-
///
453-
/// ```text
454-
/// ┌───────────┐ ┌─────────┐ ┌─────────┐
455-
/// │┌─────────┐│ │ None │ │ NULL │
456-
/// ││ A ││ ├─────────┤ ├─────────┤
457-
/// │└─────────┘│ │ 1 │ │ B │
458-
/// │┌─────────┐│ ├─────────┤ ├─────────┤
459-
/// ││ B ││ │ 0 │ merge(values, indices) │ A │
460-
/// │└─────────┘│ ├─────────┤ ─────────────────────────▶ ├─────────┤
461-
/// │┌─────────┐│ │ None │ │ NULL │
462-
/// ││ C ││ ├─────────┤ ├─────────┤
463-
/// │├─────────┤│ │ 2 │ │ C │
464-
/// ││ D ││ ├─────────┤ ├─────────┤
465-
/// │└─────────┘│ │ 2 │ │ D │
466-
/// └───────────┘ └─────────┘ └─────────┘
467-
/// values indices result
468-
/// ```
469-
fn merge_n(values: &[ArrayData], indices: &[PartialResultIndex]) -> Result<ArrayRef> {
470-
#[cfg(debug_assertions)]
471-
for ix in indices {
472-
if let Some(index) = ix.index() {
473-
assert!(
474-
index < values.len(),
475-
"Index out of bounds: {} >= {}",
476-
index,
477-
values.len()
478-
);
479-
}
480-
}
481-
482-
let data_refs = values.iter().collect();
483-
let mut mutable = MutableArrayData::new(data_refs, true, indices.len());
484-
485-
// This loop extends the mutable array by taking slices from the partial results.
486-
//
487-
// take_offsets keeps track of how many values have been taken from each array.
488-
let mut take_offsets = vec![0; values.len() + 1];
489-
let mut start_row_ix = 0;
490-
loop {
491-
let array_ix = indices[start_row_ix];
492-
493-
// Determine the length of the slice to take.
494-
let mut end_row_ix = start_row_ix + 1;
495-
while end_row_ix < indices.len() && indices[end_row_ix] == array_ix {
496-
end_row_ix += 1;
497-
}
498-
let slice_length = end_row_ix - start_row_ix;
499-
500-
// Extend mutable with either nulls or with values from the array.
501-
match array_ix.index() {
502-
None => mutable.extend_nulls(slice_length),
503-
Some(index) => {
504-
let start_offset = take_offsets[index];
505-
let end_offset = start_offset + slice_length;
506-
mutable.extend(index, start_offset, end_offset);
507-
take_offsets[index] = end_offset;
508-
}
509-
}
510-
511-
if end_row_ix == indices.len() {
512-
break;
513-
} else {
514-
// Set the start_row_ix for the next slice.
515-
start_row_ix = end_row_ix;
516-
}
517-
}
518-
519-
Ok(make_array(mutable.freeze()))
520-
}
521-
522339
/// An index into the partial results array that's more compact than `usize`.
523340
///
524341
/// `u32::MAX` is reserved as a special 'none' value. This is used instead of
@@ -561,7 +378,9 @@ impl PartialResultIndex {
561378
fn is_none(&self) -> bool {
562379
self.index == NONE_VALUE
563380
}
381+
}
564382

383+
impl MergeIndex for PartialResultIndex {
565384
/// Returns `Some(index)` if this value is not the 'none' placeholder, `None` otherwise.
566385
fn index(&self) -> Option<usize> {
567386
if self.is_none() {
@@ -589,7 +408,7 @@ enum ResultState {
589408
Partial {
590409
// A `Vec` of partial results that should be merged.
591410
// `partial_result_indices` contains indexes into this vec.
592-
arrays: Vec<ArrayData>,
411+
arrays: Vec<ArrayRef>,
593412
// Indicates per result row from which array in `partial_results` a value should be taken.
594413
indices: Vec<PartialResultIndex>,
595414
},
@@ -670,7 +489,7 @@ impl ResultBuilder {
670489
} else if row_indices.len() == self.row_count {
671490
self.set_complete_result(ColumnarValue::Array(a))
672491
} else {
673-
self.add_partial_result(row_indices, a.to_data())
492+
self.add_partial_result(row_indices, a)
674493
}
675494
}
676495
ColumnarValue::Scalar(s) => {
@@ -679,7 +498,7 @@ impl ResultBuilder {
679498
} else {
680499
self.add_partial_result(
681500
row_indices,
682-
s.to_array_of_size(row_indices.len())?.to_data(),
501+
s.to_array_of_size(row_indices.len())?,
683502
)
684503
}
685504
}
@@ -694,7 +513,7 @@ impl ResultBuilder {
694513
fn add_partial_result(
695514
&mut self,
696515
row_indices: &ArrayRef,
697-
row_values: ArrayData,
516+
row_values: ArrayRef,
698517
) -> Result<()> {
699518
assert_or_internal_err!(
700519
row_indices.null_count() == 0,
@@ -775,7 +594,8 @@ impl ResultBuilder {
775594
}
776595
ResultState::Partial { arrays, indices } => {
777596
// Merge partial results into a single array.
778-
Ok(ColumnarValue::Array(merge_n(&arrays, &indices)?))
597+
let array_refs = arrays.iter().map(|a| a.as_ref()).collect::<Vec<_>>();
598+
Ok(ColumnarValue::Array(merge_n(&array_refs, &indices)?))
779599
}
780600
ResultState::Complete(v) => {
781601
// If we have a complete result, we can just return it.
@@ -1152,11 +972,20 @@ impl CaseBody {
1152972

1153973
let else_value = else_expr.evaluate(&else_batch)?;
1154974

1155-
Ok(ColumnarValue::Array(merge(
1156-
&when_value,
1157-
then_value,
1158-
else_value,
1159-
)?))
975+
Ok(ColumnarValue::Array(match (then_value, else_value) {
976+
(ColumnarValue::Array(t), ColumnarValue::Array(e)) => {
977+
merge(&when_value, &t, &e)
978+
}
979+
(ColumnarValue::Scalar(t), ColumnarValue::Array(e)) => {
980+
merge(&when_value, &t.to_scalar()?, &e)
981+
}
982+
(ColumnarValue::Array(t), ColumnarValue::Scalar(e)) => {
983+
merge(&when_value, &t, &e.to_scalar()?)
984+
}
985+
(ColumnarValue::Scalar(t), ColumnarValue::Scalar(e)) => {
986+
merge(&when_value, &t.to_scalar()?, &e.to_scalar()?)
987+
}
988+
}?))
1160989
}
1161990
}
1162991

@@ -2567,57 +2396,6 @@ mod tests {
25672396
Ok(())
25682397
}
25692398

2570-
#[test]
2571-
fn test_merge_n() {
2572-
let a1 = StringArray::from(vec![Some("A")]).to_data();
2573-
let a2 = StringArray::from(vec![Some("B")]).to_data();
2574-
let a3 = StringArray::from(vec![Some("C"), Some("D")]).to_data();
2575-
2576-
let indices = vec![
2577-
PartialResultIndex::none(),
2578-
PartialResultIndex::try_new(1).unwrap(),
2579-
PartialResultIndex::try_new(0).unwrap(),
2580-
PartialResultIndex::none(),
2581-
PartialResultIndex::try_new(2).unwrap(),
2582-
PartialResultIndex::try_new(2).unwrap(),
2583-
];
2584-
2585-
let merged = merge_n(&[a1, a2, a3], &indices).unwrap();
2586-
let merged = merged.as_string::<i32>();
2587-
2588-
assert_eq!(merged.len(), indices.len());
2589-
assert!(!merged.is_valid(0));
2590-
assert!(merged.is_valid(1));
2591-
assert_eq!(merged.value(1), "B");
2592-
assert!(merged.is_valid(2));
2593-
assert_eq!(merged.value(2), "A");
2594-
assert!(!merged.is_valid(3));
2595-
assert!(merged.is_valid(4));
2596-
assert_eq!(merged.value(4), "C");
2597-
assert!(merged.is_valid(5));
2598-
assert_eq!(merged.value(5), "D");
2599-
}
2600-
2601-
#[test]
2602-
fn test_merge() {
2603-
let a1 = Arc::new(StringArray::from(vec![Some("A"), Some("C")]));
2604-
let a2 = Arc::new(StringArray::from(vec![Some("B")]));
2605-
2606-
let mask = BooleanArray::from(vec![true, false, true]);
2607-
2608-
let merged =
2609-
merge(&mask, ColumnarValue::Array(a1), ColumnarValue::Array(a2)).unwrap();
2610-
let merged = merged.as_string::<i32>();
2611-
2612-
assert_eq!(merged.len(), mask.len());
2613-
assert!(merged.is_valid(0));
2614-
assert_eq!(merged.value(0), "A");
2615-
assert!(merged.is_valid(1));
2616-
assert_eq!(merged.value(1), "B");
2617-
assert!(merged.is_valid(2));
2618-
assert_eq!(merged.value(2), "C");
2619-
}
2620-
26212399
fn when_then_else(
26222400
when: &Arc<dyn PhysicalExpr>,
26232401
then: &Arc<dyn PhysicalExpr>,

0 commit comments

Comments
 (0)