|
21 | 21 | use std::sync::Arc;
|
22 | 22 |
|
23 | 23 | use crate::array::*;
|
| 24 | +use crate::compute::kernels::concat::concat; |
24 | 25 | use crate::datatypes::*;
|
25 | 26 | use crate::error::{ArrowError, Result};
|
26 | 27 |
|
@@ -353,6 +354,35 @@ impl RecordBatch {
|
353 | 354 | let schema = Arc::new(Schema::new(fields));
|
354 | 355 | RecordBatch::try_new(schema, columns)
|
355 | 356 | }
|
| 357 | + |
| 358 | + /// Concatenates `batches` together into a single record batch. |
| 359 | + pub fn concat(schema: &SchemaRef, batches: &[Self]) -> Result<Self> { |
| 360 | + if batches.is_empty() { |
| 361 | + return Ok(RecordBatch::new_empty(schema.clone())); |
| 362 | + } |
| 363 | + if let Some((i, _)) = batches |
| 364 | + .iter() |
| 365 | + .enumerate() |
| 366 | + .find(|&(_, batch)| batch.schema() != *schema) |
| 367 | + { |
| 368 | + return Err(ArrowError::InvalidArgumentError(format!( |
| 369 | + "batches[{}] schema is different with argument schema.", |
| 370 | + i |
| 371 | + ))); |
| 372 | + } |
| 373 | + let field_num = schema.fields().len(); |
| 374 | + let mut arrays = Vec::with_capacity(field_num); |
| 375 | + for i in 0..field_num { |
| 376 | + let array = concat( |
| 377 | + &batches |
| 378 | + .iter() |
| 379 | + .map(|batch| batch.column(i).as_ref()) |
| 380 | + .collect::<Vec<_>>(), |
| 381 | + )?; |
| 382 | + arrays.push(array); |
| 383 | + } |
| 384 | + Self::try_new(schema.clone(), arrays) |
| 385 | + } |
356 | 386 | }
|
357 | 387 |
|
358 | 388 | /// Options that control the behaviour used when creating a [`RecordBatch`].
|
@@ -639,4 +669,76 @@ mod tests {
|
639 | 669 | assert_eq!(batch.column(0).as_ref(), boolean.as_ref());
|
640 | 670 | assert_eq!(batch.column(1).as_ref(), int.as_ref());
|
641 | 671 | }
|
| 672 | + |
| 673 | + #[test] |
| 674 | + fn concat_record_batches() { |
| 675 | + let schema = Arc::new(Schema::new(vec![ |
| 676 | + Field::new("a", DataType::Int32, false), |
| 677 | + Field::new("b", DataType::Utf8, false), |
| 678 | + ])); |
| 679 | + let batch1 = RecordBatch::try_new( |
| 680 | + schema.clone(), |
| 681 | + vec![ |
| 682 | + Arc::new(Int32Array::from(vec![1, 2])), |
| 683 | + Arc::new(StringArray::from(vec!["a", "b"])), |
| 684 | + ], |
| 685 | + ) |
| 686 | + .unwrap(); |
| 687 | + let batch2 = RecordBatch::try_new( |
| 688 | + schema.clone(), |
| 689 | + vec![ |
| 690 | + Arc::new(Int32Array::from(vec![3, 4])), |
| 691 | + Arc::new(StringArray::from(vec!["c", "d"])), |
| 692 | + ], |
| 693 | + ) |
| 694 | + .unwrap(); |
| 695 | + let new_batch = RecordBatch::concat(&schema, &[batch1, batch2]).unwrap(); |
| 696 | + assert_eq!(new_batch.schema().as_ref(), schema.as_ref()); |
| 697 | + assert_eq!(2, new_batch.num_columns()); |
| 698 | + assert_eq!(4, new_batch.num_rows()); |
| 699 | + } |
| 700 | + |
| 701 | + #[test] |
| 702 | + fn concat_empty_record_batch() { |
| 703 | + let schema = Arc::new(Schema::new(vec![ |
| 704 | + Field::new("a", DataType::Int32, false), |
| 705 | + Field::new("b", DataType::Utf8, false), |
| 706 | + ])); |
| 707 | + let batch = RecordBatch::concat(&schema, &[]).unwrap(); |
| 708 | + assert_eq!(batch.schema().as_ref(), schema.as_ref()); |
| 709 | + assert_eq!(0, batch.num_rows()); |
| 710 | + } |
| 711 | + |
| 712 | + #[test] |
| 713 | + fn concat_record_batches_of_different_schemas() { |
| 714 | + let schema1 = Arc::new(Schema::new(vec![ |
| 715 | + Field::new("a", DataType::Int32, false), |
| 716 | + Field::new("b", DataType::Utf8, false), |
| 717 | + ])); |
| 718 | + let schema2 = Arc::new(Schema::new(vec![ |
| 719 | + Field::new("c", DataType::Int32, false), |
| 720 | + Field::new("d", DataType::Utf8, false), |
| 721 | + ])); |
| 722 | + let batch1 = RecordBatch::try_new( |
| 723 | + schema1.clone(), |
| 724 | + vec![ |
| 725 | + Arc::new(Int32Array::from(vec![1, 2])), |
| 726 | + Arc::new(StringArray::from(vec!["a", "b"])), |
| 727 | + ], |
| 728 | + ) |
| 729 | + .unwrap(); |
| 730 | + let batch2 = RecordBatch::try_new( |
| 731 | + schema2, |
| 732 | + vec![ |
| 733 | + Arc::new(Int32Array::from(vec![3, 4])), |
| 734 | + Arc::new(StringArray::from(vec!["c", "d"])), |
| 735 | + ], |
| 736 | + ) |
| 737 | + .unwrap(); |
| 738 | + let error = RecordBatch::concat(&schema1, &[batch1, batch2]).unwrap_err(); |
| 739 | + assert_eq!( |
| 740 | + error.to_string(), |
| 741 | + "Invalid argument error: batches[1] schema is different with argument schema.", |
| 742 | + ); |
| 743 | + } |
642 | 744 | }
|
0 commit comments