Skip to content

Commit 1c397a9

Browse files
authored
Add specialized coalesce path for PrimitiveArrays (#7772)
# Which issue does this PR close? - Closes #7763 # Rationale for this change I want the `coalesce` operation to be as fast as possible # What changes are included in this PR? Add specialied `InProgressPrimitiveArray` that avoids keeping a second copy of the primitive arrays that are concat'ed together I don't expect this will make a huge performance difference -- but it is needed to implement #7762 which I do expect to make a difference Update: it turns out this seems to improve performance quite a bit (25%) for highly selective kernels. I speculare this is due to not having to keep around many small allocations to hold intermediate `ArrayRef`s # Are these changes tested? There are already existing tests for u32s which cover this code path. I also added a test for StringArray which ensures the generic in progress array is still covered. I also checked coverage using ```shell cargo llvm-cov --html -p arrow-select -p arrow-data ``` # Are there any user-facing changes? No, this is an internal optimization only
1 parent 0de463e commit 1c397a9

File tree

2 files changed

+168
-15
lines changed

2 files changed

+168
-15
lines changed

arrow-select/src/coalesce.rs

Lines changed: 67 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
//! [`take`]: crate::take::take
2323
use crate::filter::filter_record_batch;
2424
use arrow_array::types::{BinaryViewType, StringViewType};
25-
use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch};
25+
use arrow_array::{downcast_primitive, Array, ArrayRef, BooleanArray, RecordBatch};
2626
use arrow_schema::{ArrowError, DataType, SchemaRef};
2727
use std::collections::VecDeque;
2828
use std::sync::Arc;
@@ -31,9 +31,11 @@ use std::sync::Arc;
3131

3232
mod byte_view;
3333
mod generic;
34+
mod primitive;
3435

3536
use byte_view::InProgressByteViewArray;
3637
use generic::GenericInProgressArray;
38+
use primitive::InProgressPrimitiveArray;
3739

3840
/// Concatenate multiple [`RecordBatch`]es
3941
///
@@ -322,7 +324,15 @@ impl BatchCoalescer {
322324

323325
/// Return a new `InProgressArray` for the given data type
324326
fn create_in_progress_array(data_type: &DataType, batch_size: usize) -> Box<dyn InProgressArray> {
325-
match data_type {
327+
macro_rules! instantiate_primitive {
328+
($t:ty) => {
329+
Box::new(InProgressPrimitiveArray::<$t>::new(batch_size))
330+
};
331+
}
332+
333+
downcast_primitive! {
334+
// Instantiate InProgressPrimitiveArray for each primitive type
335+
data_type => (instantiate_primitive),
326336
DataType::Utf8View => Box::new(InProgressByteViewArray::<StringViewType>::new(batch_size)),
327337
DataType::BinaryView => {
328338
Box::new(InProgressByteViewArray::<BinaryViewType>::new(batch_size))
@@ -364,7 +374,9 @@ mod tests {
364374
use crate::concat::concat_batches;
365375
use arrow_array::builder::StringViewBuilder;
366376
use arrow_array::cast::AsArray;
367-
use arrow_array::{BinaryViewArray, RecordBatchOptions, StringViewArray, UInt32Array};
377+
use arrow_array::{
378+
BinaryViewArray, RecordBatchOptions, StringArray, StringViewArray, UInt32Array,
379+
};
368380
use arrow_schema::{DataType, Field, Schema};
369381
use std::ops::Range;
370382

@@ -456,6 +468,27 @@ mod tests {
456468
.run();
457469
}
458470

471+
#[test]
472+
fn test_coalesce_non_null() {
473+
Test::new()
474+
// 4040 rows of unit32
475+
.with_batch(uint32_batch_non_null(0..3000))
476+
.with_batch(uint32_batch_non_null(0..1040))
477+
.with_batch_size(1024)
478+
.with_expected_output_sizes(vec![1024, 1024, 1024, 968])
479+
.run();
480+
}
481+
#[test]
482+
fn test_utf8_split() {
483+
Test::new()
484+
// 4040 rows of utf8 strings in total, split into batches of 1024
485+
.with_batch(utf8_batch(0..3000))
486+
.with_batch(utf8_batch(0..1040))
487+
.with_batch_size(1024)
488+
.with_expected_output_sizes(vec![1024, 1024, 1024, 968])
489+
.run();
490+
}
491+
459492
#[test]
460493
fn test_string_view_no_views() {
461494
let output_batches = Test::new()
@@ -941,15 +974,37 @@ mod tests {
941974
}
942975
}
943976

944-
/// Return a RecordBatch with a UInt32Array with the specified range
977+
/// Return a RecordBatch with a UInt32Array with the specified range and
978+
/// every third value is null.
945979
fn uint32_batch(range: Range<u32>) -> RecordBatch {
980+
let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, true)]));
981+
982+
let array = UInt32Array::from_iter(range.map(|i| if i % 3 == 0 { None } else { Some(i) }));
983+
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
984+
}
985+
986+
/// Return a RecordBatch with a UInt32Array with no nulls specified range
987+
fn uint32_batch_non_null(range: Range<u32>) -> RecordBatch {
946988
let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));
947989

948-
RecordBatch::try_new(
949-
Arc::clone(&schema),
950-
vec![Arc::new(UInt32Array::from_iter_values(range))],
951-
)
952-
.unwrap()
990+
let array = UInt32Array::from_iter_values(range);
991+
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
992+
}
993+
994+
/// Return a RecordBatch with a StringArrary with values `value0`, `value1`, ...
995+
/// and every third value is `None`.
996+
fn utf8_batch(range: Range<u32>) -> RecordBatch {
997+
let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::Utf8, true)]));
998+
999+
let array = StringArray::from_iter(range.map(|i| {
1000+
if i % 3 == 0 {
1001+
None
1002+
} else {
1003+
Some(format!("value{}", i))
1004+
}
1005+
}));
1006+
1007+
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
9531008
}
9541009

