diff --git a/CMakeLists.txt b/CMakeLists.txt index 5769f6b..47fab25 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -133,6 +133,7 @@ set(SPARROW_IPC_SRC ${SPARROW_IPC_SOURCE_DIR}/arrow_interface/arrow_schema/private_data.cpp ${SPARROW_IPC_SOURCE_DIR}/chunk_memory_serializer.cpp ${SPARROW_IPC_SOURCE_DIR}/compression.cpp + ${SPARROW_IPC_SOURCE_DIR}/compression_impl.hpp ${SPARROW_IPC_SOURCE_DIR}/deserialize_fixedsizebinary_array.cpp ${SPARROW_IPC_SOURCE_DIR}/deserialize_utils.cpp ${SPARROW_IPC_SOURCE_DIR}/deserialize.cpp diff --git a/include/sparrow_ipc/chunk_memory_serializer.hpp b/include/sparrow_ipc/chunk_memory_serializer.hpp index 4868a42..26e0277 100644 --- a/include/sparrow_ipc/chunk_memory_serializer.hpp +++ b/include/sparrow_ipc/chunk_memory_serializer.hpp @@ -8,10 +8,9 @@ #include -#include "Message_generated.h" - #include "sparrow_ipc/any_output_stream.hpp" #include "sparrow_ipc/chunk_memory_output_stream.hpp" +#include "sparrow_ipc/compression.hpp" #include "sparrow_ipc/config/config.hpp" #include "sparrow_ipc/memory_output_stream.hpp" #include "sparrow_ipc/serialize.hpp" @@ -44,8 +43,7 @@ namespace sparrow_ipc * @param stream Reference to a chunked memory output stream that will receive the serialized chunks * @param compression Optional: The compression type to use for record batch bodies. */ - // TODO Use enums and such to avoid including flatbuffers headers - chunk_serializer(chunked_memory_output_stream>>& stream, std::optional compression = std::nullopt); + chunk_serializer(chunked_memory_output_stream>>& stream, std::optional compression = std::nullopt); /** * @brief Writes a single record batch to the chunked stream. @@ -131,7 +129,7 @@ namespace sparrow_ipc std::vector m_dtypes; chunked_memory_output_stream>>* m_pstream; bool m_ended{false}; - std::optional m_compression; + std::optional m_compression; }; // Implementation diff --git a/include/sparrow_ipc/compression.hpp b/include/sparrow_ipc/compression.hpp index 96b47ec..9c92a16 100644 --- a/include/sparrow_ipc/compression.hpp +++ b/include/sparrow_ipc/compression.hpp @@ -5,24 +5,21 @@ #include #include -#include "Message_generated.h" - #include "sparrow_ipc/config/config.hpp" namespace sparrow_ipc { -// TODO use these later if needed for wrapping purposes (flatbuffers/lz4) -// enum class CompressionType -// { -// NONE, -// LZ4, -// ZSTD -// }; - -// CompressionType to_compression_type(org::apache::arrow::flatbuf::CompressionType compression_type); + enum class CompressionType + { + LZ4_FRAME, + ZSTD + }; - constexpr auto CompressionHeaderSize = sizeof(std::int64_t); + [[nodiscard]] SPARROW_IPC_API std::vector compress( + const CompressionType compression_type, + std::span data); - [[nodiscard]] SPARROW_IPC_API std::vector compress(const org::apache::arrow::flatbuf::CompressionType compression_type, std::span data); - [[nodiscard]] SPARROW_IPC_API std::variant, std::span> decompress(const org::apache::arrow::flatbuf::CompressionType compression_type, std::span data); + [[nodiscard]] SPARROW_IPC_API std::variant, std::span> decompress( + const CompressionType compression_type, + std::span data); } diff --git a/include/sparrow_ipc/deserialize_primitive_array.hpp b/include/sparrow_ipc/deserialize_primitive_array.hpp index 40e9076..b23b93f 100644 --- a/include/sparrow_ipc/deserialize_primitive_array.hpp +++ b/include/sparrow_ipc/deserialize_primitive_array.hpp @@ -43,16 +43,7 @@ namespace sparrow_ipc if (compression) { - // TODO Handle buffers emptiness thoroughly / which is and which is not allowed... - // Validity buffers can be empty - if (validity_buffer_span.empty()) - { - buffers.push_back(validity_buffer_span); - } - else - { - buffers.push_back(utils::get_decompressed_buffer(validity_buffer_span, compression)); - } + buffers.push_back(utils::get_decompressed_buffer(validity_buffer_span, compression)); buffers.push_back(utils::get_decompressed_buffer(data_buffer_span, compression)); } else diff --git a/include/sparrow_ipc/deserialize_utils.hpp b/include/sparrow_ipc/deserialize_utils.hpp index c83ad20..97f6be9 100644 --- a/include/sparrow_ipc/deserialize_utils.hpp +++ b/include/sparrow_ipc/deserialize_utils.hpp @@ -32,27 +32,22 @@ namespace sparrow_ipc::utils ); /** - * @brief Extracts bitmap pointer and null count from a RecordBatch buffer. - * - * This function retrieves a bitmap buffer from the specified index in the RecordBatch's - * buffer list and calculates the number of null values represented by the bitmap. + * @brief Extracts a buffer from a RecordBatch's body. * - * @param record_batch The Arrow RecordBatch containing buffer metadata - * @param body The raw buffer data as a byte span - * @param index The index of the bitmap buffer in the RecordBatch's buffer list + * This function retrieves a buffer span from the specified index in the RecordBatch's + * buffer list and increments the index. * - * @return A pair containing: - * - First: Pointer to the bitmap data (nullptr if buffer is empty) - * - Second: Count of null values in the bitmap (0 if buffer is empty) + * @param record_batch The Arrow RecordBatch containing buffer metadata. + * @param body The raw buffer data as a byte span. + * @param buffer_index The index of the buffer to retrieve. This value is incremented by the function. * - * @note If the bitmap buffer has zero length, returns {nullptr, 0} - * @note The returned pointer is a non-const cast of the original const data + * @return A `std::span` viewing the extracted buffer data. + * @throws std::runtime_error if the buffer metadata indicates a buffer that exceeds the body size. */ - // TODO to be removed when not used anymore (after adding compression to deserialize_fixedsizebinary_array) - [[nodiscard]] std::pair get_bitmap_pointer_and_null_count( + [[nodiscard]] std::span get_buffer( const org::apache::arrow::flatbuf::RecordBatch& record_batch, std::span body, - size_t index + size_t& buffer_index ); /** @@ -72,23 +67,4 @@ namespace sparrow_ipc::utils std::span buffer_span, const org::apache::arrow::flatbuf::BodyCompression* compression ); - - /** - * @brief Extracts a buffer from a RecordBatch's body. - * - * This function retrieves a buffer span from the specified index in the RecordBatch's - * buffer list and increments the index. - * - * @param record_batch The Arrow RecordBatch containing buffer metadata. - * @param body The raw buffer data as a byte span. - * @param buffer_index The index of the buffer to retrieve. This value is incremented by the function. - * - * @return A `std::span` viewing the extracted buffer data. - * @throws std::runtime_error if the buffer metadata indicates a buffer that exceeds the body size. - */ - [[nodiscard]] std::span get_buffer( - const org::apache::arrow::flatbuf::RecordBatch& record_batch, - std::span body, - size_t& buffer_index - ); } diff --git a/include/sparrow_ipc/deserialize_variable_size_binary_array.hpp b/include/sparrow_ipc/deserialize_variable_size_binary_array.hpp index dab2792..b56e6b5 100644 --- a/include/sparrow_ipc/deserialize_variable_size_binary_array.hpp +++ b/include/sparrow_ipc/deserialize_variable_size_binary_array.hpp @@ -41,15 +41,7 @@ namespace sparrow_ipc if (compression) { - // Validity buffers can be empty - if (validity_buffer_span.empty()) - { - buffers.push_back(validity_buffer_span); - } - else - { - buffers.push_back(utils::get_decompressed_buffer(validity_buffer_span, compression)); - } + buffers.push_back(utils::get_decompressed_buffer(validity_buffer_span, compression)); buffers.push_back(utils::get_decompressed_buffer(offset_buffer_span, compression)); buffers.push_back(utils::get_decompressed_buffer(data_buffer_span, compression)); } diff --git a/include/sparrow_ipc/flatbuffer_utils.hpp b/include/sparrow_ipc/flatbuffer_utils.hpp index 78094bf..864055e 100644 --- a/include/sparrow_ipc/flatbuffer_utils.hpp +++ b/include/sparrow_ipc/flatbuffer_utils.hpp @@ -5,6 +5,9 @@ #include #include +#include "sparrow_ipc/compression.hpp" +#include "sparrow_ipc/utils.hpp" + namespace sparrow_ipc { // Creates a Flatbuffers Decimal type from a format string @@ -164,6 +167,42 @@ namespace sparrow_ipc [[nodiscard]] std::vector create_fieldnodes(const sparrow::record_batch& record_batch); + namespace details + { + template + void fill_buffers_impl( + const sparrow::arrow_proxy& arrow_proxy, + std::vector& flatbuf_buffers, + int64_t& offset, + Func&& get_buffer_size + ) + { + const auto& buffers = arrow_proxy.buffers(); + for (const auto& buffer : buffers) + { + int64_t size = get_buffer_size(buffer); + flatbuf_buffers.emplace_back(offset, size); + offset += utils::align_to_8(size); + } + for (const auto& child : arrow_proxy.children()) + { + fill_buffers_impl(child, flatbuf_buffers, offset, get_buffer_size); + } + } + + template + std::vector get_buffers_impl(const sparrow::record_batch& record_batch, Func&& fill_buffers_func) + { + std::vector buffers; + int64_t offset = 0; + for (const auto& column : record_batch.columns()) + { + const auto& arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(column); + fill_buffers_func(arrow_proxy, buffers, offset); + } + return buffers; + } + } // namespace details /** * @brief Recursively fills a vector of FlatBuffer Buffer objects with buffer information from an Arrow @@ -205,6 +244,67 @@ namespace sparrow_ipc [[nodiscard]] std::vector get_buffers(const sparrow::record_batch& record_batch); + /** + * @brief Recursively populates a vector with compressed buffer metadata from an Arrow proxy. + * + * This function traverses the Arrow proxy and its children, compressing each buffer and recording + * its metadata (offset and size) in the provided vector. The offset is updated to ensure proper + * alignment for each subsequent buffer. + * + * @param arrow_proxy The Arrow proxy containing the buffers to be compressed. + * @param flatbuf_compressed_buffers A vector to store the resulting compressed buffer metadata. + * @param offset The current offset in the buffer layout, which will be updated by the function. + * @param compression_type The compression algorithm to use. + */ + void fill_compressed_buffers( + const sparrow::arrow_proxy& arrow_proxy, + std::vector& flatbuf_compressed_buffers, + int64_t& offset, + const CompressionType compression_type + ); + + /** + * @brief Retrieves metadata describing the layout of compressed buffers within a record batch. + * + * This function processes a record batch to determine the metadata (offset and size) + * for each of its buffers, assuming they are compressed using the specified algorithm. + * This metadata accounts for each compressed buffer being prefixed by its 8-byte + * uncompressed size and padded to ensure 8-byte alignment. + * + * @param record_batch The record batch whose buffers' compressed metadata is to be retrieved. + * @param compression_type The compression algorithm that would be applied (e.g., LZ4_FRAME, ZSTD). + * @return A vector of FlatBuffer Buffer objects, each describing the offset and + * size of a corresponding compressed buffer within a larger message body. + */ + [[nodiscard]] std::vector + get_compressed_buffers(const sparrow::record_batch& record_batch, const CompressionType compression_type); + + /** + * @brief Calculates the total size of the body section for an Arrow array. + * + * This function recursively computes the total size needed for all buffers + * in an Arrow array structure, including buffers from child arrays. Each + * buffer size is aligned to 8-byte boundaries as required by the Arrow format. + * + * @param arrow_proxy The Arrow array proxy containing buffers and child arrays + * @param compression The compression type to use when serializing + * @return int64_t The total aligned size in bytes of all buffers in the array hierarchy + */ + [[nodiscard]] int64_t calculate_body_size(const sparrow::arrow_proxy& arrow_proxy, std::optional compression = std::nullopt); + + /** + * @brief Calculates the total body size of a record batch by summing the body sizes of all its columns. + * + * This function iterates through all columns in the given record batch and accumulates + * the body size of each column's underlying Arrow array proxy. The body size represents + * the total memory required for the serialized data content of the record batch. + * + * @param record_batch The sparrow record batch containing columns to calculate size for + * @param compression The compression type to use when serializing + * @return int64_t The total body size in bytes of all columns in the record batch + */ + [[nodiscard]] int64_t calculate_body_size(const sparrow::record_batch& record_batch, std::optional compression = std::nullopt); + /** * @brief Creates a FlatBuffer message containing a serialized Apache Arrow RecordBatch. * @@ -222,5 +322,5 @@ namespace sparrow_ipc * @note Variadic buffer counts is not currently implemented (set to 0) */ [[nodiscard]] flatbuffers::FlatBufferBuilder - get_record_batch_message_builder(const sparrow::record_batch& record_batch, std::optional compression = std::nullopt); + get_record_batch_message_builder(const sparrow::record_batch& record_batch, std::optional compression = std::nullopt); } diff --git a/include/sparrow_ipc/serialize.hpp b/include/sparrow_ipc/serialize.hpp index ab47646..b9a7334 100644 --- a/include/sparrow_ipc/serialize.hpp +++ b/include/sparrow_ipc/serialize.hpp @@ -6,6 +6,7 @@ #include "Message_generated.h" #include "sparrow_ipc/any_output_stream.hpp" +#include "sparrow_ipc/compression.hpp" #include "sparrow_ipc/config/config.hpp" #include "sparrow_ipc/magic_values.hpp" #include "sparrow_ipc/serialize_utils.hpp" @@ -36,7 +37,7 @@ namespace sparrow_ipc */ template requires std::same_as, sparrow::record_batch> - void serialize_record_batches_to_ipc_stream(const R& record_batches, any_output_stream& stream, std::optional compression) + void serialize_record_batches_to_ipc_stream(const R& record_batches, any_output_stream& stream, std::optional compression) { if (record_batches.empty()) { @@ -76,7 +77,7 @@ namespace sparrow_ipc */ SPARROW_IPC_API void - serialize_record_batch(const sparrow::record_batch& record_batch, any_output_stream& stream, std::optional compression); + serialize_record_batch(const sparrow::record_batch& record_batch, any_output_stream& stream, std::optional compression); /** * @brief Serializes a schema message for a record batch into a byte buffer. diff --git a/include/sparrow_ipc/serialize_utils.hpp b/include/sparrow_ipc/serialize_utils.hpp index 25f9a1e..4710f59 100644 --- a/include/sparrow_ipc/serialize_utils.hpp +++ b/include/sparrow_ipc/serialize_utils.hpp @@ -7,6 +7,7 @@ #include "Message_generated.h" #include "sparrow_ipc/any_output_stream.hpp" +#include "sparrow_ipc/compression.hpp" #include "sparrow_ipc/config/config.hpp" #include "sparrow_ipc/utils.hpp" @@ -43,7 +44,7 @@ namespace sparrow_ipc * @param compression The compression type to use when serializing */ SPARROW_IPC_API void - serialize_record_batch(const sparrow::record_batch& record_batch, any_output_stream& stream, std::optional compression); + serialize_record_batch(const sparrow::record_batch& record_batch, any_output_stream& stream, std::optional compression); /** * @brief Calculates the total serialized size of a schema message. @@ -77,7 +78,7 @@ namespace sparrow_ipc * @return The total size in bytes that the serialized record batch would occupy */ [[nodiscard]] SPARROW_IPC_API std::size_t - calculate_record_batch_message_size(const sparrow::record_batch& record_batch, std::optional compression = std::nullopt); + calculate_record_batch_message_size(const sparrow::record_batch& record_batch, std::optional compression = std::nullopt); /** * @brief Calculates the total serialized size for a collection of record batches. @@ -93,7 +94,7 @@ namespace sparrow_ipc */ template requires std::same_as, sparrow::record_batch> - [[nodiscard]] std::size_t calculate_total_serialized_size(const R& record_batches, std::optional compression = std::nullopt) + [[nodiscard]] std::size_t calculate_total_serialized_size(const R& record_batches, std::optional compression = std::nullopt) { if (record_batches.empty()) { @@ -118,22 +119,6 @@ namespace sparrow_ipc return total_size; } - /** - * @brief Generates the compressed message body and buffer metadata for a record batch. - * - * This function traverses the record batch, compresses each buffer using the specified - * compression algorithm, and constructs the message body. For each compressed buffer, - * it is prefixed by its 8-byte uncompressed size. Padding is added after each - * compressed buffer to ensure 8-byte alignment. - * - * @param record_batch The record batch to serialize. - * @param compression_type The compression algorithm to use (e.g., LZ4_FRAME, ZSTD). - * @return A vector of FlatBuffer Buffer objects describing the offset and - * size of each buffer within the compressed body. - */ - [[nodiscard]] SPARROW_IPC_API std::vector - generate_compressed_buffers(const sparrow::record_batch& record_batch, const org::apache::arrow::flatbuf::CompressionType compression_type); - /** * @brief Fills the body vector with serialized data from an arrow proxy and its children. * @@ -150,7 +135,7 @@ namespace sparrow_ipc * @param stream The output stream where the serialized body data will be written * @param compression The compression type to use when serializing */ - SPARROW_IPC_API void fill_body(const sparrow::arrow_proxy& arrow_proxy, any_output_stream& stream, std::optional compression = std::nullopt); + SPARROW_IPC_API void fill_body(const sparrow::arrow_proxy& arrow_proxy, any_output_stream& stream, std::optional compression = std::nullopt); /** * @brief Generates a serialized body from a record batch. @@ -163,33 +148,7 @@ namespace sparrow_ipc * @param stream The output stream where the serialized body will be written * @param compression The compression type to use when serializing */ - SPARROW_IPC_API void generate_body(const sparrow::record_batch& record_batch, any_output_stream& stream, std::optional compression = std::nullopt); - - /** - * @brief Calculates the total size of the body section for an Arrow array. - * - * This function recursively computes the total size needed for all buffers - * in an Arrow array structure, including buffers from child arrays. Each - * buffer size is aligned to 8-byte boundaries as required by the Arrow format. - * - * @param arrow_proxy The Arrow array proxy containing buffers and child arrays - * @param compression The compression type to use when serializing - * @return int64_t The total aligned size in bytes of all buffers in the array hierarchy - */ - [[nodiscard]] SPARROW_IPC_API int64_t calculate_body_size(const sparrow::arrow_proxy& arrow_proxy, std::optional compression = std::nullopt); - - /** - * @brief Calculates the total body size of a record batch by summing the body sizes of all its columns. - * - * This function iterates through all columns in the given record batch and accumulates - * the body size of each column's underlying Arrow array proxy. The body size represents - * the total memory required for the serialized data content of the record batch. - * - * @param record_batch The sparrow record batch containing columns to calculate size for - * @param compression The compression type to use when serializing - * @return int64_t The total body size in bytes of all columns in the record batch - */ - [[nodiscard]] SPARROW_IPC_API int64_t calculate_body_size(const sparrow::record_batch& record_batch, std::optional compression = std::nullopt); + SPARROW_IPC_API void generate_body(const sparrow::record_batch& record_batch, any_output_stream& stream, std::optional compression = std::nullopt); SPARROW_IPC_API std::vector get_column_dtypes(const sparrow::record_batch& rb); } diff --git a/include/sparrow_ipc/serializer.hpp b/include/sparrow_ipc/serializer.hpp index f0ebcb9..e867f76 100644 --- a/include/sparrow_ipc/serializer.hpp +++ b/include/sparrow_ipc/serializer.hpp @@ -4,6 +4,7 @@ #include #include "sparrow_ipc/any_output_stream.hpp" +#include "sparrow_ipc/compression.hpp" #include "sparrow_ipc/serialize_utils.hpp" namespace sparrow_ipc @@ -41,7 +42,7 @@ namespace sparrow_ipc * The serializer stores a pointer to this stream for later use. */ template - serializer(TStream& stream, std::optional compression = std::nullopt) + serializer(TStream& stream, std::optional compression = std::nullopt) : m_stream(stream), m_compression(compression) { } @@ -206,7 +207,7 @@ namespace sparrow_ipc std::vector m_dtypes; any_output_stream m_stream; bool m_ended{false}; - std::optional m_compression; + std::optional m_compression; }; inline serializer& end_stream(serializer& serializer) @@ -214,4 +215,4 @@ namespace sparrow_ipc serializer.end(); return serializer; } -} \ No newline at end of file +} diff --git a/src/chunk_memory_serializer.cpp b/src/chunk_memory_serializer.cpp index db2c8a2..bbe5ab4 100644 --- a/src/chunk_memory_serializer.cpp +++ b/src/chunk_memory_serializer.cpp @@ -6,7 +6,7 @@ namespace sparrow_ipc { - chunk_serializer::chunk_serializer(chunked_memory_output_stream>>& stream, std::optional compression) + chunk_serializer::chunk_serializer(chunked_memory_output_stream>>& stream, std::optional compression) : m_pstream(&stream), m_compression(compression) { } diff --git a/src/compression.cpp b/src/compression.cpp index 50e6c34..63cfe70 100644 --- a/src/compression.cpp +++ b/src/compression.cpp @@ -2,23 +2,38 @@ #include -#include "sparrow_ipc/compression.hpp" +#include "compression_impl.hpp" namespace sparrow_ipc { -// CompressionType to_compression_type(org::apache::arrow::flatbuf::CompressionType compression_type) -// { -// switch (compression_type) -// { -// case org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME: -// return CompressionType::LZ4; -// // case org::apache::arrow::flatbuf::CompressionType::ZSTD: -// // // TODO: Add ZSTD support -// // break; -// default: -// return CompressionType::NONE; -// } -// } + namespace details + { + org::apache::arrow::flatbuf::CompressionType to_fb_compression_type(CompressionType compression_type) + { + switch (compression_type) + { + case CompressionType::LZ4_FRAME: + return org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME; + case CompressionType::ZSTD: + throw std::invalid_argument("Compression using zstd is not supported yet."); + default: + throw std::invalid_argument("Unsupported compression type."); + } + } + + CompressionType from_fb_compression_type(org::apache::arrow::flatbuf::CompressionType compression_type) + { + switch (compression_type) + { + case org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME: + return CompressionType::LZ4_FRAME; + case org::apache::arrow::flatbuf::CompressionType::ZSTD: + throw std::invalid_argument("Compression using zstd is not supported yet."); + default: + throw std::invalid_argument("Unsupported compression type."); + } + } + } // namespace details namespace { @@ -57,7 +72,7 @@ namespace sparrow_ipc std::vector uncompressed_data_with_header(std::span data) { std::vector result; - result.reserve(CompressionHeaderSize + data.size()); + result.reserve(details::CompressionHeaderSize + data.size()); const std::int64_t header = -1; result.insert(result.end(), reinterpret_cast(&header), reinterpret_cast(&header) + sizeof(header)); result.insert(result.end(), data.begin(), data.end()); @@ -75,7 +90,7 @@ namespace sparrow_ipc } std::vector result; - result.reserve(CompressionHeaderSize + compressed_body.size()); + result.reserve(details::CompressionHeaderSize + compressed_body.size()); result.insert(result.end(), reinterpret_cast(&original_size), reinterpret_cast(&original_size) + sizeof(original_size)); result.insert(result.end(), compressed_body.begin(), compressed_body.end()); return result; @@ -83,12 +98,12 @@ namespace sparrow_ipc std::variant, std::span> lz4_decompress_with_header(std::span data) { - if (data.size() < CompressionHeaderSize) + if (data.size() < details::CompressionHeaderSize) { throw std::runtime_error("Invalid compressed data: missing decompressed size"); } const std::int64_t decompressed_size = *reinterpret_cast(data.data()); - const auto compressed_data = data.subspan(CompressionHeaderSize); + const auto compressed_data = data.subspan(details::CompressionHeaderSize); if (decompressed_size == -1) { @@ -100,32 +115,32 @@ namespace sparrow_ipc std::span get_body_from_uncompressed_data(std::span data) { - if (data.size() < CompressionHeaderSize) + if (data.size() < details::CompressionHeaderSize) { throw std::runtime_error("Invalid data: missing header"); } - return data.subspan(CompressionHeaderSize); + return data.subspan(details::CompressionHeaderSize); } - } + } // namespace - std::vector compress(const org::apache::arrow::flatbuf::CompressionType compression_type, std::span data) + std::vector compress(const CompressionType compression_type, std::span data) { switch (compression_type) { - case org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME: + case CompressionType::LZ4_FRAME: { return lz4_compress_with_header(data); } - case org::apache::arrow::flatbuf::CompressionType::ZSTD: + case CompressionType::ZSTD: { - throw std::runtime_error("Compression using zstd is not supported yet."); + throw std::invalid_argument("Compression using zstd is not supported yet."); } default: return uncompressed_data_with_header(data); } } - std::variant, std::span> decompress(const org::apache::arrow::flatbuf::CompressionType compression_type, std::span data) + std::variant, std::span> decompress(const CompressionType compression_type, std::span data) { if (data.empty()) { @@ -134,13 +149,13 @@ namespace sparrow_ipc switch (compression_type) { - case org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME: + case CompressionType::LZ4_FRAME: { return lz4_decompress_with_header(data); } - case org::apache::arrow::flatbuf::CompressionType::ZSTD: + case CompressionType::ZSTD: { - throw std::runtime_error("Decompression using zstd is not supported yet."); + throw std::invalid_argument("Decompression using zstd is not supported yet."); } default: { diff --git a/src/compression_impl.hpp b/src/compression_impl.hpp new file mode 100644 index 0000000..00737bc --- /dev/null +++ b/src/compression_impl.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include + +#include "Message_generated.h" + +#include "sparrow_ipc/compression.hpp" + +namespace sparrow_ipc +{ + namespace details + { + constexpr auto CompressionHeaderSize = sizeof(std::int64_t); + + org::apache::arrow::flatbuf::CompressionType to_fb_compression_type(CompressionType compression_type); + CompressionType from_fb_compression_type(org::apache::arrow::flatbuf::CompressionType compression_type); + } +} diff --git a/src/deserialize_utils.cpp b/src/deserialize_utils.cpp index 6476afe..f5e93e2 100644 --- a/src/deserialize_utils.cpp +++ b/src/deserialize_utils.cpp @@ -1,10 +1,9 @@ #include "sparrow_ipc/deserialize_utils.hpp" -#include "sparrow_ipc/compression.hpp" +#include "compression_impl.hpp" namespace sparrow_ipc::utils { - // TODO check and remove unused fcts? std::pair get_bitmap_pointer_and_null_count( std::span validity_buffer_span, const int64_t length @@ -22,29 +21,6 @@ namespace sparrow_ipc::utils return {ptr, bitmap_view.null_count()}; } - std::pair get_bitmap_pointer_and_null_count( - const org::apache::arrow::flatbuf::RecordBatch& record_batch, - std::span body, - size_t index - ) - { - const auto bitmap_metadata = record_batch.buffers()->Get(index); - if (bitmap_metadata->length() == 0) - { - return {nullptr, 0}; - } - if (body.size() < (bitmap_metadata->offset() + bitmap_metadata->length())) - { - throw std::runtime_error("Bitmap buffer exceeds body size"); - } - auto ptr = const_cast(body.data() + bitmap_metadata->offset()); - const sparrow::dynamic_bitset_view bitmap_view{ - ptr, - static_cast(record_batch.length()) - }; - return {ptr, bitmap_view.null_count()}; - } - std::span get_buffer( const org::apache::arrow::flatbuf::RecordBatch& record_batch, std::span body, @@ -64,9 +40,9 @@ namespace sparrow_ipc::utils const org::apache::arrow::flatbuf::BodyCompression* compression ) { - if (compression) + if (compression && !buffer_span.empty()) { - return decompress(compression->codec(), buffer_span); + return decompress(sparrow_ipc::details::from_fb_compression_type(compression->codec()), buffer_span); } else { diff --git a/src/flatbuffer_utils.cpp b/src/flatbuffer_utils.cpp index f3e418c..d3abdfe 100644 --- a/src/flatbuffer_utils.cpp +++ b/src/flatbuffer_utils.cpp @@ -1,8 +1,7 @@ -#include "sparrow_ipc/flatbuffer_utils.hpp" #include -#include "sparrow_ipc/serialize_utils.hpp" -#include "sparrow_ipc/utils.hpp" +#include "compression_impl.hpp" +#include "sparrow_ipc/flatbuffer_utils.hpp" namespace sparrow_ipc { @@ -537,40 +536,88 @@ namespace sparrow_ipc int64_t& offset ) { - const auto& buffers = arrow_proxy.buffers(); - for (const auto& buffer : buffers) + details::fill_buffers_impl(arrow_proxy, flatbuf_buffers, offset, [](const auto& buffer) { + return static_cast(buffer.size()); + }); + } + + std::vector get_buffers(const sparrow::record_batch& record_batch) + { + return details::get_buffers_impl(record_batch, [](const sparrow::arrow_proxy& proxy, std::vector& buffers, int64_t& offset) { + fill_buffers(proxy, buffers, offset); + }); + } + + void fill_compressed_buffers( + const sparrow::arrow_proxy& arrow_proxy, + std::vector& flatbuf_compressed_buffers, + int64_t& offset, + const CompressionType compression_type + ) + { + details::fill_buffers_impl( + arrow_proxy, flatbuf_compressed_buffers, offset, [&](const auto& buffer) { + return compress(compression_type, std::span(buffer.data(), buffer.size())) + .size(); + }); + } + + std::vector + get_compressed_buffers(const sparrow::record_batch& record_batch, const CompressionType compression_type) + { + return details::get_buffers_impl(record_batch, [&](const sparrow::arrow_proxy& proxy, std::vector& buffers, int64_t& offset) { + fill_compressed_buffers(proxy, buffers, offset, compression_type); + }); + } + + int64_t calculate_body_size(const sparrow::arrow_proxy& arrow_proxy, std::optional compression) + { + int64_t total_size = 0; + if (compression.has_value()) + { + for (const auto& buffer : arrow_proxy.buffers()) + { + total_size += utils::align_to_8(compress(compression.value(), std::span(buffer.data(), buffer.size())).size()); + } + } + else { - int64_t size = static_cast(buffer.size()); - flatbuf_buffers.emplace_back(offset, size); - offset += utils::align_to_8(size); + for (const auto& buffer : arrow_proxy.buffers()) + { + total_size += utils::align_to_8(buffer.size()); + } } + for (const auto& child : arrow_proxy.children()) { - fill_buffers(child, flatbuf_buffers, offset); + total_size += calculate_body_size(child, compression); } + return total_size; } - std::vector get_buffers(const sparrow::record_batch& record_batch) + int64_t calculate_body_size(const sparrow::record_batch& record_batch, std::optional compression) { - std::vector buffers; - std::int64_t offset = 0; - for (const auto& column : record_batch.columns()) - { - const auto& arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(column); - fill_buffers(arrow_proxy, buffers, offset); - } - return buffers; + return std::accumulate( + record_batch.columns().begin(), + record_batch.columns().end(), + int64_t{0}, + [&](int64_t acc, const sparrow::array& arr) + { + const auto& arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(arr); + return acc + calculate_body_size(arrow_proxy, compression); + } + ); } - flatbuffers::FlatBufferBuilder get_record_batch_message_builder(const sparrow::record_batch& record_batch, std::optional compression) + flatbuffers::FlatBufferBuilder get_record_batch_message_builder(const sparrow::record_batch& record_batch, std::optional compression) { flatbuffers::FlatBufferBuilder record_batch_builder; flatbuffers::Offset compression_offset = 0; std::optional> compressed_buffers; if (compression) { - compressed_buffers = generate_compressed_buffers(record_batch, compression.value()); - compression_offset = org::apache::arrow::flatbuf::CreateBodyCompression(record_batch_builder, compression.value(), org::apache::arrow::flatbuf::BodyCompressionMethod::BUFFER); + compressed_buffers = get_compressed_buffers(record_batch, compression.value()); + compression_offset = org::apache::arrow::flatbuf::CreateBodyCompression(record_batch_builder, details::to_fb_compression_type(compression.value()), org::apache::arrow::flatbuf::BodyCompressionMethod::BUFFER); } const auto& buffers = compressed_buffers ? *compressed_buffers : get_buffers(record_batch); const std::vector nodes = create_fieldnodes(record_batch); diff --git a/src/serialize.cpp b/src/serialize.cpp index 016068a..397bacf 100644 --- a/src/serialize.cpp +++ b/src/serialize.cpp @@ -23,7 +23,7 @@ namespace sparrow_ipc common_serialize(get_schema_message_builder(record_batch), stream); } - void serialize_record_batch(const sparrow::record_batch& record_batch, any_output_stream& stream, std::optional compression) + void serialize_record_batch(const sparrow::record_batch& record_batch, any_output_stream& stream, std::optional compression) { common_serialize(get_record_batch_message_builder(record_batch, compression), stream); generate_body(record_batch, stream, compression); diff --git a/src/serialize_utils.cpp b/src/serialize_utils.cpp index 875e8be..193f2a2 100644 --- a/src/serialize_utils.cpp +++ b/src/serialize_utils.cpp @@ -1,4 +1,3 @@ -#include "sparrow_ipc/compression.hpp" #include "sparrow_ipc/flatbuffer_utils.hpp" #include "sparrow_ipc/magic_values.hpp" #include "sparrow_ipc/serialize.hpp" @@ -7,7 +6,7 @@ namespace sparrow_ipc { - void fill_body(const sparrow::arrow_proxy& arrow_proxy, any_output_stream& stream, std::optional compression) + void fill_body(const sparrow::arrow_proxy& arrow_proxy, any_output_stream& stream, std::optional compression) { std::for_each(arrow_proxy.buffers().begin(), arrow_proxy.buffers().end(), [&](const auto& buffer) { if (compression.has_value()) @@ -27,7 +26,7 @@ namespace sparrow_ipc }); } - void generate_body(const sparrow::record_batch& record_batch, any_output_stream& stream, std::optional compression) + void generate_body(const sparrow::record_batch& record_batch, any_output_stream& stream, std::optional compression) { std::for_each(record_batch.columns().begin(), record_batch.columns().end(), [&](const auto& column) { const auto& arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(column); @@ -35,45 +34,6 @@ namespace sparrow_ipc }); } - int64_t calculate_body_size(const sparrow::arrow_proxy& arrow_proxy, std::optional compression) - { - int64_t total_size = 0; - if (compression.has_value()) - { - for (const auto& buffer : arrow_proxy.buffers()) - { - total_size += utils::align_to_8(compress(compression.value(), std::span(buffer.data(), buffer.size())).size()); - } - } - else - { - for (const auto& buffer : arrow_proxy.buffers()) - { - total_size += utils::align_to_8(buffer.size()); - } - } - - for (const auto& child : arrow_proxy.children()) - { - total_size += calculate_body_size(child, compression); - } - return total_size; - } - - int64_t calculate_body_size(const sparrow::record_batch& record_batch, std::optional compression) - { - return std::accumulate( - record_batch.columns().begin(), - record_batch.columns().end(), - int64_t{0}, - [&](int64_t acc, const sparrow::array& arr) - { - const auto& arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(arr); - return acc + calculate_body_size(arrow_proxy, compression); - } - ); - } - std::size_t calculate_schema_message_size(const sparrow::record_batch& record_batch) { // Build the schema message to get its exact size @@ -89,7 +49,7 @@ namespace sparrow_ipc return utils::align_to_8(total_size); } - std::size_t calculate_record_batch_message_size(const sparrow::record_batch& record_batch, std::optional compression) + std::size_t calculate_record_batch_message_size(const sparrow::record_batch& record_batch, std::optional compression) { // Build the record batch message to get its exact metadata size flatbuffers::FlatBufferBuilder record_batch_builder = get_record_batch_message_builder(record_batch, compression); @@ -109,26 +69,6 @@ namespace sparrow_ipc return metadata_size + actual_body_size; } - std::vector - generate_compressed_buffers(const sparrow::record_batch& record_batch, const org::apache::arrow::flatbuf::CompressionType compression_type) - { - std::vector compressed_buffers; - int64_t current_offset = 0; - - for (const auto& column : record_batch.columns()) - { - const auto& arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(column); - for (const auto& buffer : arrow_proxy.buffers()) - { - std::vector compressed_buffer_with_header = compress(compression_type, std::span(buffer.data(), buffer.size())); - const size_t aligned_chunk_size = utils::align_to_8(compressed_buffer_with_header.size()); - compressed_buffers.emplace_back(current_offset, aligned_chunk_size); - current_offset += aligned_chunk_size; - } - } - return compressed_buffers; - } - std::vector get_column_dtypes(const sparrow::record_batch& rb) { std::vector dtypes; diff --git a/tests/include/sparrow_ipc_tests_helpers.hpp b/tests/include/sparrow_ipc_tests_helpers.hpp index 79cc84b..bfe5e30 100644 --- a/tests/include/sparrow_ipc_tests_helpers.hpp +++ b/tests/include/sparrow_ipc_tests_helpers.hpp @@ -71,4 +71,15 @@ namespace sparrow_ipc sp::array(sp::string_array(std::vector{"hello", "world", "test", "data", "batch"}))}} ); } + + // Helper function to create a compressible record batch for testing + inline sp::record_batch create_compressible_test_record_batch() + { + std::vector int_data(1000, 12345); + std::vector string_data(1000, "hello world"); + return sp::record_batch( + {{"int_col", sp::array(sp::primitive_array(int_data))}, + {"string_col", sp::array(sp::string_array(string_data))}} + ); + } } diff --git a/tests/test_chunk_memory_serializer.cpp b/tests/test_chunk_memory_serializer.cpp index 1230b97..a52abd0 100644 --- a/tests/test_chunk_memory_serializer.cpp +++ b/tests/test_chunk_memory_serializer.cpp @@ -15,20 +15,37 @@ namespace sparrow_ipc { TEST_CASE("construction with single record batch") { - SUBCASE("Valid record batch") + SUBCASE("Valid record batch, with and without compression") { - auto rb = create_test_record_batch(); - std::vector> chunks; - chunked_memory_output_stream stream(chunks); + auto rb = create_compressible_test_record_batch(); + std::vector> chunks_compressed; + chunked_memory_output_stream stream_compressed(chunks_compressed); - chunk_serializer serializer(stream); - serializer << rb; + chunk_serializer serializer_compressed(stream_compressed, CompressionType::LZ4_FRAME); + serializer_compressed << rb; // After construction with single record batch, should have schema + record batch - CHECK_EQ(chunks.size(), 2); - CHECK_GT(chunks[0].size(), 0); // Schema message - CHECK_GT(chunks[1].size(), 0); // Record batch message - CHECK_GT(stream.size(), 0); + CHECK_EQ(chunks_compressed.size(), 2); + CHECK_GT(chunks_compressed[0].size(), 0); // Schema message + CHECK_GT(chunks_compressed[1].size(), 0); // Record batch message + CHECK_GT(stream_compressed.size(), 0); + + std::vector> chunks_uncompressed; + chunked_memory_output_stream stream_uncompressed(chunks_uncompressed); + + chunk_serializer serializer_uncompressed(stream_uncompressed); + serializer_uncompressed << rb; + + CHECK_EQ(chunks_uncompressed.size(), 2); + CHECK_GT(chunks_uncompressed[0].size(), 0); // Schema message + CHECK_GT(chunks_uncompressed[1].size(), 0); // Record batch message + CHECK_GT(stream_uncompressed.size(), 0); + + // Check that schema size is the same + CHECK_EQ(chunks_compressed[0].size(), chunks_uncompressed[0].size()); + + // Check that compressed record batch is smaller + CHECK_LT(chunks_compressed[1].size(), chunks_uncompressed[1].size()); } SUBCASE("Empty record batch") diff --git a/tests/test_compression.cpp b/tests/test_compression.cpp index d190052..b037387 100644 --- a/tests/test_compression.cpp +++ b/tests/test_compression.cpp @@ -4,7 +4,7 @@ #include -#include +#include "../src/compression_impl.hpp" namespace sparrow_ipc { @@ -14,19 +14,19 @@ namespace sparrow_ipc { std::string original_string = "some data to compress"; std::vector original_data(original_string.begin(), original_string.end()); - const auto compression_type = org::apache::arrow::flatbuf::CompressionType::ZSTD; + const auto compression_type = CompressionType::ZSTD; // Test compression with ZSTD - CHECK_THROWS_WITH_AS(compress(compression_type, original_data), "Compression using zstd is not supported yet.", std::runtime_error); + CHECK_THROWS_WITH_AS(compress(compression_type, original_data), "Compression using zstd is not supported yet.", std::invalid_argument); // Test decompression with ZSTD - CHECK_THROWS_WITH_AS(decompress(compression_type, original_data), "Decompression using zstd is not supported yet.", std::runtime_error); + CHECK_THROWS_WITH_AS(decompress(compression_type, original_data), "Decompression using zstd is not supported yet.", std::invalid_argument); } TEST_CASE("Decompress empty data") { const std::vector empty_data; - const auto compression_type = org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME; + const auto compression_type = CompressionType::LZ4_FRAME; CHECK_THROWS_WITH_AS(decompress(compression_type, empty_data), "Trying to decompress empty data.", std::runtime_error); } @@ -34,11 +34,11 @@ namespace sparrow_ipc TEST_CASE("Empty data") { const std::vector empty_data; - const auto compression_type = org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME; + const auto compression_type = CompressionType::LZ4_FRAME; // Test compression of empty data auto compressed = compress(compression_type, empty_data); - CHECK_EQ(compressed.size(), CompressionHeaderSize); + CHECK_EQ(compressed.size(), details::CompressionHeaderSize); const std::int64_t header = *reinterpret_cast(compressed.data()); CHECK_EQ(header, -1); @@ -53,7 +53,7 @@ namespace sparrow_ipc std::vector original_data(original_string.begin(), original_string.end()); // Compress data - auto compression_type = org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME; + auto compression_type = CompressionType::LZ4_FRAME; std::vector compressed_data = compress(compression_type, original_data); // Decompress @@ -75,7 +75,7 @@ namespace sparrow_ipc std::vector original_data(original_string.begin(), original_string.end()); // Compress data - auto compression_type = org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME; + auto compression_type = CompressionType::LZ4_FRAME; std::vector compressed_data = compress(compression_type, original_data); // Decompress diff --git a/tests/test_de_serialization_with_files.cpp b/tests/test_de_serialization_with_files.cpp index 7b7d236..360d67c 100644 --- a/tests/test_de_serialization_with_files.cpp +++ b/tests/test_de_serialization_with_files.cpp @@ -227,7 +227,7 @@ TEST_SUITE("Integration tests") std::vector serialized_data; sparrow_ipc::memory_output_stream stream(serialized_data); - sparrow_ipc::serializer serializer(stream, org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME); + sparrow_ipc::serializer serializer(stream, sparrow_ipc::CompressionType::LZ4_FRAME); serializer << record_batches_from_json << sparrow_ipc::end_stream; const auto deserialized_serialized_data = sparrow_ipc::deserialize_stream( std::span(serialized_data) diff --git a/tests/test_flatbuffer_utils.cpp b/tests/test_flatbuffer_utils.cpp index b508140..48a78df 100644 --- a/tests/test_flatbuffer_utils.cpp +++ b/tests/test_flatbuffer_utils.cpp @@ -178,25 +178,55 @@ namespace sparrow_ipc } } + void test_fill_buffers_variant( + const std::function&, int64_t&)>& fill_func) + { + auto array = sp::primitive_array({1, 2, 3, 4, 5}); + auto proxy = sp::detail::array_access::get_arrow_proxy(array); + + std::vector buffers; + int64_t offset = 0; + fill_func(proxy, buffers, offset); + + CHECK_GT(buffers.size(), 0); + CHECK_GT(offset, 0); + + // Verify offsets are aligned + for (const auto& buffer : buffers) + { + CHECK_EQ(buffer.offset() % 8, 0); + } + } + TEST_CASE("fill_buffers") { SUBCASE("Simple primitive array") { - auto array = sp::primitive_array({1, 2, 3, 4, 5}); - auto proxy = sp::detail::array_access::get_arrow_proxy(array); - - std::vector buffers; - int64_t offset = 0; - fill_buffers(proxy, buffers, offset); + test_fill_buffers_variant([](const sparrow::arrow_proxy& proxy, std::vector& buffers, int64_t& offset) { + fill_buffers(proxy, buffers, offset); + }); + } + } - CHECK_GT(buffers.size(), 0); - CHECK_GT(offset, 0); + TEST_CASE("fill_compressed_buffers") + { + SUBCASE("Simple primitive array") + { + test_fill_buffers_variant([](const sparrow::arrow_proxy& proxy, std::vector& buffers, int64_t& offset) { + fill_compressed_buffers(proxy, buffers, offset, CompressionType::LZ4_FRAME); + }); + } + } - // Verify offsets are aligned - for (const auto& buffer : buffers) - { - CHECK_EQ(buffer.offset() % 8, 0); - } + void test_get_buffers_variant(const std::function(const sparrow::record_batch&)>& get_func) + { + auto record_batch = create_test_record_batch(); + auto buffers = get_func(record_batch); + CHECK_GT(buffers.size(), 0); + // Verify all offsets are properly calculated and aligned + for (size_t i = 1; i < buffers.size(); ++i) + { + CHECK_GE(buffers[i].offset(), buffers[i - 1].offset() + buffers[i - 1].length()); } } @@ -204,14 +234,19 @@ namespace sparrow_ipc { SUBCASE("Record batch with multiple columns") { - auto record_batch = create_test_record_batch(); - auto buffers = get_buffers(record_batch); - CHECK_GT(buffers.size(), 0); - // Verify all offsets are properly calculated and aligned - for (size_t i = 1; i < buffers.size(); ++i) - { - CHECK_GE(buffers[i].offset(), buffers[i - 1].offset() + buffers[i - 1].length()); - } + test_get_buffers_variant([](const sparrow::record_batch& record_batch) { + return get_buffers(record_batch); + }); + } + } + + TEST_CASE("get_compressed_buffers") + { + SUBCASE("Record batch with multiple columns") + { + test_get_buffers_variant([](const sparrow::record_batch& record_batch) { + return get_compressed_buffers(record_batch, CompressionType::LZ4_FRAME); + }); } } @@ -523,13 +558,23 @@ namespace sparrow_ipc TEST_CASE("get_record_batch_message_builder") { - SUBCASE("Valid record batch with field nodes and buffers") + auto test_get_record_batch_message_builder = [](std::optional compression) { auto record_batch = create_test_record_batch(); - auto builder = get_record_batch_message_builder(record_batch); + auto builder = get_record_batch_message_builder(record_batch, compression); CHECK_GT(builder.GetSize(), 0); CHECK_NE(builder.GetBufferPointer(), nullptr); + }; + + SUBCASE("Valid record batch with field nodes and buffers (Without compression)") + { + test_get_record_batch_message_builder(std::nullopt); + } + + SUBCASE("Valid record batch with field nodes and buffers (With compression)") + { + test_get_record_batch_message_builder(CompressionType::LZ4_FRAME); } } } -} \ No newline at end of file +} diff --git a/tests/test_serialize_utils.cpp b/tests/test_serialize_utils.cpp index 194f958..8717172 100644 --- a/tests/test_serialize_utils.cpp +++ b/tests/test_serialize_utils.cpp @@ -4,6 +4,7 @@ #include #include "sparrow_ipc/any_output_stream.hpp" +#include "sparrow_ipc/flatbuffer_utils.hpp" #include "sparrow_ipc/magic_values.hpp" #include "sparrow_ipc/memory_output_stream.hpp" #include "sparrow_ipc/serialize_utils.hpp" @@ -40,8 +41,6 @@ namespace sparrow_ipc } } - // TODO after the used fcts are stable regarding compression, add tests for fcts having it as an additional argument - // cf. fill_body example TEST_CASE("fill_body") { SUBCASE("Simple primitive array (uncompressed)") @@ -67,13 +66,19 @@ namespace sparrow_ipc std::vector body_compressed; sparrow_ipc::memory_output_stream stream_compressed(body_compressed); sparrow_ipc::any_output_stream astream_compressed(stream_compressed); - fill_body(proxy, astream_compressed, org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME); + fill_body(proxy, astream_compressed, CompressionType::LZ4_FRAME); + CHECK_GT(body_compressed.size(), 0); + // Body size should be aligned + CHECK_EQ(body_compressed.size() % 8, 0); // Uncompressed std::vector body_uncompressed; sparrow_ipc::memory_output_stream stream_uncompressed(body_uncompressed); sparrow_ipc::any_output_stream astream_uncompressed(stream_uncompressed); fill_body(proxy, astream_uncompressed, std::nullopt); + CHECK_GT(body_uncompressed.size(), 0); + // Body size should be aligned + CHECK_EQ(body_uncompressed.size() % 8, 0); // Check that compressed size is smaller than uncompressed size CHECK_LT(body_compressed.size(), body_uncompressed.size()); } @@ -81,43 +86,74 @@ namespace sparrow_ipc TEST_CASE("generate_body") { - SUBCASE("Record batch with multiple columns") + auto record_batch = create_test_record_batch(); + SUBCASE("Record batch with multiple columns (uncompressed)") { - auto record_batch = create_test_record_batch(); std::vector serialized; memory_output_stream stream(serialized); any_output_stream astream(stream); - generate_body(record_batch, astream); + generate_body(record_batch, astream, std::nullopt); + CHECK_GT(serialized.size(), 0); + CHECK_EQ(serialized.size() % 8, 0); + } + + SUBCASE("Record batch with multiple columns (compressed)") + { + std::vector serialized; + memory_output_stream stream(serialized); + any_output_stream astream(stream); + generate_body(record_batch, astream, CompressionType::LZ4_FRAME); CHECK_GT(serialized.size(), 0); CHECK_EQ(serialized.size() % 8, 0); } } +#if defined(SPARROW_IPC_STATIC_LIB) TEST_CASE("calculate_body_size") { - SUBCASE("Single array") - { - auto array = sp::primitive_array({1, 2, 3, 4, 5}); - auto proxy = sp::detail::array_access::get_arrow_proxy(array); + auto array = sp::primitive_array({1, 2, 3, 4, 5}); + auto proxy = sp::detail::array_access::get_arrow_proxy(array); + SUBCASE("Single array (uncompressed)") + { auto size = calculate_body_size(proxy); CHECK_GT(size, 0); CHECK_EQ(size % 8, 0); } - SUBCASE("Record batch") + SUBCASE("Single array (compressed)") + { + auto size = calculate_body_size(proxy, CompressionType::LZ4_FRAME); + CHECK_GT(size, 0); + CHECK_EQ(size % 8, 0); + } + + auto record_batch = create_test_record_batch(); + SUBCASE("Record batch (uncompressed)") { - auto record_batch = create_test_record_batch(); auto size = calculate_body_size(record_batch); CHECK_GT(size, 0); CHECK_EQ(size % 8, 0); std::vector serialized; memory_output_stream stream(serialized); - any_output_stream astream(stream); + any_output_stream astream(stream); generate_body(record_batch, astream); CHECK_EQ(size, static_cast(serialized.size())); } + + SUBCASE("Record batch (compressed)") + { + auto size = calculate_body_size(record_batch, CompressionType::LZ4_FRAME); + CHECK_GT(size, 0); + CHECK_EQ(size % 8, 0); + std::vector serialized; + memory_output_stream stream(serialized); + any_output_stream astream(stream); + generate_body(record_batch, astream, CompressionType::LZ4_FRAME); + CHECK_EQ(size, static_cast(serialized.size())); + } } +#endif TEST_CASE("calculate_schema_message_size") { @@ -158,55 +194,59 @@ namespace sparrow_ipc TEST_CASE("calculate_record_batch_message_size") { - SUBCASE("Single column record batch") + auto test_calculate_record_batch_message_size = [](const sp::record_batch& record_batch, std::optional compression) { - auto array = sp::primitive_array({1, 2, 3, 4, 5}); - auto record_batch = sp::record_batch({{"column1", sp::array(std::move(array))}}); - - auto estimated_size = calculate_record_batch_message_size(record_batch); + auto estimated_size = calculate_record_batch_message_size(record_batch, compression); CHECK_GT(estimated_size, 0); CHECK_EQ(estimated_size % 8, 0); std::vector serialized; memory_output_stream stream(serialized); any_output_stream astream(stream); - serialize_record_batch(record_batch, astream, std::nullopt); + serialize_record_batch(record_batch, astream, compression); CHECK_EQ(estimated_size, serialized.size()); + }; + + SUBCASE("Single column record batch") + { + auto array = sp::primitive_array({1, 2, 3, 4, 5}); + auto record_batch = sp::record_batch({{"column1", sp::array(std::move(array))}}); + test_calculate_record_batch_message_size(record_batch, std::nullopt); + test_calculate_record_batch_message_size(record_batch, CompressionType::LZ4_FRAME); } SUBCASE("Multi-column record batch") { auto record_batch = create_test_record_batch(); - - auto estimated_size = calculate_record_batch_message_size(record_batch); - CHECK_GT(estimated_size, 0); - CHECK_EQ(estimated_size % 8, 0); - - // Verify by actual serialization - std::vector serialized; - memory_output_stream stream(serialized); - any_output_stream astream(stream); - serialize_record_batch(record_batch, astream, std::nullopt); - - CHECK_EQ(estimated_size, serialized.size()); + test_calculate_record_batch_message_size(record_batch, std::nullopt); + test_calculate_record_batch_message_size(record_batch, CompressionType::LZ4_FRAME); } } TEST_CASE("calculate_total_serialized_size") { + auto test_calculate_total_serialized_size = [](const std::vector& batches, std::optional compression) + { + auto estimated_size = calculate_total_serialized_size(batches, compression); + CHECK_GT(estimated_size, 0); + + // Should be equal to schema size + sum of record batch sizes + auto schema_size = calculate_schema_message_size(batches[0]); + int64_t batches_size = 0; + for(const auto& batch : batches) + { + batches_size += calculate_record_batch_message_size(batch, compression); + } + CHECK_EQ(estimated_size, schema_size + batches_size); + }; + SUBCASE("Single record batch") { auto record_batch = create_test_record_batch(); std::vector batches = {record_batch}; - - auto estimated_size = calculate_total_serialized_size(batches); - CHECK_GT(estimated_size, 0); - - // Should equal schema size + record batch size - auto schema_size = calculate_schema_message_size(record_batch); - auto batch_size = calculate_record_batch_message_size(record_batch); - CHECK_EQ(estimated_size, schema_size + batch_size); + test_calculate_total_serialized_size(batches, std::nullopt); + test_calculate_total_serialized_size(batches, CompressionType::LZ4_FRAME); } SUBCASE("Multiple record batches") @@ -224,15 +264,8 @@ namespace sparrow_ipc ); std::vector batches = {record_batch1, record_batch2}; - - auto estimated_size = calculate_total_serialized_size(batches); - CHECK_GT(estimated_size, 0); - - // Should equal schema size + sum of record batch sizes - auto schema_size = calculate_schema_message_size(batches[0]); - auto batch1_size = calculate_record_batch_message_size(batches[0]); - auto batch2_size = calculate_record_batch_message_size(batches[1]); - CHECK_EQ(estimated_size, schema_size + batch1_size + batch2_size); + test_calculate_total_serialized_size(batches, std::nullopt); + test_calculate_total_serialized_size(batches, CompressionType::LZ4_FRAME); } SUBCASE("Empty collection") @@ -255,19 +288,19 @@ namespace sparrow_ipc std::vector batches = {record_batch1, record_batch2}; CHECK_THROWS_AS(auto size = calculate_total_serialized_size(batches), std::invalid_argument); + CHECK_THROWS_AS(auto size = calculate_total_serialized_size(batches, CompressionType::LZ4_FRAME), std::invalid_argument); } } TEST_CASE("serialize_record_batch") { - SUBCASE("Valid record batch") + auto test_serialize_record_batch = [](const sp::record_batch& record_batch_to_serialize, std::optional compression) { - auto record_batch = create_test_record_batch(); std::vector serialized; memory_output_stream stream(serialized); any_output_stream astream(stream); - serialize_record_batch(record_batch, astream, std::nullopt); - CHECK_GT(serialized.size(), 0); + serialize_record_batch(record_batch_to_serialize, astream, compression); + CHECK_GT(serialized.size(), 0); // Check that it starts with continuation bytes CHECK_GE(serialized.size(), continuation.size()); @@ -293,17 +326,23 @@ namespace sparrow_ipc // Verify alignment CHECK_EQ(aligned_metadata_end % 8, 0); CHECK_LE(aligned_metadata_end, serialized.size()); + + return serialized.size(); + }; + + SUBCASE("Valid record batch") + { + auto record_batch = create_compressible_test_record_batch(); + auto compressed_size = test_serialize_record_batch(record_batch, CompressionType::LZ4_FRAME); + auto uncompressed_size = test_serialize_record_batch(record_batch, std::nullopt); + CHECK_LT(compressed_size, uncompressed_size); } SUBCASE("Empty record batch") { auto empty_batch = sp::record_batch({}); - std::vector serialized; - memory_output_stream stream(serialized); - any_output_stream astream(stream); - serialize_record_batch(empty_batch, astream, std::nullopt); - CHECK_GT(serialized.size(), 0); - CHECK_GE(serialized.size(), continuation.size()); + test_serialize_record_batch(empty_batch, std::nullopt); + test_serialize_record_batch(empty_batch, CompressionType::LZ4_FRAME); } } } diff --git a/tests/test_serializer.cpp b/tests/test_serializer.cpp index c35bcaa..ccb6a72 100644 --- a/tests/test_serializer.cpp +++ b/tests/test_serializer.cpp @@ -35,15 +35,22 @@ namespace sparrow_ipc { TEST_CASE_TEMPLATE("construction and write single record batch", StreamWrapper, memory_stream_wrapper, ostringstream_wrapper) { - SUBCASE("Valid record batch") + SUBCASE("Valid record batch, with and without compression") { - auto rb = create_test_record_batch(); - StreamWrapper wrapper; - serializer ser(wrapper.get_stream()); - ser.write(rb); + auto rb = create_compressible_test_record_batch(); + StreamWrapper wrapper_compressed; + serializer ser_compressed(wrapper_compressed.get_stream(), CompressionType::LZ4_FRAME); + ser_compressed.write(rb); // After writing first record batch, should have schema + record batch - CHECK_GT(wrapper.size(), 0); + CHECK_GT(wrapper_compressed.size(), 0); + + StreamWrapper wrapper_uncompressed; + serializer ser_uncompressed(wrapper_uncompressed.get_stream()); + ser_uncompressed.write(rb); + CHECK_GT(wrapper_uncompressed.size(), 0); + + CHECK_LT(wrapper_compressed.size(), wrapper_uncompressed.size()); } SUBCASE("Empty record batch")