Skip to content

Commit b05edf4

Browse files
authored
Implement RecordBatch::concat (apache#537)
* Implements `RecordBatch::concat`. * Updates according to code review.
1 parent fde79a2 commit b05edf4

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed

arrow/src/record_batch.rs

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
use std::sync::Arc;
2222

2323
use crate::array::*;
24+
use crate::compute::kernels::concat::concat;
2425
use crate::datatypes::*;
2526
use crate::error::{ArrowError, Result};
2627

@@ -353,6 +354,35 @@ impl RecordBatch {
353354
let schema = Arc::new(Schema::new(fields));
354355
RecordBatch::try_new(schema, columns)
355356
}
357+
358+
/// Concatenates `batches` together into a single record batch.
359+
pub fn concat(schema: &SchemaRef, batches: &[Self]) -> Result<Self> {
360+
if batches.is_empty() {
361+
return Ok(RecordBatch::new_empty(schema.clone()));
362+
}
363+
if let Some((i, _)) = batches
364+
.iter()
365+
.enumerate()
366+
.find(|&(_, batch)| batch.schema() != *schema)
367+
{
368+
return Err(ArrowError::InvalidArgumentError(format!(
369+
"batches[{}] schema is different with argument schema.",
370+
i
371+
)));
372+
}
373+
let field_num = schema.fields().len();
374+
let mut arrays = Vec::with_capacity(field_num);
375+
for i in 0..field_num {
376+
let array = concat(
377+
&batches
378+
.iter()
379+
.map(|batch| batch.column(i).as_ref())
380+
.collect::<Vec<_>>(),
381+
)?;
382+
arrays.push(array);
383+
}
384+
Self::try_new(schema.clone(), arrays)
385+
}
356386
}
357387

358388
/// Options that control the behaviour used when creating a [`RecordBatch`].
@@ -639,4 +669,76 @@ mod tests {
639669
assert_eq!(batch.column(0).as_ref(), boolean.as_ref());
640670
assert_eq!(batch.column(1).as_ref(), int.as_ref());
641671
}
672+
673+
#[test]
674+
fn concat_record_batches() {
675+
let schema = Arc::new(Schema::new(vec![
676+
Field::new("a", DataType::Int32, false),
677+
Field::new("b", DataType::Utf8, false),
678+
]));
679+
let batch1 = RecordBatch::try_new(
680+
schema.clone(),
681+
vec![
682+
Arc::new(Int32Array::from(vec![1, 2])),
683+
Arc::new(StringArray::from(vec!["a", "b"])),
684+
],
685+
)
686+
.unwrap();
687+
let batch2 = RecordBatch::try_new(
688+
schema.clone(),
689+
vec![
690+
Arc::new(Int32Array::from(vec![3, 4])),
691+
Arc::new(StringArray::from(vec!["c", "d"])),
692+
],
693+
)
694+
.unwrap();
695+
let new_batch = RecordBatch::concat(&schema, &[batch1, batch2]).unwrap();
696+
assert_eq!(new_batch.schema().as_ref(), schema.as_ref());
697+
assert_eq!(2, new_batch.num_columns());
698+
assert_eq!(4, new_batch.num_rows());
699+
}
700+
701+
#[test]
702+
fn concat_empty_record_batch() {
703+
let schema = Arc::new(Schema::new(vec![
704+
Field::new("a", DataType::Int32, false),
705+
Field::new("b", DataType::Utf8, false),
706+
]));
707+
let batch = RecordBatch::concat(&schema, &[]).unwrap();
708+
assert_eq!(batch.schema().as_ref(), schema.as_ref());
709+
assert_eq!(0, batch.num_rows());
710+
}
711+
712+
#[test]
713+
fn concat_record_batches_of_different_schemas() {
714+
let schema1 = Arc::new(Schema::new(vec![
715+
Field::new("a", DataType::Int32, false),
716+
Field::new("b", DataType::Utf8, false),
717+
]));
718+
let schema2 = Arc::new(Schema::new(vec![
719+
Field::new("c", DataType::Int32, false),
720+
Field::new("d", DataType::Utf8, false),
721+
]));
722+
let batch1 = RecordBatch::try_new(
723+
schema1.clone(),
724+
vec![
725+
Arc::new(Int32Array::from(vec![1, 2])),
726+
Arc::new(StringArray::from(vec!["a", "b"])),
727+
],
728+
)
729+
.unwrap();
730+
let batch2 = RecordBatch::try_new(
731+
schema2,
732+
vec![
733+
Arc::new(Int32Array::from(vec![3, 4])),
734+
Arc::new(StringArray::from(vec!["c", "d"])),
735+
],
736+
)
737+
.unwrap();
738+
let error = RecordBatch::concat(&schema1, &[batch1, batch2]).unwrap_err();
739+
assert_eq!(
740+
error.to_string(),
741+
"Invalid argument error: batches[1] schema is different with argument schema.",
742+
);
743+
}
642744
}

0 commit comments

Comments
 (0)