9551010
/// Return a RecordBatch with a StringViewArray with (only) the specified values
@@ -960,14 +1015,11 @@ mod tests {
9601015
false,
9611016
)]));
9621017

963-
RecordBatch::try_new(
964-
Arc::clone(&schema),
965-
vec![Arc::new(StringViewArray::from_iter(values))],
966-
)
967-
.unwrap()
1018+
let array = StringViewArray::from_iter(values);
1019+
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
9681020
}
9691021

970-
/// Return a RecordBatch with a StringViewArray with num_rows by repating
1022+
/// Return a RecordBatch with a StringViewArray with num_rows by repeating
9711023
/// values over and over.
9721024
fn stringview_batch_repeated<'a>(
9731025
num_rows: usize,
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::coalesce::InProgressArray;
19+
use arrow_array::cast::AsArray;
20+
use arrow_array::{Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray};
21+
use arrow_buffer::{NullBufferBuilder, ScalarBuffer};
22+
use arrow_schema::ArrowError;
23+
use std::fmt::Debug;
24+
use std::sync::Arc;
25+
26+
/// InProgressArray for [`PrimitiveArray`]
27+
#[derive(Debug)]
28+
pub(crate) struct InProgressPrimitiveArray<T: ArrowPrimitiveType> {
29+
/// The current source, if any
30+
source: Option<ArrayRef>,
31+
/// the target batch size (and thus size for views allocation)
32+
batch_size: usize,
33+
/// In progress nulls
34+
nulls: NullBufferBuilder,
35+
/// The currently in progress array
36+
current: Vec<T::Native>,
37+
}
38+
39+
impl<T: ArrowPrimitiveType> InProgressPrimitiveArray<T> {
40+
/// Create a new `InProgressPrimitiveArray`
41+
pub(crate) fn new(batch_size: usize) -> Self {
42+
Self {
43+
batch_size,
44+
source: None,
45+
nulls: NullBufferBuilder::new(batch_size),
46+
current: vec![],
47+
}
48+
}
49+
50+
/// Allocate space for output values if necessary.
51+
///
52+
/// This is done on write (when we know it is necessary) rather than
53+
/// eagerly to avoid allocations that are not used.
54+
fn ensure_capacity(&mut self) {
55+
self.current.reserve(self.batch_size);
56+
}
57+
}
58+
59+
impl<T: ArrowPrimitiveType + Debug> InProgressArray for InProgressPrimitiveArray<T> {
60+
fn set_source(&mut self, source: Option<ArrayRef>) {
61+
self.source = source;
62+
}
63+
64+
fn copy_rows(&mut self, offset: usize, len: usize) -> Result<(), ArrowError> {
65+
self.ensure_capacity();
66+
67+
let s = self
68+
.source
69+
.as_ref()
70+
.ok_or_else(|| {
71+
ArrowError::InvalidArgumentError(
72+
"Internal Error: InProgressPrimitiveArray: source not set".to_string(),
73+
)
74+
})?
75+
.as_primitive::<T>();
76+
77+
// add nulls if necessary
78+
if let Some(nulls) = s.nulls().as_ref() {
79+
let nulls = nulls.slice(offset, len);
80+
self.nulls.append_buffer(&nulls);
81+
} else {
82+
self.nulls.append_n_non_nulls(len);
83+
};
84+
85+
// Copy the values
86+
self.current
87+
.extend_from_slice(&s.values()[offset..offset + len]);
88+
89+
Ok(())
90+
}
91+
92+
fn finish(&mut self) -> Result<ArrayRef, ArrowError> {
93+
// take and reset the current values and nulls
94+
let values = std::mem::take(&mut self.current);
95+
let nulls = self.nulls.finish();
96+
self.nulls = NullBufferBuilder::new(self.batch_size);
97+
98+
let array = PrimitiveArray::<T>::try_new(ScalarBuffer::from(values), nulls)?;
99+
Ok(Arc::new(array))
100+
}
101+
}

0 commit comments

Comments
 (0)