From 7815ac09c1df77e688c10ba008dd1eea6b995d91 Mon Sep 17 00:00:00 2001 From: hzuo Date: Mon, 25 Mar 2024 19:50:12 -0400 Subject: [PATCH 01/12] fix --- arrow-buffer/src/buffer/scalar.rs | 8 +- arrow-flight/src/decode.rs | 1 + arrow-flight/src/sql/client.rs | 1 + arrow-flight/src/utils.rs | 1 + .../integration_test.rs | 1 + .../integration_test.rs | 2 + arrow-ipc/src/reader.rs | 166 ++++++++++--- arrow-ipc/src/reader/stream.rs | 2 + arrow-ipc/src/writer.rs | 220 +++++++++++++++--- 9 files changed, 343 insertions(+), 59 deletions(-) diff --git a/arrow-buffer/src/buffer/scalar.rs b/arrow-buffer/src/buffer/scalar.rs index 2019cc79830d..4c3bd44fd1ad 100644 --- a/arrow-buffer/src/buffer/scalar.rs +++ b/arrow-buffer/src/buffer/scalar.rs @@ -63,9 +63,15 @@ impl ScalarBuffer { /// This method will panic if /// /// * `offset` or `len` would result in overflow - /// * `buffer` is not aligned to a multiple of `std::mem::size_of::` + /// * `buffer` is not aligned to a multiple of `std::mem::align_of::` /// * `bytes` is not large enough for the requested slice pub fn new(buffer: Buffer, offset: usize, len: usize) -> Self { + assert_eq!( + buffer.as_ptr().align_offset(std::mem::align_of::()), + 0, + "buffer is not aligned" + ); + let size = std::mem::size_of::(); let byte_offset = offset.checked_mul(size).expect("offset overflow"); let byte_len = len.checked_mul(size).expect("length overflow"); diff --git a/arrow-flight/src/decode.rs b/arrow-flight/src/decode.rs index afbf033eb06d..d2f155070551 100644 --- a/arrow-flight/src/decode.rs +++ b/arrow-flight/src/decode.rs @@ -308,6 +308,7 @@ impl FlightDataDecoder { &state.schema, &mut state.dictionaries_by_field, &message.version(), + false, ) .map_err(|e| { FlightError::DecodeError(format!("Error decoding ipc dictionary: {e}")) diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs index a014137f6fa9..fd20b5c813e2 100644 --- a/arrow-flight/src/sql/client.rs +++ b/arrow-flight/src/sql/client.rs @@ -590,6 +590,7 @@ pub fn arrow_data_from_flight_data( &dictionaries_by_field, None, &ipc_message.version(), + false, )?; Ok(ArrowFlightData::RecordBatch(record_batch)) } diff --git a/arrow-flight/src/utils.rs b/arrow-flight/src/utils.rs index 32716a52eb0d..3c30448146b4 100644 --- a/arrow-flight/src/utils.rs +++ b/arrow-flight/src/utils.rs @@ -99,6 +99,7 @@ pub fn flight_data_to_arrow_batch( dictionaries_by_id, None, &message.version(), + false, ) })? } diff --git a/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs index c6b5a72ca6e2..ab5c674dfd20 100644 --- a/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs +++ b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs @@ -269,6 +269,7 @@ async fn receive_batch_flight_data( &schema, dictionaries_by_id, &message.version(), + false, ) .expect("Error reading dictionary"); diff --git a/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs b/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs index 25203ecb7697..7fe35b8acc96 100644 --- a/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs +++ b/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs @@ -308,6 +308,7 @@ async fn record_batch_from_message( dictionaries_by_id, None, &message.version(), + false, ); arrow_batch_result @@ -330,6 +331,7 @@ async fn dictionary_from_message( &schema_ref, dictionaries_by_id, &message.version(), + false, ); dictionary_batch_result .map_err(|e| Status::internal(format!("Could not convert to Dictionary: {e:?}"))) diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index dd0365da4bc7..efd484ff6416 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -71,7 +71,11 @@ fn read_buffer( /// - check if the bit width of non-64-bit numbers is 64, and /// - read the buffer as 64-bit (signed integer or float), and /// - cast the 64-bit array to the appropriate data type -fn create_array(reader: &mut ArrayReader, field: &Field) -> Result { +fn create_array( + reader: &mut ArrayReader, + field: &Field, + enforce_zero_copy: bool, +) -> Result { let data_type = field.data_type(); match data_type { Utf8 | Binary | LargeBinary | LargeUtf8 => create_primitive_array( @@ -82,23 +86,37 @@ fn create_array(reader: &mut ArrayReader, field: &Field) -> Result create_primitive_array( reader.next_node(field)?, data_type, &[reader.next_buffer()?, reader.next_buffer()?], + enforce_zero_copy, ), List(ref list_field) | LargeList(ref list_field) | Map(ref list_field, _) => { let list_node = reader.next_node(field)?; let list_buffers = [reader.next_buffer()?, reader.next_buffer()?]; - let values = create_array(reader, list_field)?; - create_list_array(list_node, data_type, &list_buffers, values) + let values = create_array(reader, list_field, enforce_zero_copy)?; + create_list_array( + list_node, + data_type, + &list_buffers, + values, + enforce_zero_copy, + ) } FixedSizeList(ref list_field, _) => { let list_node = reader.next_node(field)?; let list_buffers = [reader.next_buffer()?]; - let values = create_array(reader, list_field)?; - create_list_array(list_node, data_type, &list_buffers, values) + let values = create_array(reader, list_field, enforce_zero_copy)?; + create_list_array( + list_node, + data_type, + &list_buffers, + values, + enforce_zero_copy, + ) } Struct(struct_fields) => { let struct_node = reader.next_node(field)?; @@ -109,7 +127,7 @@ fn create_array(reader: &mut ArrayReader, field: &Field) -> Result Result { let run_node = reader.next_node(field)?; - let run_ends = create_array(reader, run_ends_field)?; - let values = create_array(reader, values_field)?; + let run_ends = create_array(reader, run_ends_field, enforce_zero_copy)?; + let values = create_array(reader, values_field, enforce_zero_copy)?; let run_array_length = run_node.length() as usize; - let data = ArrayData::builder(data_type.clone()) + let builder = ArrayData::builder(data_type.clone()) .len(run_array_length) .offset(0) .add_child_data(run_ends.into_data()) - .add_child_data(values.into_data()) - .build_aligned()?; + .add_child_data(values.into_data()); + + let array_data = if enforce_zero_copy { + builder.build()? + } else { + builder.build_aligned()? + }; - Ok(make_array(data)) + Ok(make_array(array_data)) } // Create dictionary array from RecordBatch Dictionary(_, _) => { @@ -151,7 +174,13 @@ fn create_array(reader: &mut ArrayReader, field: &Field) -> Result { let union_node = reader.next_node(field)?; @@ -177,7 +206,7 @@ fn create_array(reader: &mut ArrayReader, field: &Field) -> Result Result create_primitive_array( reader.next_node(field)?, data_type, &[reader.next_buffer()?, reader.next_buffer()?], + enforce_zero_copy, ), } } @@ -218,17 +253,17 @@ fn create_primitive_array( field_node: &FieldNode, data_type: &DataType, buffers: &[Buffer], + enforce_zero_copy: bool, ) -> Result { let length = field_node.length() as usize; let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); - let array_data = match data_type { + let builder = match data_type { Utf8 | Binary | LargeBinary | LargeUtf8 => { // read 3 buffers: null buffer (optional), offsets buffer and data buffer ArrayData::builder(data_type.clone()) .len(length) .buffers(buffers[1..3].to_vec()) .null_bit_buffer(null_buffer) - .build_aligned()? } _ if data_type.is_primitive() || matches!(data_type, Boolean | FixedSizeBinary(_)) => { // read 2 buffers: null buffer (optional) and data buffer @@ -236,11 +271,16 @@ fn create_primitive_array( .len(length) .add_buffer(buffers[1].clone()) .null_bit_buffer(null_buffer) - .build_aligned()? } t => unreachable!("Data type {:?} either unsupported or not primitive", t), }; + let array_data = if enforce_zero_copy { + builder.build()? + } else { + builder.build_aligned()? + }; + Ok(make_array(array_data)) } @@ -251,6 +291,7 @@ fn create_list_array( data_type: &DataType, buffers: &[Buffer], child_array: ArrayRef, + enforce_zero_copy: bool, ) -> Result { let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); let length = field_node.length() as usize; @@ -269,7 +310,14 @@ fn create_list_array( _ => unreachable!("Cannot create list or map array from {:?}", data_type), }; - Ok(make_array(builder.build_aligned()?)) + + let array_data = if enforce_zero_copy { + builder.build()? + } else { + builder.build_aligned()? + }; + + Ok(make_array(array_data)) } /// Reads the correct number of buffers based on list type and null_count, and creates a @@ -279,6 +327,7 @@ fn create_dictionary_array( data_type: &DataType, buffers: &[Buffer], value_array: ArrayRef, + enforce_zero_copy: bool, ) -> Result { if let Dictionary(_, _) = *data_type { let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); @@ -288,7 +337,13 @@ fn create_dictionary_array( .add_child_data(value_array.into_data()) .null_bit_buffer(null_buffer); - Ok(make_array(builder.build_aligned()?)) + let array_data = if enforce_zero_copy { + builder.build()? + } else { + builder.build_aligned()? + }; + + Ok(make_array(array_data)) } else { unreachable!("Cannot create dictionary array from {:?}", data_type) } @@ -396,6 +451,7 @@ pub fn read_record_batch( dictionaries_by_id: &HashMap, projection: Option<&[usize]>, metadata: &MetadataVersion, + enforce_zero_copy: bool, ) -> Result { let buffers = batch.buffers().ok_or_else(|| { ArrowError::IpcError("Unable to get buffers from IPC RecordBatch".to_string()) @@ -425,7 +481,7 @@ pub fn read_record_batch( for (idx, field) in schema.fields().iter().enumerate() { // Create array for projected field if let Some(proj_idx) = projection.iter().position(|p| p == &idx) { - let child = create_array(&mut reader, field)?; + let child = create_array(&mut reader, field, enforce_zero_copy)?; arrays.push((proj_idx, child)); } else { reader.skip_field(field)?; @@ -441,7 +497,7 @@ pub fn read_record_batch( let mut children = vec![]; // keep track of index as lists require more than one node for field in schema.fields() { - let child = create_array(&mut reader, field)?; + let child = create_array(&mut reader, field, enforce_zero_copy)?; children.push(child); } RecordBatch::try_new_with_options(schema, children, &options) @@ -456,6 +512,7 @@ pub fn read_dictionary( schema: &Schema, dictionaries_by_id: &mut HashMap, metadata: &MetadataVersion, + enforce_zero_copy: bool, ) -> Result<(), ArrowError> { if batch.isDelta() { return Err(ArrowError::InvalidArgumentError( @@ -485,6 +542,7 @@ pub fn read_dictionary( dictionaries_by_id, None, metadata, + enforce_zero_copy, )?; Some(record_batch.column(0).clone()) } @@ -609,6 +667,7 @@ pub struct FileDecoder { dictionaries: HashMap, version: MetadataVersion, projection: Option>, + enforce_zero_copy: bool, } impl FileDecoder { @@ -619,6 +678,7 @@ impl FileDecoder { version, dictionaries: Default::default(), projection: None, + enforce_zero_copy: false, } } @@ -628,6 +688,11 @@ impl FileDecoder { self } + pub fn with_enforce_zero_copy(mut self, enforce_zero_copy: bool) -> Self { + self.enforce_zero_copy = enforce_zero_copy; + self + } + fn read_message<'a>(&self, buf: &'a [u8]) -> Result, ArrowError> { let message = parse_message(buf)?; @@ -652,6 +717,7 @@ impl FileDecoder { &self.schema, &mut self.dictionaries, &message.version(), + self.enforce_zero_copy, ) } t => Err(ArrowError::ParseError(format!( @@ -683,6 +749,7 @@ impl FileDecoder { &self.dictionaries, self.projection.as_deref(), &message.version(), + self.enforce_zero_copy, ) .map(Some) } @@ -1125,6 +1192,7 @@ impl StreamReader { &self.dictionaries_by_id, self.projection.as_ref().map(|x| x.0.as_ref()), &message.version(), + false, ) .map(Some) } @@ -1144,6 +1212,7 @@ impl StreamReader { &self.schema, &mut self.dictionaries_by_id, &message.version(), + false, )?; // read the next message until we encounter a RecordBatch @@ -1801,11 +1870,56 @@ mod tests { &Default::default(), None, &message.version(), + false, ) .unwrap(); assert_eq!(batch, roundtrip); } + #[test] + fn test_unaligned_throws_error_with_enforce_zero_copy() { + let batch = RecordBatch::try_from_iter(vec![( + "i32", + Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as _, + )]) + .unwrap(); + + let gen = IpcDataGenerator {}; + let mut dict_tracker = DictionaryTracker::new(false); + let (_, encoded) = gen + .encoded_batch(&batch, &mut dict_tracker, &Default::default()) + .unwrap(); + + let message = root_as_message(&encoded.ipc_message).unwrap(); + + // Construct an unaligned buffer + let mut buffer = MutableBuffer::with_capacity(encoded.arrow_data.len() + 1); + buffer.push(0_u8); + buffer.extend_from_slice(&encoded.arrow_data); + let b = Buffer::from(buffer).slice(1); + assert_ne!(b.as_ptr().align_offset(8), 0); + + let ipc_batch = message.header_as_record_batch().unwrap(); + let result = read_record_batch( + &b, + ipc_batch, + batch.schema(), + &Default::default(), + None, + &message.version(), + true, + ); + + let error = result.unwrap_err(); + match error { + ArrowError::InvalidArgumentError(e) => { + assert!(e.contains("Misaligned")); + assert!(e.contains("offset from expected alignment of")); + } + _ => panic!("Expected InvalidArgumentError"), + } + } + #[test] fn test_file_with_massive_column_count() { // 499_999 is upper limit for default settings (1_000_000) diff --git a/arrow-ipc/src/reader/stream.rs b/arrow-ipc/src/reader/stream.rs index 7807228175ac..56fe37bbaff1 100644 --- a/arrow-ipc/src/reader/stream.rs +++ b/arrow-ipc/src/reader/stream.rs @@ -199,6 +199,7 @@ impl StreamDecoder { &self.dictionaries, None, &version, + false, )?; self.state = DecoderState::default(); return Ok(Some(batch)); @@ -214,6 +215,7 @@ impl StreamDecoder { schema, &mut self.dictionaries, &version, + false, )?; self.state = DecoderState::default(); } diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index 22edfbc2454d..586012a1cfce 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -43,8 +43,9 @@ use crate::CONTINUATION_MARKER; #[derive(Debug, Clone)] pub struct IpcWriteOptions { /// Write padding after memory buffers to this multiple of bytes. - /// Generally 8 or 64, defaults to 64 - alignment: usize, + /// Generally 8 or 64, defaults to 64. + /// Must be a multiple of 8 and must be between 8 and 64 inclusive. + alignment: u8, /// The legacy format is for releases before 0.15.0, and uses metadata V4 write_legacy_ipc_format: bool, /// The metadata version to write. The Rust IPC writer supports V4+ @@ -83,13 +84,13 @@ impl IpcWriteOptions { } /// Try create IpcWriteOptions, checking for incompatible settings pub fn try_new( - alignment: usize, + alignment: u8, write_legacy_ipc_format: bool, metadata_version: crate::MetadataVersion, ) -> Result { - if alignment == 0 || alignment % 8 != 0 { + if 8 <= alignment && alignment <= 64 && alignment % 8 != 0 { return Err(ArrowError::InvalidArgumentError( - "Alignment should be greater than 0 and be a multiple of 8".to_string(), + "Alignment should be a multiple of 8 in the range [8, 64]".to_string(), )); } match metadata_version { @@ -428,8 +429,8 @@ impl IpcDataGenerator { } // pad the tail of body data let len = arrow_data.len(); - let pad_len = pad_to_8(len as u32); - arrow_data.extend_from_slice(&vec![0u8; pad_len][..]); + let pad_len = pad_to_alignment(write_options.alignment, len); + arrow_data.extend_from_slice(&PADDING[..pad_len]); // write data let buffers = fbb.create_vector(&buffers); @@ -503,8 +504,8 @@ impl IpcDataGenerator { // pad the tail of body data let len = arrow_data.len(); - let pad_len = pad_to_8(len as u32); - arrow_data.extend_from_slice(&vec![0u8; pad_len][..]); + let pad_len = pad_to_alignment(write_options.alignment, len); + arrow_data.extend_from_slice(&PADDING[..pad_len]); // write data let buffers = fbb.create_vector(&buffers); @@ -716,11 +717,11 @@ impl FileWriter { ) -> Result { let data_gen = IpcDataGenerator::default(); let mut writer = BufWriter::new(writer); - // write magic to header aligned on 8 byte boundary - let header_size = super::ARROW_MAGIC.len() + 2; - assert_eq!(header_size, 8); - writer.write_all(&super::ARROW_MAGIC[..])?; - writer.write_all(&[0, 0])?; + // write magic to header aligned on alignment boundary + let pad_len = pad_to_alignment(write_options.alignment, super::ARROW_MAGIC.len()); + let header_size = super::ARROW_MAGIC.len() + pad_len; + writer.write_all(&super::ARROW_MAGIC)?; + writer.write_all(&PADDING[..pad_len])?; // write the schema, set the written bytes to the schema + header let encoded_message = data_gen.schema_to_bytes(schema, &write_options); let (meta, data) = write_message(&mut writer, encoded_message, &write_options)?; @@ -1017,13 +1018,13 @@ pub fn write_message( write_options: &IpcWriteOptions, ) -> Result<(usize, usize), ArrowError> { let arrow_data_len = encoded.arrow_data.len(); - if arrow_data_len % 8 != 0 { + if arrow_data_len % usize::from(write_options.alignment) != 0 { return Err(ArrowError::MemoryError( "Arrow data not aligned".to_string(), )); } - let a = write_options.alignment - 1; + let a = usize::from(write_options.alignment - 1); let buffer = encoded.ipc_message; let flatbuf_size = buffer.len(); let prefix_size = if write_options.write_legacy_ipc_format { @@ -1045,11 +1046,11 @@ pub fn write_message( writer.write_all(&buffer)?; } // write padding - writer.write_all(&vec![0; padding_bytes])?; + writer.write_all(&PADDING[..padding_bytes])?; // write arrow data let body_len = if arrow_data_len > 0 { - write_body_buffers(&mut writer, &encoded.arrow_data)? + write_body_buffers(&mut writer, &encoded.arrow_data, write_options.alignment)? } else { 0 }; @@ -1057,15 +1058,19 @@ pub fn write_message( Ok((aligned_size, body_len)) } -fn write_body_buffers(mut writer: W, data: &[u8]) -> Result { - let len = data.len() as u32; - let pad_len = pad_to_8(len) as u32; +fn write_body_buffers( + mut writer: W, + data: &[u8], + alignment: u8, +) -> Result { + let len = data.len(); + let pad_len = pad_to_alignment(alignment, len); let total_len = len + pad_len; // write body buffer writer.write_all(data)?; if pad_len > 0 { - writer.write_all(&vec![0u8; pad_len as usize][..])?; + writer.write_all(&PADDING[..pad_len])?; } writer.flush()?; @@ -1234,6 +1239,7 @@ fn write_array_data( arrow_data, offset, compression_codec, + write_options.alignment, )?; } @@ -1247,6 +1253,7 @@ fn write_array_data( arrow_data, offset, compression_codec, + write_options.alignment, )?; } } else if matches!(data_type, DataType::LargeBinary | DataType::LargeUtf8) { @@ -1258,6 +1265,7 @@ fn write_array_data( arrow_data, offset, compression_codec, + write_options.alignment, )?; } } else if DataType::is_numeric(data_type) @@ -1283,7 +1291,14 @@ fn write_array_data( } else { buffer.as_slice() }; - offset = write_buffer(buffer_slice, buffers, arrow_data, offset, compression_codec)?; + offset = write_buffer( + buffer_slice, + buffers, + arrow_data, + offset, + compression_codec, + write_options.alignment, + )?; } else if matches!(data_type, DataType::Boolean) { // Bools are special because the payload (= 1 bit) is smaller than the physical container elements (= bytes). // The array data may not start at the physical boundary of the underlying buffer, so we need to shift bits around. @@ -1291,7 +1306,14 @@ fn write_array_data( let buffer = &array_data.buffers()[0]; let buffer = buffer.bit_slice(array_data.offset(), array_data.len()); - offset = write_buffer(&buffer, buffers, arrow_data, offset, compression_codec)?; + offset = write_buffer( + &buffer, + buffers, + arrow_data, + offset, + compression_codec, + write_options.alignment, + )?; } else if matches!( data_type, DataType::List(_) | DataType::LargeList(_) | DataType::Map(_, _) @@ -1312,6 +1334,7 @@ fn write_array_data( arrow_data, offset, compression_codec, + write_options.alignment, )?; offset = write_array_data( &sliced_child_data, @@ -1327,7 +1350,14 @@ fn write_array_data( return Ok(offset); } else { for buffer in array_data.buffers() { - offset = write_buffer(buffer, buffers, arrow_data, offset, compression_codec)?; + offset = write_buffer( + buffer, + buffers, + arrow_data, + offset, + compression_codec, + write_options.alignment, + )?; } } @@ -1391,6 +1421,7 @@ fn write_buffer( arrow_data: &mut Vec, // output stream offset: i64, // current output stream offset compression_codec: Option, + alignment: u8, ) -> Result { let len: i64 = match compression_codec { Some(compressor) => compressor.compress_to_vec(buffer, arrow_data)?, @@ -1406,17 +1437,20 @@ fn write_buffer( // make new index entry buffers.push(crate::Buffer::new(offset, len)); - // padding and make offset 8 bytes aligned - let pad_len = pad_to_8(len as u32) as i64; - arrow_data.extend_from_slice(&vec![0u8; pad_len as usize][..]); + // padding and make offset aligned + let pad_len = pad_to_alignment(alignment, len as usize); + arrow_data.extend_from_slice(&PADDING[..pad_len]); - Ok(offset + len + pad_len) + Ok(offset + len + (pad_len as i64)) } -/// Calculate an 8-byte boundary and return the number of bytes needed to pad to 8 bytes +const PADDING: [u8; 64] = [0; 64]; + +/// Calculate an alignment boundary and return the number of bytes needed to pad to the alignment boundary #[inline] -fn pad_to_8(len: u32) -> usize { - (((len + 7) & !7) - len) as usize +fn pad_to_alignment(alignment: u8, len: usize) -> usize { + let a = usize::from(alignment - 1); + ((len + a) & !a) - len } #[cfg(test)] @@ -1430,7 +1464,9 @@ mod tests { use arrow_array::builder::{PrimitiveRunBuilder, UInt32Builder}; use arrow_array::types::*; + use crate::convert::fb_to_schema; use crate::reader::*; + use crate::root_as_footer; use crate::MetadataVersion; use super::*; @@ -2234,4 +2270,124 @@ mod tests { let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap(); roundtrip_ensure_sliced_smaller(in_batch, 1000); } + + #[test] + fn test_decimal128_alignment16() { + const IPC_ALIGNMENT: u8 = 16; + + for num_cols in 1..100 { + let num_rows = (num_cols * 7 + 11) % 100; // Deterministic swizzle + + let mut fields = Vec::new(); + let mut arrays = Vec::new(); + for i in 0..num_cols { + let field = Field::new(&format!("col_{}", i), DataType::Decimal128(38, 10), true); + let array = Decimal128Array::from(vec![num_cols as i128; num_rows]); + fields.push(field); + arrays.push(Arc::new(array) as Arc); + } + let schema = Schema::new(fields); + let batch = RecordBatch::try_new(Arc::new(schema), arrays).unwrap(); + + let mut writer = FileWriter::try_new_with_options( + Vec::new(), + batch.schema_ref(), + IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(), + ) + .unwrap(); + writer.write(&batch).unwrap(); + writer.finish().unwrap(); + + let out: Vec = writer.into_inner().unwrap(); + + let buffer = Buffer::from_vec(out); + let trailer_start = buffer.len() - 10; + let footer_len = + read_footer_length(buffer[trailer_start..].try_into().unwrap()).unwrap(); + let footer = + root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap(); + + let schema = fb_to_schema(footer.schema().unwrap()); + assert_eq!(&schema, batch.schema().as_ref()); + + let decoder = + FileDecoder::new(Arc::new(schema), footer.version()).with_enforce_zero_copy(true); + + assert_eq!(footer.dictionaries().unwrap().len(), 0); + + let batches = footer.recordBatches().unwrap(); + assert_eq!(batches.len(), 1); + + let block = batches.get(0); + let block_len = block.bodyLength() as usize + block.metaDataLength() as usize; + let data = buffer.slice_with_length(block.offset() as _, block_len); + + let batch2 = decoder.read_record_batch(block, &data).unwrap().unwrap(); + + assert_eq!(batch, batch2); + } + } + + #[test] + fn test_decimal128_alignment8_is_unaligned() { + const IPC_ALIGNMENT: u8 = 8; + + let num_cols = 2; + let num_rows = 1; + + let mut fields = Vec::new(); + let mut arrays = Vec::new(); + for i in 0..num_cols { + let field = Field::new(&format!("col_{}", i), DataType::Decimal128(38, 10), true); + let array = Decimal128Array::from(vec![num_cols as i128; num_rows]); + fields.push(field); + arrays.push(Arc::new(array) as Arc); + } + let schema = Schema::new(fields); + let batch = RecordBatch::try_new(Arc::new(schema), arrays).unwrap(); + + let mut writer = FileWriter::try_new_with_options( + Vec::new(), + batch.schema_ref(), + IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(), + ) + .unwrap(); + writer.write(&batch).unwrap(); + writer.finish().unwrap(); + + let out: Vec = writer.into_inner().unwrap(); + + let buffer = Buffer::from_vec(out); + let trailer_start = buffer.len() - 10; + let footer_len = read_footer_length(buffer[trailer_start..].try_into().unwrap()).unwrap(); + let footer = root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap(); + + let schema = fb_to_schema(footer.schema().unwrap()); + assert_eq!(&schema, batch.schema().as_ref()); + + // Importantly we enforce zero copy, otherwise the error later is suppressed due to copying + // to an aligned buffer in `ArrayDataBuilder.build_aligned`. + let decoder = + FileDecoder::new(Arc::new(schema), footer.version()).with_enforce_zero_copy(true); + + assert_eq!(footer.dictionaries().unwrap().len(), 0); + + let batches = footer.recordBatches().unwrap(); + assert_eq!(batches.len(), 1); + + let block = batches.get(0); + let block_len = block.bodyLength() as usize + block.metaDataLength() as usize; + let data = buffer.slice_with_length(block.offset() as _, block_len); + + let result = decoder.read_record_batch(block, &data); + + let error = result.unwrap_err(); + match error { + ArrowError::InvalidArgumentError(e) => { + assert!(e.contains("Misaligned")); + assert!(e.contains("offset from expected alignment of")); + } + _ => panic!("Expected InvalidArgumentError"), + } + } } From ccca72256e7213606a39d5b093346004b19b37ca Mon Sep 17 00:00:00 2001 From: hzuo Date: Wed, 27 Mar 2024 00:08:19 -0400 Subject: [PATCH 02/12] Remove redundant alignment check --- arrow-buffer/src/buffer/scalar.rs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/arrow-buffer/src/buffer/scalar.rs b/arrow-buffer/src/buffer/scalar.rs index 4c3bd44fd1ad..343b8549e93d 100644 --- a/arrow-buffer/src/buffer/scalar.rs +++ b/arrow-buffer/src/buffer/scalar.rs @@ -66,12 +66,6 @@ impl ScalarBuffer { /// * `buffer` is not aligned to a multiple of `std::mem::align_of::` /// * `bytes` is not large enough for the requested slice pub fn new(buffer: Buffer, offset: usize, len: usize) -> Self { - assert_eq!( - buffer.as_ptr().align_offset(std::mem::align_of::()), - 0, - "buffer is not aligned" - ); - let size = std::mem::size_of::(); let byte_offset = offset.checked_mul(size).expect("offset overflow"); let byte_len = len.checked_mul(size).expect("length overflow"); From 224df9f2a76393c8679b835ec46f2c75046dd00e Mon Sep 17 00:00:00 2001 From: hzuo Date: Wed, 27 Mar 2024 00:11:41 -0400 Subject: [PATCH 03/12] fix typo --- arrow-ipc/src/writer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index 586012a1cfce..7d7685b28686 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -88,7 +88,7 @@ impl IpcWriteOptions { write_legacy_ipc_format: bool, metadata_version: crate::MetadataVersion, ) -> Result { - if 8 <= alignment && alignment <= 64 && alignment % 8 != 0 { + if alignment < 8 || alignment > 64 || alignment % 8 != 0 { return Err(ArrowError::InvalidArgumentError( "Alignment should be a multiple of 8 in the range [8, 64]".to_string(), )); From c936ef0d76ef0228af81db2059aa27b78e0bcae3 Mon Sep 17 00:00:00 2001 From: hzuo Date: Wed, 27 Mar 2024 00:16:26 -0400 Subject: [PATCH 04/12] Add comment about randomized testing --- arrow-ipc/src/writer.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index 7d7685b28686..310e48678ce4 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -2275,6 +2275,10 @@ mod tests { fn test_decimal128_alignment16() { const IPC_ALIGNMENT: u8 = 16; + // Test a bunch of different dimensions to ensure alignment is never an issue. + // For example, if we only test `num_cols = 1` then even with alignment 8 this + // test would _happen_ to pass, even though for different dimensions like + // `num_cols = 2` it would fail. for num_cols in 1..100 { let num_rows = (num_cols * 7 + 11) % 100; // Deterministic swizzle From f0e0b7307727c3ad926edb819d5ec50389ba8518 Mon Sep 17 00:00:00 2001 From: hzuo Date: Wed, 27 Mar 2024 00:27:12 -0400 Subject: [PATCH 05/12] add doc comment on enforce_zero_copy --- arrow-ipc/src/reader.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index efd484ff6416..6960ec8187dd 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -688,6 +688,14 @@ impl FileDecoder { self } + /// Specify whether to enforce zero copy, which means the array data inside the record batches + /// produced by this decoder will always reference the input buffer. + /// + /// In particular, this means that if there is data that is not aligned properly in the + /// input buffer, the decoder will throw rather than copy the data to an aligned buffer. + /// + /// By default `enforce_zero_copy` is false, meaning it will allocate new buffers and copy + /// data if necessary, e.g. if the alignment is not correct. pub fn with_enforce_zero_copy(mut self, enforce_zero_copy: bool) -> Self { self.enforce_zero_copy = enforce_zero_copy; self From 308fd7c652b829033e5e25a26f64d75e96f58002 Mon Sep 17 00:00:00 2001 From: hzuo Date: Mon, 1 Apr 2024 17:07:53 -0400 Subject: [PATCH 06/12] address alamb feedback --- arrow-ipc/src/reader.rs | 107 ++++++++++++++++++++++------------------ arrow-ipc/src/writer.rs | 39 ++++++--------- 2 files changed, 75 insertions(+), 71 deletions(-) diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index bea33d7f26ad..e01fb449fa09 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -78,7 +78,7 @@ fn create_array( reader: &mut ArrayReader, field: &Field, variadic_counts: &mut VecDeque, - enforce_zero_copy: bool, + require_alignment: bool, ) -> Result { let data_type = field.data_type(); match data_type { @@ -90,7 +90,7 @@ fn create_array( reader.next_buffer()?, reader.next_buffer()?, ], - enforce_zero_copy, + require_alignment, ), BinaryView | Utf8View => { let count = variadic_counts @@ -106,37 +106,37 @@ fn create_array( reader.next_node(field)?, data_type, &buffers, - enforce_zero_copy, + require_alignment, ) } FixedSizeBinary(_) => create_primitive_array( reader.next_node(field)?, data_type, &[reader.next_buffer()?, reader.next_buffer()?], - enforce_zero_copy, + require_alignment, ), List(ref list_field) | LargeList(ref list_field) | Map(ref list_field, _) => { let list_node = reader.next_node(field)?; let list_buffers = [reader.next_buffer()?, reader.next_buffer()?]; - let values = create_array(reader, list_field, variadic_counts, enforce_zero_copy)?; + let values = create_array(reader, list_field, variadic_counts, require_alignment)?; create_list_array( list_node, data_type, &list_buffers, values, - enforce_zero_copy, + require_alignment, ) } FixedSizeList(ref list_field, _) => { let list_node = reader.next_node(field)?; let list_buffers = [reader.next_buffer()?]; - let values = create_array(reader, list_field, variadic_counts, enforce_zero_copy)?; + let values = create_array(reader, list_field, variadic_counts, require_alignment)?; create_list_array( list_node, data_type, &list_buffers, values, - enforce_zero_copy, + require_alignment, ) } Struct(struct_fields) => { @@ -148,7 +148,7 @@ fn create_array( // TODO investigate whether just knowing the number of buffers could // still work for struct_field in struct_fields { - let child = create_array(reader, struct_field, variadic_counts, enforce_zero_copy)?; + let child = create_array(reader, struct_field, variadic_counts, require_alignment)?; struct_arrays.push((struct_field.clone(), child)); } let null_count = struct_node.null_count() as usize; @@ -163,8 +163,8 @@ fn create_array( RunEndEncoded(run_ends_field, values_field) => { let run_node = reader.next_node(field)?; let run_ends = - create_array(reader, run_ends_field, variadic_counts, enforce_zero_copy)?; - let values = create_array(reader, values_field, variadic_counts, enforce_zero_copy)?; + create_array(reader, run_ends_field, variadic_counts, require_alignment)?; + let values = create_array(reader, values_field, variadic_counts, require_alignment)?; let run_array_length = run_node.length() as usize; let builder = ArrayData::builder(data_type.clone()) @@ -173,7 +173,7 @@ fn create_array( .add_child_data(run_ends.into_data()) .add_child_data(values.into_data()); - let array_data = if enforce_zero_copy { + let array_data = if require_alignment { builder.build()? } else { builder.build_aligned()? @@ -201,7 +201,7 @@ fn create_array( data_type, &index_buffers, value_array.clone(), - enforce_zero_copy, + require_alignment, ) } Union(fields, mode) => { @@ -228,7 +228,7 @@ fn create_array( let mut ids = Vec::with_capacity(fields.len()); for (id, field) in fields.iter() { - let child = create_array(reader, field, variadic_counts, enforce_zero_copy)?; + let child = create_array(reader, field, variadic_counts, require_alignment)?; children.push((field.as_ref().clone(), child)); ids.push(id); } @@ -251,7 +251,7 @@ fn create_array( .len(length as usize) .offset(0); - let array_data = if enforce_zero_copy { + let array_data = if require_alignment { builder.build()? } else { builder.build_aligned()? @@ -264,7 +264,7 @@ fn create_array( reader.next_node(field)?, data_type, &[reader.next_buffer()?, reader.next_buffer()?], - enforce_zero_copy, + require_alignment, ), } } @@ -275,7 +275,7 @@ fn create_primitive_array( field_node: &FieldNode, data_type: &DataType, buffers: &[Buffer], - enforce_zero_copy: bool, + require_alignment: bool, ) -> Result { let length = field_node.length() as usize; let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); @@ -301,7 +301,7 @@ fn create_primitive_array( t => unreachable!("Data type {:?} either unsupported or not primitive", t), }; - let array_data = if enforce_zero_copy { + let array_data = if require_alignment { builder.build()? } else { builder.build_aligned()? @@ -317,7 +317,7 @@ fn create_list_array( data_type: &DataType, buffers: &[Buffer], child_array: ArrayRef, - enforce_zero_copy: bool, + require_alignment: bool, ) -> Result { let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); let length = field_node.length() as usize; @@ -337,7 +337,7 @@ fn create_list_array( _ => unreachable!("Cannot create list or map array from {:?}", data_type), }; - let array_data = if enforce_zero_copy { + let array_data = if require_alignment { builder.build()? } else { builder.build_aligned()? @@ -353,7 +353,7 @@ fn create_dictionary_array( data_type: &DataType, buffers: &[Buffer], value_array: ArrayRef, - enforce_zero_copy: bool, + require_alignment: bool, ) -> Result { if let Dictionary(_, _) = *data_type { let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); @@ -363,7 +363,7 @@ fn create_dictionary_array( .add_child_data(value_array.into_data()) .null_bit_buffer(null_buffer); - let array_data = if enforce_zero_copy { + let array_data = if require_alignment { builder.build()? } else { builder.build_aligned()? @@ -485,7 +485,16 @@ impl<'a> ArrayReader<'a> { } } -/// Creates a record batch from binary data using the `crate::RecordBatch` indexes and the `Schema` +/// Creates a record batch from binary data using the `crate::RecordBatch` indexes and the `Schema`. +/// +/// If `require_alignment` is true, this function will return an error if any array data in the +/// input `buf` is not properly aligned. +/// Under the hood it will use [`arrow_data::ArrayDataBuilder::build`] to construct [`arrow_data::ArrayData`]. +/// +/// If `require_alignment` is false, this function will automatically allocate a new aligned buffer +/// and copy over the data if any array data in the input `buf` is not properly aligned. +/// (Properly aligned array data will remain zero-copy.) +/// Under the hood it will use [`arrow_data::ArrayDataBuilder::build_aligned`] to construct [`arrow_data::ArrayData`]. pub fn read_record_batch( buf: &Buffer, batch: crate::RecordBatch, @@ -493,7 +502,7 @@ pub fn read_record_batch( dictionaries_by_id: &HashMap, projection: Option<&[usize]>, metadata: &MetadataVersion, - enforce_zero_copy: bool, + require_alignment: bool, ) -> Result { let buffers = batch.buffers().ok_or_else(|| { ArrowError::IpcError("Unable to get buffers from IPC RecordBatch".to_string()) @@ -528,7 +537,7 @@ pub fn read_record_batch( // Create array for projected field if let Some(proj_idx) = projection.iter().position(|p| p == &idx) { let child = - create_array(&mut reader, field, &mut variadic_counts, enforce_zero_copy)?; + create_array(&mut reader, field, &mut variadic_counts, require_alignment)?; arrays.push((proj_idx, child)); } else { reader.skip_field(field, &mut variadic_counts)?; @@ -545,7 +554,7 @@ pub fn read_record_batch( let mut children = vec![]; // keep track of index as lists require more than one node for field in schema.fields() { - let child = create_array(&mut reader, field, &mut variadic_counts, enforce_zero_copy)?; + let child = create_array(&mut reader, field, &mut variadic_counts, require_alignment)?; children.push(child); } assert!(variadic_counts.is_empty()); @@ -561,7 +570,7 @@ pub fn read_dictionary( schema: &Schema, dictionaries_by_id: &mut HashMap, metadata: &MetadataVersion, - enforce_zero_copy: bool, + require_alignment: bool, ) -> Result<(), ArrowError> { if batch.isDelta() { return Err(ArrowError::InvalidArgumentError( @@ -591,7 +600,7 @@ pub fn read_dictionary( dictionaries_by_id, None, metadata, - enforce_zero_copy, + require_alignment, )?; Some(record_batch.column(0).clone()) } @@ -716,7 +725,7 @@ pub struct FileDecoder { dictionaries: HashMap, version: MetadataVersion, projection: Option>, - enforce_zero_copy: bool, + require_alignment: bool, } impl FileDecoder { @@ -727,7 +736,7 @@ impl FileDecoder { version, dictionaries: Default::default(), projection: None, - enforce_zero_copy: false, + require_alignment: false, } } @@ -737,16 +746,20 @@ impl FileDecoder { self } - /// Specify whether to enforce zero copy, which means the array data inside the record batches - /// produced by this decoder will always reference the input buffer. + /// Specifies whether or not array data in input buffers is required to be properly aligned. /// - /// In particular, this means that if there is data that is not aligned properly in the - /// input buffer, the decoder will throw rather than copy the data to an aligned buffer. + /// If `require_alignment` is true, this decoder will return an error if any array data in the + /// input `buf` is not properly aligned. + /// Under the hood it will use [`arrow_data::ArrayDataBuilder::build`] to construct + /// [`arrow_data::ArrayData`]. /// - /// By default `enforce_zero_copy` is false, meaning it will allocate new buffers and copy - /// data if necessary, e.g. if the alignment is not correct. - pub fn with_enforce_zero_copy(mut self, enforce_zero_copy: bool) -> Self { - self.enforce_zero_copy = enforce_zero_copy; + /// If `require_alignment` is false (the default), this decoder will automatically allocate a + /// new aligned buffer and copy over the data if any array data in the input `buf` is not + /// properly aligned. (Properly aligned array data will remain zero-copy.) + /// Under the hood it will use [`arrow_data::ArrayDataBuilder::build_aligned`] to construct + /// [`arrow_data::ArrayData`]. + pub fn with_require_alignment(mut self, require_alignment: bool) -> Self { + self.require_alignment = require_alignment; self } @@ -774,7 +787,7 @@ impl FileDecoder { &self.schema, &mut self.dictionaries, &message.version(), - self.enforce_zero_copy, + self.require_alignment, ) } t => Err(ArrowError::ParseError(format!( @@ -806,7 +819,7 @@ impl FileDecoder { &self.dictionaries, self.projection.as_deref(), &message.version(), - self.enforce_zero_copy, + self.require_alignment, ) .map(Some) } @@ -2049,7 +2062,7 @@ mod tests { } #[test] - fn test_unaligned_throws_error_with_enforce_zero_copy() { + fn test_unaligned_throws_error_with_require_alignment() { let batch = RecordBatch::try_from_iter(vec![( "i32", Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as _, @@ -2083,13 +2096,11 @@ mod tests { ); let error = result.unwrap_err(); - match error { - ArrowError::InvalidArgumentError(e) => { - assert!(e.contains("Misaligned")); - assert!(e.contains("offset from expected alignment of")); - } - _ => panic!("Expected InvalidArgumentError"), - } + assert_eq!( + error.to_string(), + "Invalid argument error: Misaligned buffers[0] in array of type Int32, \ + offset from expected alignment of 4 by 1" + ); } #[test] diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index 09dad4ca9887..5ed8250021a6 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -84,7 +84,7 @@ impl IpcWriteOptions { } /// Try create IpcWriteOptions, checking for incompatible settings pub fn try_new( - alignment: u8, + alignment: usize, write_legacy_ipc_format: bool, metadata_version: crate::MetadataVersion, ) -> Result { @@ -93,6 +93,7 @@ impl IpcWriteOptions { "Alignment should be a multiple of 8 in the range [8, 64]".to_string(), )); } + let alignment: u8 = u8::try_from(alignment).expect("range already checked"); match metadata_version { crate::MetadataVersion::V1 | crate::MetadataVersion::V2 @@ -2384,14 +2385,14 @@ mod tests { } #[test] - fn test_decimal128_alignment16() { - const IPC_ALIGNMENT: u8 = 16; + fn test_decimal128_alignment16_is_sufficient() { + const IPC_ALIGNMENT: usize = 16; // Test a bunch of different dimensions to ensure alignment is never an issue. // For example, if we only test `num_cols = 1` then even with alignment 8 this // test would _happen_ to pass, even though for different dimensions like // `num_cols = 2` it would fail. - for num_cols in 1..100 { + for num_cols in [1, 2, 3, 17, 50, 73, 99] { let num_rows = (num_cols * 7 + 11) % 100; // Deterministic swizzle let mut fields = Vec::new(); @@ -2424,15 +2425,13 @@ mod tests { root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap(); let schema = fb_to_schema(footer.schema().unwrap()); - assert_eq!(&schema, batch.schema().as_ref()); + // Importantly we set `require_alignment`, checking that 16-byte alignment is sufficient + // for `read_record_batch` later on to read the data in a zero-copy manner. let decoder = - FileDecoder::new(Arc::new(schema), footer.version()).with_enforce_zero_copy(true); - - assert_eq!(footer.dictionaries().unwrap().len(), 0); + FileDecoder::new(Arc::new(schema), footer.version()).with_require_alignment(true); let batches = footer.recordBatches().unwrap(); - assert_eq!(batches.len(), 1); let block = batches.get(0); let block_len = block.bodyLength() as usize + block.metaDataLength() as usize; @@ -2446,7 +2445,7 @@ mod tests { #[test] fn test_decimal128_alignment8_is_unaligned() { - const IPC_ALIGNMENT: u8 = 8; + const IPC_ALIGNMENT: usize = 8; let num_cols = 2; let num_rows = 1; @@ -2479,17 +2478,13 @@ mod tests { let footer = root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap(); let schema = fb_to_schema(footer.schema().unwrap()); - assert_eq!(&schema, batch.schema().as_ref()); - // Importantly we enforce zero copy, otherwise the error later is suppressed due to copying + // Importantly we set `require_alignment`, otherwise the error later is suppressed due to copying // to an aligned buffer in `ArrayDataBuilder.build_aligned`. let decoder = - FileDecoder::new(Arc::new(schema), footer.version()).with_enforce_zero_copy(true); - - assert_eq!(footer.dictionaries().unwrap().len(), 0); + FileDecoder::new(Arc::new(schema), footer.version()).with_require_alignment(true); let batches = footer.recordBatches().unwrap(); - assert_eq!(batches.len(), 1); let block = batches.get(0); let block_len = block.bodyLength() as usize + block.metaDataLength() as usize; @@ -2498,12 +2493,10 @@ mod tests { let result = decoder.read_record_batch(block, &data); let error = result.unwrap_err(); - match error { - ArrowError::InvalidArgumentError(e) => { - assert!(e.contains("Misaligned")); - assert!(e.contains("offset from expected alignment of")); - } - _ => panic!("Expected InvalidArgumentError"), - } + assert_eq!( + error.to_string(), + "Invalid argument error: Misaligned buffers[0] in array of type Decimal128(38, 10), \ + offset from expected alignment of 16 by 8" + ); } } From 7675444f9e5fe81bb29c15ee1b11b735f7795c85 Mon Sep 17 00:00:00 2001 From: hzuo Date: Mon, 1 Apr 2024 17:17:27 -0400 Subject: [PATCH 07/12] be explicit --- arrow-ipc/src/writer.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index 5ed8250021a6..c356c9e15237 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -43,8 +43,7 @@ use crate::CONTINUATION_MARKER; #[derive(Debug, Clone)] pub struct IpcWriteOptions { /// Write padding after memory buffers to this multiple of bytes. - /// Generally 8 or 64, defaults to 64. - /// Must be a multiple of 8 and must be between 8 and 64 inclusive. + /// Must be 8, 16, 32, or 64 - defaults to 64. alignment: u8, /// The legacy format is for releases before 0.15.0, and uses metadata V4 write_legacy_ipc_format: bool, @@ -88,9 +87,9 @@ impl IpcWriteOptions { write_legacy_ipc_format: bool, metadata_version: crate::MetadataVersion, ) -> Result { - if alignment < 8 || alignment > 64 || alignment % 8 != 0 { + if alignment != 8 || alignment != 16 || alignment != 32 || alignment != 64 { return Err(ArrowError::InvalidArgumentError( - "Alignment should be a multiple of 8 in the range [8, 64]".to_string(), + "Alignment should be 8, 16, 32, or 64.".to_string(), )); } let alignment: u8 = u8::try_from(alignment).expect("range already checked"); From 8c0a3f23330c6f143a87e6434d4ad4bd28392710 Mon Sep 17 00:00:00 2001 From: hzuo Date: Tue, 2 Apr 2024 13:51:32 -0400 Subject: [PATCH 08/12] fix unit tests --- arrow-ipc/src/writer.rs | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index c356c9e15237..9bcc332333f9 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -87,7 +87,9 @@ impl IpcWriteOptions { write_legacy_ipc_format: bool, metadata_version: crate::MetadataVersion, ) -> Result { - if alignment != 8 || alignment != 16 || alignment != 32 || alignment != 64 { + let is_alignment_valid = + alignment == 8 || alignment == 16 || alignment == 32 || alignment == 64; + if !is_alignment_valid { return Err(ArrowError::InvalidArgumentError( "Alignment should be 8, 16, 32, or 64.".to_string(), )); @@ -1545,7 +1547,17 @@ mod tests { } fn serialize_stream(record: &RecordBatch) -> Vec { - let mut stream_writer = StreamWriter::try_new(vec![], record.schema_ref()).unwrap(); + // Use a smaller-than-default IPC alignment so that the various `truncate_*` tests can be + // compactly written, without needing to construct a giant array to spill over the 64-byte + // default alignment boundary. + const IPC_ALIGNMENT: usize = 8; + + let mut stream_writer = StreamWriter::try_new_with_options( + vec![], + record.schema_ref(), + IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(), + ) + .unwrap(); stream_writer.write(record).unwrap(); stream_writer.finish().unwrap(); stream_writer.into_inner().unwrap() From f1dc10e95be2b68fcf67cc34576ac9f22e90a3e0 Mon Sep 17 00:00:00 2001 From: hzuo Date: Tue, 2 Apr 2024 15:52:49 -0400 Subject: [PATCH 09/12] fix arrow-flight tests --- arrow-flight/src/encode.rs | 6 +++++- arrow-ipc/src/writer.rs | 6 +++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index efd688129485..7604f3cd4d62 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -627,6 +627,7 @@ mod tests { use arrow_array::{cast::downcast_array, types::*}; use arrow_buffer::Buffer; use arrow_cast::pretty::pretty_format_batches; + use arrow_ipc::MetadataVersion; use arrow_schema::UnionMode; use std::collections::HashMap; @@ -638,7 +639,8 @@ mod tests { /// ensure only the batch's used data (not the allocated data) is sent /// fn test_encode_flight_data() { - let options = IpcWriteOptions::default(); + // use 8-byte alignment - default alignment is 64 which produces bigger ipc data + let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap(); let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]); let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c1) as ArrayRef)]) @@ -1343,6 +1345,8 @@ mod tests { let mut stream = FlightDataEncoderBuilder::new() .with_max_flight_data_size(max_flight_data_size) + // use 8-byte alignment - default alignment is 64 which produces bigger ipc data + .with_options(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap()) .build(futures::stream::iter([Ok(batch.clone())])); let mut i = 0; diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index 9bcc332333f9..e541c9fd1d31 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -1547,9 +1547,9 @@ mod tests { } fn serialize_stream(record: &RecordBatch) -> Vec { - // Use a smaller-than-default IPC alignment so that the various `truncate_*` tests can be - // compactly written, without needing to construct a giant array to spill over the 64-byte - // default alignment boundary. + // Use 8-byte alignment so that the various `truncate_*` tests can be compactly written, + // without needing to construct a giant array to spill over the 64-byte default alignment + // boundary. const IPC_ALIGNMENT: usize = 8; let mut stream_writer = StreamWriter::try_new_with_options( From a1cb1a615fe9ee902017970d56aac6fbe8b914f2 Mon Sep 17 00:00:00 2001 From: hzuo Date: Tue, 2 Apr 2024 16:26:53 -0400 Subject: [PATCH 10/12] clippy --- arrow-ipc/src/writer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index e541c9fd1d31..97136bd97c2f 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -1120,7 +1120,7 @@ fn write_body_buffers( } writer.flush()?; - Ok(total_len as usize) + Ok(total_len) } /// Write a record batch to the writer, writing the message size before the message From 368f8235624280345322445b1311433e5fa9e6f3 Mon Sep 17 00:00:00 2001 From: hzuo Date: Wed, 3 Apr 2024 16:55:42 -0400 Subject: [PATCH 11/12] hide api change + add require_alignment to StreamDecoder too --- arrow-flight/src/decode.rs | 1 - arrow-flight/src/sql/client.rs | 1 - arrow-flight/src/utils.rs | 1 - .../integration_test.rs | 1 - .../integration_test.rs | 2 - arrow-ipc/src/reader.rs | 47 +++++++++++++++---- arrow-ipc/src/reader/stream.rs | 29 ++++++++++-- 7 files changed, 62 insertions(+), 20 deletions(-) diff --git a/arrow-flight/src/decode.rs b/arrow-flight/src/decode.rs index d2f155070551..afbf033eb06d 100644 --- a/arrow-flight/src/decode.rs +++ b/arrow-flight/src/decode.rs @@ -308,7 +308,6 @@ impl FlightDataDecoder { &state.schema, &mut state.dictionaries_by_field, &message.version(), - false, ) .map_err(|e| { FlightError::DecodeError(format!("Error decoding ipc dictionary: {e}")) diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs index 7bd7e46a6a95..44250fbe63e2 100644 --- a/arrow-flight/src/sql/client.rs +++ b/arrow-flight/src/sql/client.rs @@ -613,7 +613,6 @@ pub fn arrow_data_from_flight_data( &dictionaries_by_field, None, &ipc_message.version(), - false, )?; Ok(ArrowFlightData::RecordBatch(record_batch)) } diff --git a/arrow-flight/src/utils.rs b/arrow-flight/src/utils.rs index 3c30448146b4..32716a52eb0d 100644 --- a/arrow-flight/src/utils.rs +++ b/arrow-flight/src/utils.rs @@ -99,7 +99,6 @@ pub fn flight_data_to_arrow_batch( dictionaries_by_id, None, &message.version(), - false, ) })? } diff --git a/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs index ab5c674dfd20..c6b5a72ca6e2 100644 --- a/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs +++ b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs @@ -269,7 +269,6 @@ async fn receive_batch_flight_data( &schema, dictionaries_by_id, &message.version(), - false, ) .expect("Error reading dictionary"); diff --git a/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs b/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs index 7fe35b8acc96..25203ecb7697 100644 --- a/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs +++ b/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs @@ -308,7 +308,6 @@ async fn record_batch_from_message( dictionaries_by_id, None, &message.version(), - false, ); arrow_batch_result @@ -331,7 +330,6 @@ async fn dictionary_from_message( &schema_ref, dictionaries_by_id, &message.version(), - false, ); dictionary_batch_result .map_err(|e| Status::internal(format!("Could not convert to Dictionary: {e:?}"))) diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index e01fb449fa09..d31e9af1489b 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -485,6 +485,35 @@ impl<'a> ArrayReader<'a> { } } +pub fn read_record_batch( + buf: &Buffer, + batch: crate::RecordBatch, + schema: SchemaRef, + dictionaries_by_id: &HashMap, + projection: Option<&[usize]>, + metadata: &MetadataVersion, +) -> Result { + read_record_batch2( + buf, + batch, + schema, + dictionaries_by_id, + projection, + metadata, + false, + ) +} + +pub fn read_dictionary( + buf: &Buffer, + batch: crate::DictionaryBatch, + schema: &Schema, + dictionaries_by_id: &mut HashMap, + metadata: &MetadataVersion, +) -> Result<(), ArrowError> { + read_dictionary2(buf, batch, schema, dictionaries_by_id, metadata, false) +} + /// Creates a record batch from binary data using the `crate::RecordBatch` indexes and the `Schema`. /// /// If `require_alignment` is true, this function will return an error if any array data in the @@ -495,7 +524,7 @@ impl<'a> ArrayReader<'a> { /// and copy over the data if any array data in the input `buf` is not properly aligned. /// (Properly aligned array data will remain zero-copy.) /// Under the hood it will use [`arrow_data::ArrayDataBuilder::build_aligned`] to construct [`arrow_data::ArrayData`]. -pub fn read_record_batch( +fn read_record_batch2( buf: &Buffer, batch: crate::RecordBatch, schema: SchemaRef, @@ -564,7 +593,7 @@ pub fn read_record_batch( /// Read the dictionary from the buffer and provided metadata, /// updating the `dictionaries_by_id` with the resulting dictionary -pub fn read_dictionary( +fn read_dictionary2( buf: &Buffer, batch: crate::DictionaryBatch, schema: &Schema, @@ -593,7 +622,7 @@ pub fn read_dictionary( let value = value_type.as_ref().clone(); let schema = Schema::new(vec![Field::new("", value, true)]); // Read a single column - let record_batch = read_record_batch( + let record_batch = read_record_batch2( buf, batch.data().unwrap(), Arc::new(schema), @@ -781,7 +810,7 @@ impl FileDecoder { match message.header_type() { crate::MessageHeader::DictionaryBatch => { let batch = message.header_as_dictionary_batch().unwrap(); - read_dictionary( + read_dictionary2( &buf.slice(block.metaDataLength() as _), batch, &self.schema, @@ -812,7 +841,7 @@ impl FileDecoder { ArrowError::IpcError("Unable to read IPC message as record batch".to_string()) })?; // read the block that makes up the record batch into a buffer - read_record_batch( + read_record_batch2( &buf.slice(block.metaDataLength() as _), batch, self.schema.clone(), @@ -1255,7 +1284,7 @@ impl StreamReader { let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize); self.reader.read_exact(&mut buf)?; - read_record_batch( + read_record_batch2( &buf.into(), batch, self.schema(), @@ -1276,7 +1305,7 @@ impl StreamReader { let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize); self.reader.read_exact(&mut buf)?; - read_dictionary( + read_dictionary2( &buf.into(), batch, &self.schema, @@ -2048,7 +2077,7 @@ mod tests { assert_ne!(b.as_ptr().align_offset(8), 0); let ipc_batch = message.header_as_record_batch().unwrap(); - let roundtrip = read_record_batch( + let roundtrip = read_record_batch2( &b, ipc_batch, batch.schema(), @@ -2085,7 +2114,7 @@ mod tests { assert_ne!(b.as_ptr().align_offset(8), 0); let ipc_batch = message.header_as_record_batch().unwrap(); - let result = read_record_batch( + let result = read_record_batch2( &b, ipc_batch, batch.schema(), diff --git a/arrow-ipc/src/reader/stream.rs b/arrow-ipc/src/reader/stream.rs index 56fe37bbaff1..9fde3ee42c5d 100644 --- a/arrow-ipc/src/reader/stream.rs +++ b/arrow-ipc/src/reader/stream.rs @@ -24,7 +24,7 @@ use arrow_buffer::{Buffer, MutableBuffer}; use arrow_schema::{ArrowError, SchemaRef}; use crate::convert::MessageBuffer; -use crate::reader::{read_dictionary, read_record_batch}; +use crate::reader::{read_dictionary2, read_record_batch2}; use crate::{MessageHeader, CONTINUATION_MARKER}; /// A low-level interface for reading [`RecordBatch`] data from a stream of bytes @@ -40,6 +40,8 @@ pub struct StreamDecoder { state: DecoderState, /// A scratch buffer when a read is split across multiple `Buffer` buf: MutableBuffer, + /// Whether or not array data in input buffers are required to be aligned + require_alignment: bool, } #[derive(Debug)] @@ -83,6 +85,23 @@ impl StreamDecoder { Self::default() } + /// Specifies whether or not array data in input buffers is required to be properly aligned. + /// + /// If `require_alignment` is true, this decoder will return an error if any array data in the + /// input `buf` is not properly aligned. + /// Under the hood it will use [`arrow_data::ArrayDataBuilder::build`] to construct + /// [`arrow_data::ArrayData`]. + /// + /// If `require_alignment` is false (the default), this decoder will automatically allocate a + /// new aligned buffer and copy over the data if any array data in the input `buf` is not + /// properly aligned. (Properly aligned array data will remain zero-copy.) + /// Under the hood it will use [`arrow_data::ArrayDataBuilder::build_aligned`] to construct + /// [`arrow_data::ArrayData`]. + pub fn with_require_alignment(mut self, require_alignment: bool) -> Self { + self.require_alignment = require_alignment; + self + } + /// Try to read the next [`RecordBatch`] from the provided [`Buffer`] /// /// [`Buffer::advance`] will be called on `buffer` for any consumed bytes. @@ -192,14 +211,14 @@ impl StreamDecoder { let schema = self.schema.clone().ok_or_else(|| { ArrowError::IpcError("Missing schema".to_string()) })?; - let batch = read_record_batch( + let batch = read_record_batch2( &body, batch, schema, &self.dictionaries, None, &version, - false, + self.require_alignment, )?; self.state = DecoderState::default(); return Ok(Some(batch)); @@ -209,13 +228,13 @@ impl StreamDecoder { let schema = self.schema.as_deref().ok_or_else(|| { ArrowError::IpcError("Missing schema".to_string()) })?; - read_dictionary( + read_dictionary2( &body, dictionary, schema, &mut self.dictionaries, &version, - false, + self.require_alignment, )?; self.state = DecoderState::default(); } From 4d726bde57b32458911f3b28e77fcb20face6293 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Thu, 4 Apr 2024 11:12:28 +0100 Subject: [PATCH 12/12] Preserve docs --- arrow-ipc/src/reader.rs | 46 +++++++++++++++++----------------- arrow-ipc/src/reader/stream.rs | 6 ++--- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index d31e9af1489b..8eac17e20761 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -485,6 +485,16 @@ impl<'a> ArrayReader<'a> { } } +/// Creates a record batch from binary data using the `crate::RecordBatch` indexes and the `Schema`. +/// +/// If `require_alignment` is true, this function will return an error if any array data in the +/// input `buf` is not properly aligned. +/// Under the hood it will use [`arrow_data::ArrayDataBuilder::build`] to construct [`arrow_data::ArrayData`]. +/// +/// If `require_alignment` is false, this function will automatically allocate a new aligned buffer +/// and copy over the data if any array data in the input `buf` is not properly aligned. +/// (Properly aligned array data will remain zero-copy.) +/// Under the hood it will use [`arrow_data::ArrayDataBuilder::build_aligned`] to construct [`arrow_data::ArrayData`]. pub fn read_record_batch( buf: &Buffer, batch: crate::RecordBatch, @@ -493,7 +503,7 @@ pub fn read_record_batch( projection: Option<&[usize]>, metadata: &MetadataVersion, ) -> Result { - read_record_batch2( + read_record_batch_impl( buf, batch, schema, @@ -504,6 +514,8 @@ pub fn read_record_batch( ) } +/// Read the dictionary from the buffer and provided metadata, +/// updating the `dictionaries_by_id` with the resulting dictionary pub fn read_dictionary( buf: &Buffer, batch: crate::DictionaryBatch, @@ -511,20 +523,10 @@ pub fn read_dictionary( dictionaries_by_id: &mut HashMap, metadata: &MetadataVersion, ) -> Result<(), ArrowError> { - read_dictionary2(buf, batch, schema, dictionaries_by_id, metadata, false) + read_dictionary_impl(buf, batch, schema, dictionaries_by_id, metadata, false) } -/// Creates a record batch from binary data using the `crate::RecordBatch` indexes and the `Schema`. -/// -/// If `require_alignment` is true, this function will return an error if any array data in the -/// input `buf` is not properly aligned. -/// Under the hood it will use [`arrow_data::ArrayDataBuilder::build`] to construct [`arrow_data::ArrayData`]. -/// -/// If `require_alignment` is false, this function will automatically allocate a new aligned buffer -/// and copy over the data if any array data in the input `buf` is not properly aligned. -/// (Properly aligned array data will remain zero-copy.) -/// Under the hood it will use [`arrow_data::ArrayDataBuilder::build_aligned`] to construct [`arrow_data::ArrayData`]. -fn read_record_batch2( +fn read_record_batch_impl( buf: &Buffer, batch: crate::RecordBatch, schema: SchemaRef, @@ -591,9 +593,7 @@ fn read_record_batch2( } } -/// Read the dictionary from the buffer and provided metadata, -/// updating the `dictionaries_by_id` with the resulting dictionary -fn read_dictionary2( +fn read_dictionary_impl( buf: &Buffer, batch: crate::DictionaryBatch, schema: &Schema, @@ -622,7 +622,7 @@ fn read_dictionary2( let value = value_type.as_ref().clone(); let schema = Schema::new(vec![Field::new("", value, true)]); // Read a single column - let record_batch = read_record_batch2( + let record_batch = read_record_batch_impl( buf, batch.data().unwrap(), Arc::new(schema), @@ -810,7 +810,7 @@ impl FileDecoder { match message.header_type() { crate::MessageHeader::DictionaryBatch => { let batch = message.header_as_dictionary_batch().unwrap(); - read_dictionary2( + read_dictionary_impl( &buf.slice(block.metaDataLength() as _), batch, &self.schema, @@ -841,7 +841,7 @@ impl FileDecoder { ArrowError::IpcError("Unable to read IPC message as record batch".to_string()) })?; // read the block that makes up the record batch into a buffer - read_record_batch2( + read_record_batch_impl( &buf.slice(block.metaDataLength() as _), batch, self.schema.clone(), @@ -1284,7 +1284,7 @@ impl StreamReader { let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize); self.reader.read_exact(&mut buf)?; - read_record_batch2( + read_record_batch_impl( &buf.into(), batch, self.schema(), @@ -1305,7 +1305,7 @@ impl StreamReader { let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize); self.reader.read_exact(&mut buf)?; - read_dictionary2( + read_dictionary_impl( &buf.into(), batch, &self.schema, @@ -2077,7 +2077,7 @@ mod tests { assert_ne!(b.as_ptr().align_offset(8), 0); let ipc_batch = message.header_as_record_batch().unwrap(); - let roundtrip = read_record_batch2( + let roundtrip = read_record_batch_impl( &b, ipc_batch, batch.schema(), @@ -2114,7 +2114,7 @@ mod tests { assert_ne!(b.as_ptr().align_offset(8), 0); let ipc_batch = message.header_as_record_batch().unwrap(); - let result = read_record_batch2( + let result = read_record_batch_impl( &b, ipc_batch, batch.schema(), diff --git a/arrow-ipc/src/reader/stream.rs b/arrow-ipc/src/reader/stream.rs index 9fde3ee42c5d..64191a22b33e 100644 --- a/arrow-ipc/src/reader/stream.rs +++ b/arrow-ipc/src/reader/stream.rs @@ -24,7 +24,7 @@ use arrow_buffer::{Buffer, MutableBuffer}; use arrow_schema::{ArrowError, SchemaRef}; use crate::convert::MessageBuffer; -use crate::reader::{read_dictionary2, read_record_batch2}; +use crate::reader::{read_dictionary_impl, read_record_batch_impl}; use crate::{MessageHeader, CONTINUATION_MARKER}; /// A low-level interface for reading [`RecordBatch`] data from a stream of bytes @@ -211,7 +211,7 @@ impl StreamDecoder { let schema = self.schema.clone().ok_or_else(|| { ArrowError::IpcError("Missing schema".to_string()) })?; - let batch = read_record_batch2( + let batch = read_record_batch_impl( &body, batch, schema, @@ -228,7 +228,7 @@ impl StreamDecoder { let schema = self.schema.as_deref().ok_or_else(|| { ArrowError::IpcError("Missing schema".to_string()) })?; - read_dictionary2( + read_dictionary_impl( &body, dictionary, schema,