diff --git a/CMakeLists.txt b/CMakeLists.txt index 86b80e3..86f5ae6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -95,38 +95,49 @@ set(SPARROW_IPC_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include) set(SPARROW_IPC_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/src) set(SPARROW_IPC_HEADERS + ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/any_output_stream.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/arrow_interface/arrow_array_schema_common_release.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/arrow_interface/arrow_array.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/arrow_interface/arrow_array/private_data.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/arrow_interface/arrow_schema.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/arrow_interface/arrow_schema/private_data.hpp + ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/chunk_memory_output_stream.hpp + ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/chunk_memory_serializer.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/config/config.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/config/sparrow_ipc_version.hpp - ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_variable_size_binary_array.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_fixedsizebinary_array.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_primitive_array.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_utils.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_variable_size_binary_array.hpp + ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_variable_size_binary_array.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/encapsulated_message.hpp + ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/flatbuffer_utils.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/magic_values.hpp + ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/memory_output_stream.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/metadata.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/serialize_utils.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/serialize.hpp + ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/serializer.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/utils.hpp ) set(SPARROW_IPC_SRC - ${SPARROW_IPC_SOURCE_DIR}/serialize_utils.cpp + ${SPARROW_IPC_SOURCE_DIR}/any_output_stream.cpp ${SPARROW_IPC_SOURCE_DIR}/arrow_interface/arrow_array.cpp ${SPARROW_IPC_SOURCE_DIR}/arrow_interface/arrow_array/private_data.cpp ${SPARROW_IPC_SOURCE_DIR}/arrow_interface/arrow_schema.cpp ${SPARROW_IPC_SOURCE_DIR}/arrow_interface/arrow_schema/private_data.cpp + ${SPARROW_IPC_SOURCE_DIR}/chunk_memory_serializer.cpp ${SPARROW_IPC_SOURCE_DIR}/deserialize_fixedsizebinary_array.cpp ${SPARROW_IPC_SOURCE_DIR}/deserialize_utils.cpp ${SPARROW_IPC_SOURCE_DIR}/deserialize.cpp ${SPARROW_IPC_SOURCE_DIR}/encapsulated_message.cpp + ${SPARROW_IPC_SOURCE_DIR}/flatbuffer_utils.cpp ${SPARROW_IPC_SOURCE_DIR}/metadata.cpp + ${SPARROW_IPC_SOURCE_DIR}/serialize_utils.cpp + ${SPARROW_IPC_SOURCE_DIR}/serialize.cpp + ${SPARROW_IPC_SOURCE_DIR}/serializer.cpp ${SPARROW_IPC_SOURCE_DIR}/utils.cpp ) diff --git a/include/sparrow_ipc/any_output_stream.hpp b/include/sparrow_ipc/any_output_stream.hpp new file mode 100644 index 0000000..110cecb --- /dev/null +++ b/include/sparrow_ipc/any_output_stream.hpp @@ -0,0 +1,349 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "sparrow_ipc/config/config.hpp" + +namespace sparrow_ipc +{ + /** + * @brief Concept for stream-like types that support write operations. + * + * A type satisfies this concept if it has a write method that accepts + * a span of bytes and returns the number of bytes written. + */ + template + concept writable_stream = requires(T& t, const char* s, std::streamsize count) { + { t.write(s, count) }; + }; + + /** + * @brief Type-erased wrapper for any stream-like object. + * + * This class provides type erasure for ANY type that supports stream operations. + * It uses the concept-based type erasure + * pattern to wrap any stream-like object polymorphically. + * + * @details This implementation uses the classic type erasure pattern with: + * - An abstract base class (stream_concept) defining the interface + * - A templated model class that adapts any stream type to the interface + * - A wrapper class that stores the model polymorphically + * + * Usage: + * @code + * std::vector buffer; + * memory_output_stream> mem_stream(buffer); + * any_output_stream stream1(mem_stream); + * + * // Also works with standard streams + * std::ostringstream oss; + * any_output_stream stream2(oss); + * + * // Or any custom type with a write method + * my_custom_stream custom; + * any_output_stream stream3(custom); + * @endcode + * + * The class provides a common interface that works with any stream type. + */ + class SPARROW_IPC_API any_output_stream + { + public: + + /** + * @brief Constructs a type-erased stream from any stream-like object. + * + * @tparam TStream The concrete stream type (must satisfy writable_stream concept) + * @param stream The stream object to wrap + * + * The stream is stored by reference, so the caller must ensure the stream + * lifetime exceeds that of the any_output_stream object. + */ + template + any_output_stream(TStream& stream); + + /** + * @brief Default destructor. + */ + ~any_output_stream() = default; + + any_output_stream(any_output_stream&&) noexcept = default; + any_output_stream& operator=(any_output_stream&&) noexcept = default; + + any_output_stream(const any_output_stream&) = delete; + any_output_stream& operator=(const any_output_stream&) = delete; + + /** + * @brief Writes a span of bytes to the underlying stream. + * + * @param span The bytes to write + * @return The number of bytes written + * @throws std::runtime_error if write operation fails + */ + void write(std::span span); + + /** + * @brief Writes a single byte value multiple times. + * + * @param value The byte value to write + * @param count Number of times to write the value (default: 1) + * @return The number of bytes written + */ + void write(uint8_t value, std::size_t count = 1); + + /** + * @brief Adds padding to align to 8-byte boundary. + */ + void add_padding(); + + /** + * @brief Reserves capacity if supported by the underlying stream. + * + * @param size The number of bytes to reserve + */ + void reserve(std::size_t size); + + /** + * @brief Reserves capacity using a lazy calculation function. + * + * @param calculate_reserve_size Function that calculates the size to reserve + */ + void reserve(const std::function& calculate_reserve_size); + + /** + * @brief Gets the current size of the stream. + * + * @return The current number of bytes written + */ + [[nodiscard]] size_t size() const; + + /** + * @brief Gets a reference to the underlying stream cast to the specified type. + * + * @tparam TStream The expected concrete type of the underlying stream + * @return Reference to the underlying stream as TStream + * @throws std::bad_cast if the underlying stream is not of type TStream + */ + template + TStream& get(); + + /** + * @brief Gets a const reference to the underlying stream cast to the specified type. + * + * @tparam TStream The expected concrete type of the underlying stream + * @return Const reference to the underlying stream as TStream + * @throws std::bad_cast if the underlying stream is not of type TStream + */ + template + const TStream& get() const; + + private: + + /** + * @brief Abstract interface for type-erased streams. + */ + struct stream_concept + { + virtual ~stream_concept() = default; + virtual void write(const char* s, std::streamsize count) = 0; + virtual void write(std::span span) = 0; + virtual void write(uint8_t value, std::size_t count) = 0; + virtual void put(uint8_t value) = 0; + virtual void add_padding() = 0; + virtual void reserve(std::size_t size) = 0; + virtual void reserve(const std::function& calculate_reserve_size) = 0; + [[nodiscard]] virtual size_t size() const = 0; + }; + + /** + * @brief Concrete model that adapts a specific stream type to the interface. + * + * @tparam TStream The concrete stream type + */ + template + class stream_model : public stream_concept + { + public: + + stream_model(TStream& stream); + + void write(const char* s, std::streamsize count) final; + + void write(std::span span) final; + + void write(uint8_t value, std::size_t count) final; + + void put(uint8_t value) final; + + void add_padding() final; + + void reserve(std::size_t size) final; + + void reserve(const std::function& calculate_reserve_size) final; + + [[nodiscard]] size_t size() const final; + + TStream& get_stream(); + + const TStream& get_stream() const; + + private: + + TStream* m_stream; + size_t m_size = 0; + }; + + std::unique_ptr m_impl; + }; + + // Implementation + + template + any_output_stream::any_output_stream(TStream& stream) + : m_impl(std::make_unique>(stream)) + { + } + + template + TStream& any_output_stream::get() + { + auto* model = dynamic_cast*>(m_impl.get()); + if (!model) + { + throw std::bad_cast(); + } + return model->get_stream(); + } + + template + const TStream& any_output_stream::get() const + { + const auto* model = dynamic_cast*>(m_impl.get()); + if (!model) + { + throw std::bad_cast(); + } + return model->get_stream(); + } + + // stream_model implementation + + template + any_output_stream::stream_model::stream_model(TStream& stream) + : m_stream(&stream) + { + } + + template + void any_output_stream::stream_model::write(const char* s, std::streamsize count) + { + m_stream->write(s, count); + m_size += static_cast(count); + } + + template + void any_output_stream::stream_model::write(std::span span) + { + m_stream->write(reinterpret_cast(span.data()), static_cast(span.size())); + m_size += span.size(); + } + + template + void any_output_stream::stream_model::write(uint8_t value, std::size_t count) + { + if constexpr (requires(TStream& t, uint8_t v, std::size_t c) { t.write(v, c); }) + { + m_stream->write(value, count); + } + else + { + // Fallback: write one byte at a time + for (std::size_t i = 0; i < count; ++i) + { + m_stream->put(value); + } + } + m_size += count; + } + + template + void any_output_stream::stream_model::put(uint8_t value) + { + m_stream->put(value); + m_size ++; + } + + template + void any_output_stream::stream_model::add_padding() + { + const size_t current_size = size(); + const size_t padding_needed = (8 - (current_size % 8)) % 8; + if (padding_needed > 0) + { + static constexpr char padding_value = 0; + for (size_t i = 0; i < padding_needed; ++i) + { + m_stream->write(&padding_value, 1); + } + m_size += padding_needed; + } + } + + template + void any_output_stream::stream_model::reserve(std::size_t size) + { + if constexpr (requires(TStream& t, std::size_t s) { t.reserve(s); }) + { + m_stream->reserve(size); + } + // If not reservable, do nothing + } + + template + void any_output_stream::stream_model::reserve(const std::function& calculate_reserve_size) + { + if constexpr (requires(TStream& t, const std::function& func) { + { t.reserve(func) }; + }) + { + m_stream->reserve(calculate_reserve_size); + } + else if constexpr (requires(TStream& t, std::size_t s) { t.reserve(s); }) + { + m_stream->reserve(calculate_reserve_size()); + } + // If not reservable, do nothing + } + + template + size_t any_output_stream::stream_model::size() const + { + if constexpr (requires(const TStream& t) { + { t.size() } -> std::convertible_to; + }) + { + return m_stream->size(); + } + else + { + return m_size; + } + } + + template + TStream& any_output_stream::stream_model::get_stream() + { + return *m_stream; + } + + template + const TStream& any_output_stream::stream_model::get_stream() const + { + return *m_stream; + } +} // namespace sparrow_ipc diff --git a/include/sparrow_ipc/chunk_memory_output_stream.hpp b/include/sparrow_ipc/chunk_memory_output_stream.hpp new file mode 100644 index 0000000..2978d04 --- /dev/null +++ b/include/sparrow_ipc/chunk_memory_output_stream.hpp @@ -0,0 +1,217 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace sparrow_ipc +{ + /** + * @brief An output stream that writes data into separate memory chunks. + * + * This template class stores data in discrete memory chunks + * rather than a single contiguous buffer. Each write operation creates a new chunk, making it + * suitable for scenarios where data needs to be processed or transmitted in separate units. + * + * @tparam R A random access range type where each element is itself a random access range of uint8_t. + * Typically std::vector> or similar nested container types. + * + * @details The chunked approach offers several benefits: + * - Avoids large contiguous memory allocations + * - Enables efficient chunk-by-chunk processing or transmission + * - Supports memory reservation for the chunk container (not individual chunks) + * + * @note Each write operation creates a new chunk in the container, regardless of the write size. + */ + template + requires std::ranges::random_access_range + && std::ranges::random_access_range> + && std::same_as::value_type, uint8_t> + class chunked_memory_output_stream + { + public: + + /** + * @brief Constructs a chunked memory output stream with a reference to a chunk container. + * + * @param chunks Reference to the container that will store the memory chunks. + * The stream stores a pointer to this container for write operations. + */ + explicit chunked_memory_output_stream(R& chunks) + : m_chunks(&chunks) {}; + + /** + * @brief Writes character data as a new chunk. + * + * Creates a new chunk containing the specified character data. + * + * @param s Pointer to the character data to write + * @param count Number of characters to write + * @return Reference to this stream for method chaining + */ + chunked_memory_output_stream& write(const char* s, std::streamsize count); + + /** + * @brief Writes a span of bytes as a new chunk. + * + * Creates a new chunk containing the data from the provided span. + * + * @param span A span of bytes to write as a new chunk + * @return Reference to this stream for method chaining + */ + chunked_memory_output_stream& write(std::span span); + + /** + * @brief Writes a buffer by moving it into the chunk container. + * + * This is an optimized write operation that moves an existing buffer into the chunk + * container, avoiding a copy operation. + * + * @param buffer A vector of bytes to move into the chunk container + * @return Reference to this stream for method chaining + */ + chunked_memory_output_stream& write(std::vector&& buffer); + + /** + * @brief Writes a byte value repeated a specified number of times as a new chunk. + * + * Creates a new chunk filled with the specified byte value. + * + * @param value The byte value to write + * @param count Number of times to repeat the value + * @return Reference to this stream for method chaining + */ + chunked_memory_output_stream& write(uint8_t value, std::size_t count); + + /** + * @brief Writes a single character as a new chunk. + * + * Creates a new chunk containing a single byte. + * + * @param value The character value to write + * @return Reference to this stream for method chaining + */ + chunked_memory_output_stream& put(char value); + + /** + * @brief Reserves capacity in the chunk container. + * + * Reserves space for the specified number of chunks in the container. + * This does not reserve space within individual chunks. + * + * @param size Number of chunks to reserve space for + */ + void reserve(std::size_t size); + + /** + * @brief Reserves capacity using a lazy calculation function. + * + * Reserves space for chunks by calling the provided function to determine the count. + * + * @param calculate_reserve_size Function that returns the number of chunks to reserve + */ + void reserve(const std::function& calculate_reserve_size); + + /** + * @brief Gets the total size of all chunks. + * + * Calculates and returns the sum of sizes of all chunks in the container. + * + * @return The total number of bytes across all chunks + */ + [[nodiscard]] size_t size() const; + + private: + + R* m_chunks; + }; + + // Implementation + + template + requires std::ranges::random_access_range + && std::ranges::random_access_range> + && std::same_as::value_type, uint8_t> + chunked_memory_output_stream& chunked_memory_output_stream::write(const char* s, std::streamsize count) + { + m_chunks->emplace_back(s, s + count); + return *this; + } + + template + requires std::ranges::random_access_range + && std::ranges::random_access_range> + && std::same_as::value_type, uint8_t> + chunked_memory_output_stream& chunked_memory_output_stream::write(std::span span) + { + m_chunks->emplace_back(span.begin(), span.end()); + return *this; + } + + template + requires std::ranges::random_access_range + && std::ranges::random_access_range> + && std::same_as::value_type, uint8_t> + chunked_memory_output_stream& chunked_memory_output_stream::write(std::vector&& buffer) + { + m_chunks->emplace_back(std::move(buffer)); + return *this; + } + + template + requires std::ranges::random_access_range + && std::ranges::random_access_range> + && std::same_as::value_type, uint8_t> + chunked_memory_output_stream& chunked_memory_output_stream::write(uint8_t value, std::size_t count) + { + m_chunks->emplace_back(count, value); + return *this; + } + + template + requires std::ranges::random_access_range + && std::ranges::random_access_range> + && std::same_as::value_type, uint8_t> + chunked_memory_output_stream& chunked_memory_output_stream::put(char value) + { + m_chunks->emplace_back(std::vector{static_cast(value)}); + return *this; + } + + template + requires std::ranges::random_access_range + && std::ranges::random_access_range> + && std::same_as::value_type, uint8_t> + void chunked_memory_output_stream::reserve(std::size_t size) + { + m_chunks->reserve(size); + } + + template + requires std::ranges::random_access_range + && std::ranges::random_access_range> + && std::same_as::value_type, uint8_t> + void chunked_memory_output_stream::reserve(const std::function& calculate_reserve_size) + { + m_chunks->reserve(calculate_reserve_size()); + } + + template + requires std::ranges::random_access_range + && std::ranges::random_access_range> + && std::same_as::value_type, uint8_t> + size_t chunked_memory_output_stream::size() const + { + return std::accumulate( + m_chunks->begin(), + m_chunks->end(), + 0, + [](size_t acc, const auto& chunk) + { + return acc + chunk.size(); + } + ); + } +} \ No newline at end of file diff --git a/include/sparrow_ipc/chunk_memory_serializer.hpp b/include/sparrow_ipc/chunk_memory_serializer.hpp new file mode 100644 index 0000000..3b241f8 --- /dev/null +++ b/include/sparrow_ipc/chunk_memory_serializer.hpp @@ -0,0 +1,172 @@ +#pragma once + +#include + +#include "sparrow_ipc/chunk_memory_output_stream.hpp" +#include "sparrow_ipc/config/config.hpp" +#include "sparrow_ipc/memory_output_stream.hpp" +#include "sparrow_ipc/serialize.hpp" +#include "sparrow_ipc/serialize_utils.hpp" + +namespace sparrow_ipc +{ + /** + * @brief A serializer that writes record batches to chunked memory streams. + * + * The chunk_serializer class provides functionality to serialize Apache Arrow record batches + * into separate memory chunks. Each record batch (and the schema) is written as an independent + * chunk in the output stream, making it suitable for scenarios where data needs to be processed + * or transmitted in discrete units. + * + * @details The serializer maintains schema consistency across all record batches: + * - The schema is written once as the first chunk when the first record batch is processed + * - All subsequent record batches must have the same schema + * - Each record batch is serialized into its own independent memory chunk + * + * @note Once end() is called, no further record batches can be written to this serializer. + */ + class SPARROW_IPC_API chunk_serializer + { + public: + + /** + * @brief Constructs a chunk serializer with a reference to a chunked memory output stream. + * + * @param stream Reference to a chunked memory output stream that will receive the serialized chunks + */ + chunk_serializer(chunked_memory_output_stream>>& stream); + + /** + * @brief Writes a single record batch to the chunked stream. + * + * This method serializes a record batch into the chunked output stream. If this is the first + * record batch written, the schema is automatically serialized first as a separate chunk. + * + * @param rb The record batch to serialize + * @throws std::runtime_error if the serializer has been ended via end() + * @throws std::invalid_argument if the record batch schema doesn't match previously written batches + */ + void write(const sparrow::record_batch& rb); + + /** + * @brief Writes a range of record batches to the chunked stream. + * + * This template method efficiently serializes multiple record batches to the chunked output stream. + * If this is the first write operation, the schema is automatically serialized first as a separate chunk. + * Each record batch is then serialized into its own independent chunk. + * + * @tparam R The range type containing record batches (must satisfy std::ranges::input_range) + * @param record_batches A range of record batches to serialize + * @throws std::runtime_error if the serializer has been ended via end() + * @throws std::invalid_argument if any record batch schema doesn't match previously written batches + */ + template + requires std::same_as, sparrow::record_batch> + void write(const R& record_batches); + + /** + * @brief Appends a record batch using the stream insertion operator. + * + * This operator provides a convenient stream-like interface for appending + * record batches to the serializer. It delegates to the append() method + * and returns a reference to the serializer to enable method chaining. + * + * @param rb The record batch to append to the serializer + * @return A reference to this serializer for method chaining + * @throws std::invalid_argument if the record batch schema doesn't match + * @throws std::runtime_error if the serializer has been ended + * + * @example + * chunk_serializer ser(initial_batch, stream); + * ser << batch1 << batch2 << batch3; + */ + chunk_serializer& operator<<(const sparrow::record_batch& rb); + + /** + * @brief Appends a range of record batches using the stream insertion operator. + * + * This operator provides a convenient stream-like interface for appending + * multiple record batches to the serializer at once. It delegates to the + * append() method and returns a reference to the serializer to enable method chaining. + * + * @tparam R The type of the record batch collection (must be an input range) + * @param record_batches A range of record batches to append to the serializer + * @return A reference to this serializer for method chaining + * @throws std::invalid_argument if any record batch schema doesn't match + * @throws std::runtime_error if the serializer has been ended + * + * @example + * chunk_serializer ser(initial_batch, stream); + * std::vector batches = {batch1, batch2, batch3}; + * ser << batches << another_batch; + */ + template + requires std::same_as, sparrow::record_batch> + chunk_serializer& operator<<(const R& record_batches); + + /** + * @brief Finalizes the chunk serialization by writing an end-of-stream marker. + * + * This method signals the end of the serialization process. After calling this method, + * no further record batches can be written to this serializer. + * + * @throws std::runtime_error if attempting to write after this method has been called + */ + void end(); + + private: + + bool m_schema_received{false}; + std::vector m_dtypes; + chunked_memory_output_stream>>* m_pstream; + bool m_ended{false}; + }; + + // Implementation + + template + requires std::same_as, sparrow::record_batch> + void chunk_serializer::write(const R& record_batches) + { + if (m_ended) + { + throw std::runtime_error("Cannot append record batches to a serializer that has been ended"); + } + + m_pstream->reserve((m_schema_received ? 0 : 1) + m_pstream->size() + record_batches.size()); + + if (!m_schema_received) + { + m_schema_received = true; + m_dtypes = get_column_dtypes(*record_batches.begin()); + std::vector schema_buffer; + memory_output_stream stream(schema_buffer); + any_output_stream astream(stream); + serialize_schema_message(*record_batches.begin(), astream); + m_pstream->write(std::move(schema_buffer)); + } + + for (const auto& rb : record_batches) + { + std::vector buffer; + memory_output_stream stream(buffer); + any_output_stream astream(stream); + serialize_record_batch(rb, astream); + m_pstream->write(std::move(buffer)); + } + } + + inline chunk_serializer& chunk_serializer::operator<<(const sparrow::record_batch& rb) + { + write(rb); + return *this; + } + + template + requires std::same_as, sparrow::record_batch> + chunk_serializer& chunk_serializer::operator<<(const R& record_batches) + { + write(record_batches); + return *this; + } +} \ No newline at end of file diff --git a/include/sparrow_ipc/encapsulated_message.hpp b/include/sparrow_ipc/encapsulated_message.hpp index 7e95339..cea09a6 100644 --- a/include/sparrow_ipc/encapsulated_message.hpp +++ b/include/sparrow_ipc/encapsulated_message.hpp @@ -2,6 +2,7 @@ #include #include +#include #include "Message_generated.h" diff --git a/include/sparrow_ipc/flatbuffer_utils.hpp b/include/sparrow_ipc/flatbuffer_utils.hpp new file mode 100644 index 0000000..4ec4ef7 --- /dev/null +++ b/include/sparrow_ipc/flatbuffer_utils.hpp @@ -0,0 +1,217 @@ +#pragma once +#include +#include + +#include +#include + +namespace sparrow_ipc +{ + // Creates a Flatbuffers Decimal type from a format string + // The format string is expected to be in the format "d:precision,scale" + [[nodiscard]] std::pair> + get_flatbuffer_decimal_type( + flatbuffers::FlatBufferBuilder& builder, + std::string_view format_str, + const int32_t bitWidth + ); + + // Creates a Flatbuffers type from a format string + // This function maps a sparrow data type to the corresponding Flatbuffers type + [[nodiscard]] std::pair> + get_flatbuffer_type(flatbuffers::FlatBufferBuilder& builder, std::string_view format_str); + + /** + * @brief Creates a FlatBuffers vector of KeyValue pairs from ArrowSchema metadata. + * + * This function converts metadata from an ArrowSchema into a FlatBuffers representation + * suitable for serialization. It processes key-value pairs from the schema's metadata + * and creates corresponding FlatBuffers KeyValue objects. + * + * @param builder Reference to the FlatBufferBuilder used for creating FlatBuffers objects + * @param arrow_schema The ArrowSchema containing metadata to be serialized + * + * @return A FlatBuffers offset to a vector of KeyValue pairs. Returns 0 if the schema + * has no metadata (metadata is nullptr). + * + * @note The function reserves memory for the vector based on the metadata size for + * optimal performance. + */ + [[nodiscard]] flatbuffers::Offset>> + create_metadata(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema); + + /** + * @brief Creates a FlatBuffer Field object from an ArrowSchema. + * + * This function converts an ArrowSchema structure into a FlatBuffer Field representation + * suitable for Apache Arrow IPC serialization. It handles the creation of all necessary + * components including field name, type information, metadata, children, and nullable flag. + * + * @param builder Reference to the FlatBufferBuilder used for creating FlatBuffer objects + * @param arrow_schema The ArrowSchema structure containing the field definition to convert + * + * @return A FlatBuffer offset to the created Field object that can be used in further + * FlatBuffer construction operations + * + * @note Dictionary encoding is not currently supported (TODO item) + * @note The function checks the NULLABLE flag from the ArrowSchema flags to determine nullability + */ + [[nodiscard]] ::flatbuffers::Offset + create_field(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema); + + /** + * @brief Creates a FlatBuffers vector of Field objects from an ArrowSchema's children. + * + * This function iterates through all children of the given ArrowSchema and converts + * each child to a FlatBuffers Field object. The resulting fields are collected into + * a FlatBuffers vector. + * + * @param builder Reference to the FlatBufferBuilder used for creating FlatBuffers objects + * @param arrow_schema The ArrowSchema containing the children to convert + * + * @return A FlatBuffers offset to a vector of Field objects, or 0 if no children exist + * + * @throws std::invalid_argument If any child pointer in the ArrowSchema is null + * + * @note The function reserves space for all children upfront for performance optimization + * @note Returns 0 (null offset) when the schema has no children, otherwise returns a valid vector offset + */ + [[nodiscard]] ::flatbuffers::Offset< + ::flatbuffers::Vector<::flatbuffers::Offset>> + create_children(flatbuffers::FlatBufferBuilder& builder, sparrow::record_batch::column_range columns); + + /** + * @brief Creates a FlatBuffers vector of Field objects from a range of columns. + * + * This function iterates through the provided column range, extracts the Arrow schema + * from each column's proxy, and creates corresponding FlatBuffers Field objects. + * The resulting fields are collected into a vector and converted to a FlatBuffers + * vector offset. + * + * @param builder Reference to the FlatBuffers builder used for creating the vector + * @param columns Range of columns to process, each containing an Arrow schema proxy + * + * @return FlatBuffers offset to a vector of Field objects, or 0 if the input range is empty + * + * @note The function reserves space in the children vector based on the column count + * for performance optimization + */ + [[nodiscard]] ::flatbuffers::Offset< + ::flatbuffers::Vector<::flatbuffers::Offset>> + create_children(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema); + + /** + * @brief Creates a FlatBuffer builder containing a serialized Arrow schema message. + * + * This function constructs an Arrow IPC schema message from a record batch by: + * 1. Creating field definitions from the record batch columns + * 2. Building a Schema flatbuffer with little-endian byte order + * 3. Wrapping the schema in a Message with metadata version V5 + * 4. Finalizing the buffer for serialization + * + * @param record_batch The source record batch containing column definitions + * @return flatbuffers::FlatBufferBuilder A completed FlatBuffer containing the schema message, + * ready for Arrow IPC serialization + * + * @note The schema message has zero body length as it contains only metadata + * @note Currently uses little-endian byte order (marked as TODO for configurability) + */ + [[nodiscard]] flatbuffers::FlatBufferBuilder + get_schema_message_builder(const sparrow::record_batch& record_batch); + + /** + * @brief Recursively fills a vector of FieldNode objects from an arrow_proxy and its children. + * + * This function creates FieldNode objects containing length and null count information + * from the given arrow_proxy and recursively processes all its children, appending + * them to the provided nodes vector in depth-first order. + * + * @param arrow_proxy The arrow proxy object containing array metadata (length, null_count) + * and potential child arrays + * @param nodes Reference to a vector that will be populated with FieldNode objects. + * Each FieldNode contains the length and null count of the corresponding array. + * + * @note The function reserves space in the nodes vector to optimize memory allocation + * when processing children arrays. + * @note The traversal order is depth-first, with parent nodes added before their children. + */ + void fill_fieldnodes( + const sparrow::arrow_proxy& arrow_proxy, + std::vector& nodes + ); + + /** + * @brief Creates a vector of Apache Arrow FieldNode objects from a record batch. + * + * This function iterates through all columns in the provided record batch and + * generates corresponding FieldNode flatbuffer objects. Each column's arrow proxy + * is used to populate the field nodes vector through the fill_fieldnodes function. + * + * @param record_batch The sparrow record batch containing columns to process + * @return std::vector Vector of FieldNode + * objects representing the structure and metadata of each column + */ + [[nodiscard]] std::vector + create_fieldnodes(const sparrow::record_batch& record_batch); + + + /** + * @brief Recursively fills a vector of FlatBuffer Buffer objects with buffer information from an Arrow + * proxy. + * + * This function traverses an Arrow proxy structure and creates FlatBuffer Buffer entries for each buffer + * found in the proxy and its children. The buffers are processed in a depth-first manner, first handling + * the buffers of the current proxy, then recursively processing all child proxies. + * + * @param arrow_proxy The Arrow proxy object containing buffers and potential child proxies to process + * @param flatbuf_buffers Vector of FlatBuffer Buffer objects to be populated with buffer information + * @param offset Reference to the current byte offset, updated as buffers are processed and aligned to + * 8-byte boundaries + * + * @note The offset is automatically aligned to 8-byte boundaries using utils::align_to_8() for each + * buffer + * @note This function modifies both the flatbuf_buffers vector and the offset parameter + */ + void fill_buffers( + const sparrow::arrow_proxy& arrow_proxy, + std::vector& flatbuf_buffers, + int64_t& offset + ); + + /** + * @brief Extracts buffer information from a record batch for serialization. + * + * This function iterates through all columns in the provided record batch and + * collects their buffer information into a vector of Arrow FlatBuffer Buffer objects. + * The buffers are processed sequentially with cumulative offset tracking. + * + * @param record_batch The sparrow record batch containing columns to extract buffers from + * @return std::vector A vector containing all buffer + * descriptors from the record batch columns, with properly calculated offsets + * + * @note This function relies on the fill_buffers helper function to process individual + * column buffers and maintain offset consistency across all buffers. + */ + [[nodiscard]] std::vector + get_buffers(const sparrow::record_batch& record_batch); + + /** + * @brief Creates a FlatBuffer message containing a serialized Apache Arrow RecordBatch. + * + * This function builds a complete Arrow IPC message by serializing a record batch + * along with its metadata (field nodes and buffer information) into a FlatBuffer + * format that conforms to the Arrow IPC specification. + * + * @param record_batch The source record batch containing the data to be serialized + * + * @return A FlatBufferBuilder containing the complete serialized message ready for + * transmission or storage. The builder is finished and ready to be accessed + * via GetBufferPointer() and GetSize(). + * + * @note The returned message uses Arrow IPC format version V5 + * @note Compression and variadic buffer counts are not currently implemented (set to 0) + * @note The body size is automatically calculated based on the record batch contents + */ + [[nodiscard]] flatbuffers::FlatBufferBuilder + get_record_batch_message_builder(const sparrow::record_batch& record_batch); +} \ No newline at end of file diff --git a/include/sparrow_ipc/memory_output_stream.hpp b/include/sparrow_ipc/memory_output_stream.hpp new file mode 100644 index 0000000..27e2e06 --- /dev/null +++ b/include/sparrow_ipc/memory_output_stream.hpp @@ -0,0 +1,173 @@ + +#include +#include +#include + +namespace sparrow_ipc +{ + /** + * @brief An output stream that writes data to a contiguous memory buffer. + * + * This template class implements an output_stream that appends data to a contiguous + * random-access range (typically std::vector). All write operations append + * data to the end of the buffer, making it grow as needed. + * + * @tparam R A random access range type with uint8_t as its value type. + * Typically std::vector or similar contiguous container types. + * + * @details The memory output stream: + * - Supports efficient append operations + * - Can reserve capacity to minimize reallocations + * - Always operates on a contiguous memory buffer + * - Stores a non-owning reference to the buffer + * + * @note The caller must ensure the buffer remains valid for the lifetime of this stream + */ + template + requires std::ranges::random_access_range && std::same_as + class memory_output_stream + { + public: + + /** + * @brief Constructs a memory output stream with a reference to a buffer. + * + * @param buffer Reference to the container that will store the written data. + * The stream stores a non-owning pointer to this buffer for write operations. + * + * @note The caller must ensure the buffer remains valid for the lifetime of this stream + */ + memory_output_stream(R& buffer) + : m_buffer(&buffer) {}; + + /** + * @brief Writes character data to the buffer. + * + * Appends the specified character data to the end of the buffer. + * + * @param s Pointer to the character data to write + * @param count Number of characters to write + * @return Reference to this stream for method chaining + * + * @note The characters are converted to uint8_t and appended to the buffer + */ + memory_output_stream& write(const char* s, std::streamsize count); + + /** + * @brief Writes a span of bytes to the buffer. + * + * Appends the data from the provided span to the end of the buffer. + * + * @param span A span of bytes to write + * @return Reference to this stream for method chaining + */ + memory_output_stream& write(std::span span); + + /** + * @brief Writes a byte value repeated a specified number of times. + * + * Appends the specified byte value repeated count times to the end of the buffer. + * This is useful for padding operations or filling with a specific value. + * + * @param value The byte value to write + * @param count Number of times to repeat the value + * @return Reference to this stream for method chaining + */ + memory_output_stream& write(uint8_t value, std::size_t count); + + /** + * @brief Writes a single character to the buffer. + * + * Appends a single byte to the end of the buffer. The character is cast to uint8_t. + * + * @param value The character value to write + * @return Reference to this stream for method chaining + */ + memory_output_stream& put(char value); + + /** + * @brief Reserves capacity in the underlying buffer. + * + * Reserves space for at least the specified number of bytes in the buffer. + * This can help minimize reallocations during subsequent write operations. + * + * @param size Number of bytes to reserve + */ + void reserve(std::size_t size); + + /** + * @brief Reserves capacity using a lazy calculation function. + * + * Calls the provided function to determine the buffer size to reserve. + * + * @param calculate_reserve_size Function that returns the number of bytes to reserve + */ + void reserve(const std::function& calculate_reserve_size); + + /** + * @brief Gets the current size of the buffer. + * + * @return The number of bytes currently in the buffer + */ + [[nodiscard]] size_t size() const; + + private: + + R* m_buffer; + }; + + // Implementation + + template + requires std::ranges::random_access_range && std::same_as + memory_output_stream& memory_output_stream::write(const char* s, std::streamsize count) + { + m_buffer->insert(m_buffer->end(), s, s + count); + return *this; + } + + template + requires std::ranges::random_access_range && std::same_as + memory_output_stream& memory_output_stream::write(std::span span) + { + m_buffer->insert(m_buffer->end(), span.begin(), span.end()); + return *this; + } + + template + requires std::ranges::random_access_range && std::same_as + memory_output_stream& memory_output_stream::write(uint8_t value, std::size_t count) + { + m_buffer->insert(m_buffer->end(), count, value); + return *this; + } + + template + requires std::ranges::random_access_range && std::same_as + memory_output_stream& memory_output_stream::put(char value) + { + m_buffer->push_back(static_cast(value)); + return *this; + } + + template + requires std::ranges::random_access_range && std::same_as + void memory_output_stream::reserve(std::size_t size) + { + m_buffer->reserve(size); + } + + template + requires std::ranges::random_access_range && std::same_as + void memory_output_stream::reserve(const std::function& calculate_reserve_size) + { + m_buffer->reserve(calculate_reserve_size()); + } + + template + requires std::ranges::random_access_range && std::same_as + size_t memory_output_stream::size() const + { + return m_buffer->size(); + } +} diff --git a/include/sparrow_ipc/serialize.hpp b/include/sparrow_ipc/serialize.hpp index 1ab8003..4a18e57 100644 --- a/include/sparrow_ipc/serialize.hpp +++ b/include/sparrow_ipc/serialize.hpp @@ -1,12 +1,11 @@ #pragma once -#include #include -#include #include #include "Message_generated.h" +#include "sparrow_ipc/any_output_stream.hpp" #include "sparrow_ipc/config/config.hpp" #include "sparrow_ipc/magic_values.hpp" #include "sparrow_ipc/serialize_utils.hpp" @@ -26,9 +25,7 @@ namespace sparrow_ipc * @tparam R Container type that holds record batches (must support empty(), operator[], begin(), end()) * @param record_batches Collection of record batches to serialize. All batches must have identical * schemas. - * - * @return std::vector Binary serialized data containing schema, record batches, and - * end-of-stream marker. Returns empty vector if input collection is empty. + * @param stream The output stream where the serialized data will be written. * * @throws std::invalid_argument If record batches have inconsistent schemas or if the collection * contains batches that cannot be serialized together. @@ -38,27 +35,60 @@ namespace sparrow_ipc */ template requires std::same_as, sparrow::record_batch> - std::vector serialize(const R& record_batches) + void serialize_record_batches_to_ipc_stream(const R& record_batches, any_output_stream& stream) { if (record_batches.empty()) { - return {}; + return; } + if (!utils::check_record_batches_consistency(record_batches)) { throw std::invalid_argument( "All record batches must have the same schema to be serialized together." ); } - std::vector serialized_schema = serialize_schema_message(record_batches[0]); - std::vector serialized_record_batches = serialize_record_batches_without_schema_message(record_batches); - serialized_schema.insert( - serialized_schema.end(), - std::make_move_iterator(serialized_record_batches.begin()), - std::make_move_iterator(serialized_record_batches.end()) - ); - // End of stream message - serialized_schema.insert(serialized_schema.end(), end_of_stream.begin(), end_of_stream.end()); - return serialized_schema; + serialize_schema_message(record_batches[0], stream); + for (const auto& rb : record_batches) + { + serialize_record_batch(rb, stream); + } + stream.write(end_of_stream); } + + /** + * @brief Serializes a record batch into a binary format following the Arrow IPC specification. + * + * This function converts a sparrow record batch into a serialized byte vector that includes: + * - A continuation marker + * - The record batch message length (4 bytes) + * - The flatbuffer-encoded record batch metadata + * - Padding to align to 8-byte boundaries + * - The record batch body containing the actual data buffers + * + * @param record_batch The sparrow record batch to serialize + * @param stream The output stream where the serialized record batch will be written + * + * @note The output follows Arrow IPC message format with proper alignment and + * includes both metadata and data portions of the record batch + */ + + SPARROW_IPC_API void + serialize_record_batch(const sparrow::record_batch& record_batch, any_output_stream& stream); + + /** + * @brief Serializes a schema message for a record batch into a byte buffer. + * + * This function creates a serialized schema message following the Arrow IPC format. + * The resulting buffer contains: + * 1. Continuation bytes at the beginning + * 2. A 4-byte length prefix indicating the size of the schema message + * 3. The actual FlatBuffer schema message bytes + * 4. Padding bytes to align the total size to 8-byte boundaries + * + * @param record_batch The record batch containing the schema to serialize + * @param stream The output stream where the serialized schema message will be written + */ + SPARROW_IPC_API void + serialize_schema_message(const sparrow::record_batch& record_batch, any_output_stream& stream); } diff --git a/include/sparrow_ipc/serialize_utils.hpp b/include/sparrow_ipc/serialize_utils.hpp index 9ead8ea..ae881a5 100644 --- a/include/sparrow_ipc/serialize_utils.hpp +++ b/include/sparrow_ipc/serialize_utils.hpp @@ -1,14 +1,13 @@ #pragma once -#include #include #include #include #include "Message_generated.h" +#include "sparrow_ipc/any_output_stream.hpp" #include "sparrow_ipc/config/config.hpp" -#include "sparrow_ipc/magic_values.hpp" #include "sparrow_ipc/utils.hpp" namespace sparrow_ipc @@ -21,11 +20,10 @@ namespace sparrow_ipc * The resulting format follows the Arrow IPC specification for schema messages. * * @param record_batch The record batch containing the schema to be serialized - * @return std::vector A byte vector containing the complete serialized schema message - * with continuation bytes, 4-byte length prefix, schema data, and 8-byte alignment padding + * @param stream The output stream where the serialized schema message will be written */ - [[nodiscard]] SPARROW_IPC_API std::vector - serialize_schema_message(const sparrow::record_batch& record_batch); + SPARROW_IPC_API void + serialize_schema_message(const sparrow::record_batch& record_batch, any_output_stream& stream); /** * @brief Serializes a record batch into a binary format following the Arrow IPC specification. @@ -41,236 +39,84 @@ namespace sparrow_ipc * consists of a metadata section followed by a body section containing the actual data. * * @param record_batch The sparrow record batch to be serialized - * @return std::vector A byte vector containing the complete serialized record batch - * in Arrow IPC format, ready for transmission or storage + * @param stream The output stream where the serialized record batch will be written */ - [[nodiscard]] SPARROW_IPC_API std::vector - serialize_record_batch(const sparrow::record_batch& record_batch); - - template - requires std::same_as, sparrow::record_batch> - /** - * @brief Serializes a collection of record batches into a single byte vector. - * - * This function takes a range or container of record batches and serializes each one - * individually, then concatenates all the serialized data into a single output vector. - * The serialization is performed by calling serialize_record_batch() for each record batch - * in the input collection. - * - * @tparam R The type of the record batch container/range (must be iterable) - * @param record_batches A collection of record batches to be serialized - * @return std::vector A byte vector containing the serialized data of all record batches - * - * @note The function uses move iterators to efficiently transfer the serialized data - * from individual record batches to the output vector. - */ - [[nodiscard]] std::vector serialize_record_batches_without_schema_message(const R& record_batches) - { - std::vector output; - for (const auto& record_batch : record_batches) - { - const auto rb_serialized = serialize_record_batch(record_batch); - output.insert( - output.end(), - std::make_move_iterator(rb_serialized.begin()), - std::make_move_iterator(rb_serialized.end()) - ); - } - return output; - } - - /** - * @brief Creates a FlatBuffers vector of KeyValue pairs from ArrowSchema metadata. - * - * This function converts metadata from an ArrowSchema into a FlatBuffers representation - * suitable for serialization. It processes key-value pairs from the schema's metadata - * and creates corresponding FlatBuffers KeyValue objects. - * - * @param builder Reference to the FlatBufferBuilder used for creating FlatBuffers objects - * @param arrow_schema The ArrowSchema containing metadata to be serialized - * - * @return A FlatBuffers offset to a vector of KeyValue pairs. Returns 0 if the schema - * has no metadata (metadata is nullptr). - * - * @note The function reserves memory for the vector based on the metadata size for - * optimal performance. - */ - [[nodiscard]] SPARROW_IPC_API - flatbuffers::Offset>> - create_metadata(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema); - - /** - * @brief Creates a FlatBuffer Field object from an ArrowSchema. - * - * This function converts an ArrowSchema structure into a FlatBuffer Field representation - * suitable for Apache Arrow IPC serialization. It handles the creation of all necessary - * components including field name, type information, metadata, children, and nullable flag. - * - * @param builder Reference to the FlatBufferBuilder used for creating FlatBuffer objects - * @param arrow_schema The ArrowSchema structure containing the field definition to convert - * - * @return A FlatBuffer offset to the created Field object that can be used in further - * FlatBuffer construction operations - * - * @note Dictionary encoding is not currently supported (TODO item) - * @note The function checks the NULLABLE flag from the ArrowSchema flags to determine nullability - */ - [[nodiscard]] SPARROW_IPC_API ::flatbuffers::Offset - create_field(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema); - - /** - * @brief Creates a FlatBuffers vector of Field objects from an ArrowSchema's children. - * - * This function iterates through all children of the given ArrowSchema and converts - * each child to a FlatBuffers Field object. The resulting fields are collected into - * a FlatBuffers vector. - * - * @param builder Reference to the FlatBufferBuilder used for creating FlatBuffers objects - * @param arrow_schema The ArrowSchema containing the children to convert - * - * @return A FlatBuffers offset to a vector of Field objects, or 0 if no children exist - * - * @throws std::invalid_argument If any child pointer in the ArrowSchema is null - * - * @note The function reserves space for all children upfront for performance optimization - * @note Returns 0 (null offset) when the schema has no children, otherwise returns a valid vector offset - */ - [[nodiscard]] SPARROW_IPC_API ::flatbuffers::Offset< - ::flatbuffers::Vector<::flatbuffers::Offset>> - create_children(flatbuffers::FlatBufferBuilder& builder, sparrow::record_batch::column_range columns); + SPARROW_IPC_API void + serialize_record_batch(const sparrow::record_batch& record_batch, any_output_stream& stream); /** - * @brief Creates a FlatBuffers vector of Field objects from a range of columns. - * - * This function iterates through the provided column range, extracts the Arrow schema - * from each column's proxy, and creates corresponding FlatBuffers Field objects. - * The resulting fields are collected into a vector and converted to a FlatBuffers - * vector offset. + * @brief Calculates the total serialized size of a schema message. * - * @param builder Reference to the FlatBuffers builder used for creating the vector - * @param columns Range of columns to process, each containing an Arrow schema proxy + * This function computes the complete size that would be produced by serialize_schema_message(), + * including: + * - Continuation bytes (4 bytes) + * - Message length prefix (4 bytes) + * - FlatBuffer schema message data + * - Padding to 8-byte alignment * - * @return FlatBuffers offset to a vector of Field objects, or 0 if the input range is empty - * - * @note The function reserves space in the children vector based on the column count - * for performance optimization + * @param record_batch The record batch containing the schema to be measured + * @return The total size in bytes that the serialized schema message would occupy */ - [[nodiscard]] SPARROW_IPC_API ::flatbuffers::Offset< - ::flatbuffers::Vector<::flatbuffers::Offset>> - create_children(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema); + [[nodiscard]] SPARROW_IPC_API std::size_t + calculate_schema_message_size(const sparrow::record_batch& record_batch); /** - * @brief Creates a FlatBuffer builder containing a serialized Arrow schema message. - * - * This function constructs an Arrow IPC schema message from a record batch by: - * 1. Creating field definitions from the record batch columns - * 2. Building a Schema flatbuffer with little-endian byte order - * 3. Wrapping the schema in a Message with metadata version V5 - * 4. Finalizing the buffer for serialization + * @brief Calculates the total serialized size of a record batch message. * - * @param record_batch The source record batch containing column definitions - * @return flatbuffers::FlatBufferBuilder A completed FlatBuffer containing the schema message, - * ready for Arrow IPC serialization + * This function computes the complete size that would be produced by serialize_record_batch(), + * including: + * - Continuation bytes (4 bytes) + * - Message length prefix (4 bytes) + * - FlatBuffer record batch metadata + * - Padding to 8-byte alignment after metadata + * - Body data with 8-byte alignment between buffers * - * @note The schema message has zero body length as it contains only metadata - * @note Currently uses little-endian byte order (marked as TODO for configurability) + * @param record_batch The record batch to be measured + * @return The total size in bytes that the serialized record batch would occupy */ - [[nodiscard]] SPARROW_IPC_API flatbuffers::FlatBufferBuilder - get_schema_message_builder(const sparrow::record_batch& record_batch); + [[nodiscard]] SPARROW_IPC_API std::size_t + calculate_record_batch_message_size(const sparrow::record_batch& record_batch); /** - * @brief Serializes a schema message for a record batch into a byte buffer. + * @brief Calculates the total serialized size for a collection of record batches. * - * This function creates a serialized schema message following the Arrow IPC format. - * The resulting buffer contains: - * 1. Continuation bytes at the beginning - * 2. A 4-byte length prefix indicating the size of the schema message - * 3. The actual FlatBuffer schema message bytes - * 4. Padding bytes to align the total size to 8-byte boundaries + * This function computes the complete size that would be produced by serializing + * a schema message followed by all record batch messages in the collection. * - * @param record_batch The record batch containing the schema to serialize - * @return std::vector A byte buffer containing the complete serialized schema message + * @tparam R Range type containing sparrow::record_batch objects + * @param record_batches Collection of record batches to be measured + * @return The total size in bytes for the complete serialized output + * @throws std::invalid_argument if record batches have inconsistent schemas */ - [[nodiscard]] SPARROW_IPC_API std::vector - serialize_schema_message(const sparrow::record_batch& record_batch); + template + requires std::same_as, sparrow::record_batch> + [[nodiscard]] std::size_t calculate_total_serialized_size(const R& record_batches) + { + if (record_batches.empty()) + { + return 0; + } - /** - * @brief Recursively fills a vector of FieldNode objects from an arrow_proxy and its children. - * - * This function creates FieldNode objects containing length and null count information - * from the given arrow_proxy and recursively processes all its children, appending - * them to the provided nodes vector in depth-first order. - * - * @param arrow_proxy The arrow proxy object containing array metadata (length, null_count) - * and potential child arrays - * @param nodes Reference to a vector that will be populated with FieldNode objects. - * Each FieldNode contains the length and null count of the corresponding array. - * - * @note The function reserves space in the nodes vector to optimize memory allocation - * when processing children arrays. - * @note The traversal order is depth-first, with parent nodes added before their children. - */ - SPARROW_IPC_API void fill_fieldnodes( - const sparrow::arrow_proxy& arrow_proxy, - std::vector& nodes - ); + if (!utils::check_record_batches_consistency(record_batches)) + { + throw std::invalid_argument("Record batches have inconsistent schemas"); + } - /** - * @brief Creates a vector of Apache Arrow FieldNode objects from a record batch. - * - * This function iterates through all columns in the provided record batch and - * generates corresponding FieldNode flatbuffer objects. Each column's arrow proxy - * is used to populate the field nodes vector through the fill_fieldnodes function. - * - * @param record_batch The sparrow record batch containing columns to process - * @return std::vector Vector of FieldNode - * objects representing the structure and metadata of each column - */ - [[nodiscard]] SPARROW_IPC_API std::vector - create_fieldnodes(const sparrow::record_batch& record_batch); + // Calculate schema message size (only once) + auto it = std::ranges::begin(record_batches); + std::size_t total_size = calculate_schema_message_size(*it); - /** - * @brief Recursively fills a vector of FlatBuffer Buffer objects with buffer information from an Arrow - * proxy. - * - * This function traverses an Arrow proxy structure and creates FlatBuffer Buffer entries for each buffer - * found in the proxy and its children. The buffers are processed in a depth-first manner, first handling - * the buffers of the current proxy, then recursively processing all child proxies. - * - * @param arrow_proxy The Arrow proxy object containing buffers and potential child proxies to process - * @param flatbuf_buffers Vector of FlatBuffer Buffer objects to be populated with buffer information - * @param offset Reference to the current byte offset, updated as buffers are processed and aligned to - * 8-byte boundaries - * - * @note The offset is automatically aligned to 8-byte boundaries using utils::align_to_8() for each - * buffer - * @note This function modifies both the flatbuf_buffers vector and the offset parameter - */ - SPARROW_IPC_API void fill_buffers( - const sparrow::arrow_proxy& arrow_proxy, - std::vector& flatbuf_buffers, - int64_t& offset - ); + // Calculate record batch message sizes + for (const auto& record_batch : record_batches) + { + total_size += calculate_record_batch_message_size(record_batch); + } - /** - * @brief Extracts buffer information from a record batch for serialization. - * - * This function iterates through all columns in the provided record batch and - * collects their buffer information into a vector of Arrow FlatBuffer Buffer objects. - * The buffers are processed sequentially with cumulative offset tracking. - * - * @param record_batch The sparrow record batch containing columns to extract buffers from - * @return std::vector A vector containing all buffer - * descriptors from the record batch columns, with properly calculated offsets - * - * @note This function relies on the fill_buffers helper function to process individual - * column buffers and maintain offset consistency across all buffers. - */ - [[nodiscard]] SPARROW_IPC_API std::vector - get_buffers(const sparrow::record_batch& record_batch); + return total_size; + } /** - * @brief Fills the body vector with buffer data from an arrow proxy and its children. + * @brief Fills the body vector with serialized data from an arrow proxy and its children. * * This function recursively processes an arrow proxy by: * 1. Iterating through all buffers in the proxy and appending their data to the body vector @@ -282,9 +128,9 @@ namespace sparrow_ipc * format compliance. * * @param arrow_proxy The arrow proxy containing buffers and potential child proxies to serialize - * @param body Reference to the vector where the serialized buffer data will be appended + * @param stream The output stream where the serialized body data will be written */ - SPARROW_IPC_API void fill_body(const sparrow::arrow_proxy& arrow_proxy, std::vector& body); + SPARROW_IPC_API void fill_body(const sparrow::arrow_proxy& arrow_proxy, any_output_stream& stream); /** * @brief Generates a serialized body from a record batch. @@ -294,9 +140,9 @@ namespace sparrow_ipc * single byte vector that forms the body of the serialized data. * * @param record_batch The record batch containing columns to be serialized - * @return std::vector A byte vector containing the serialized body data + * @param stream The output stream where the serialized body will be written */ - [[nodiscard]] SPARROW_IPC_API std::vector generate_body(const sparrow::record_batch& record_batch); + SPARROW_IPC_API void generate_body(const sparrow::record_batch& record_batch, any_output_stream& stream); /** * @brief Calculates the total size of the body section for an Arrow array. @@ -322,60 +168,5 @@ namespace sparrow_ipc */ [[nodiscard]] SPARROW_IPC_API int64_t calculate_body_size(const sparrow::record_batch& record_batch); - /** - * @brief Creates a FlatBuffer message containing a serialized Apache Arrow RecordBatch. - * - * This function builds a complete Arrow IPC message by serializing a record batch - * along with its metadata (field nodes and buffer information) into a FlatBuffer - * format that conforms to the Arrow IPC specification. - * - * @param record_batch The source record batch containing the data to be serialized - * @param nodes Vector of field nodes describing the structure and null counts of columns - * @param buffers Vector of buffer descriptors containing offset and length information - * for the data buffers - * - * @return A FlatBufferBuilder containing the complete serialized message ready for - * transmission or storage. The builder is finished and ready to be accessed - * via GetBufferPointer() and GetSize(). - * - * @note The returned message uses Arrow IPC format version V5 - * @note Compression and variadic buffer counts are not currently implemented (set to 0) - * @note The body size is automatically calculated based on the record batch contents - */ - [[nodiscard]] SPARROW_IPC_API flatbuffers::FlatBufferBuilder get_record_batch_message_builder( - const sparrow::record_batch& record_batch, - const std::vector& nodes, - const std::vector& buffers - ); - - /** - * @brief Serializes a record batch into a binary format following the Arrow IPC specification. - * - * This function converts a sparrow record batch into a serialized byte vector that includes: - * - A continuation marker - * - The record batch message length (4 bytes) - * - The flatbuffer-encoded record batch metadata - * - Padding to align to 8-byte boundaries - * - The record batch body containing the actual data buffers - * - * @param record_batch The sparrow record batch to serialize - * @return std::vector A byte vector containing the serialized record batch - * in Arrow IPC format, ready for transmission or storage - * - * @note The output follows Arrow IPC message format with proper alignment and - * includes both metadata and data portions of the record batch - */ - [[nodiscard]] SPARROW_IPC_API std::vector - serialize_record_batch(const sparrow::record_batch& record_batch); - - /** - * @brief Adds padding bytes to a buffer to ensure 8-byte alignment. - * - * This function appends zero bytes to the end of the provided buffer until - * its size is a multiple of 8. This is often required for proper memory - * alignment in binary formats such as Apache Arrow IPC. - * - * @param buffer The byte vector to which padding will be added - */ - void add_padding(std::vector& buffer); + SPARROW_IPC_API std::vector get_column_dtypes(const sparrow::record_batch& rb); } diff --git a/include/sparrow_ipc/serializer.hpp b/include/sparrow_ipc/serializer.hpp new file mode 100644 index 0000000..9a8c1e0 --- /dev/null +++ b/include/sparrow_ipc/serializer.hpp @@ -0,0 +1,216 @@ +#include +#include + +#include + +#include "sparrow_ipc/any_output_stream.hpp" +#include "sparrow_ipc/serialize_utils.hpp" + +namespace sparrow_ipc +{ + /** + * @brief A class for serializing Apache Arrow record batches to an output stream. + * + * The serializer class provides functionality to serialize single or multiple record batches + * into a binary format suitable for storage or transmission. It ensures schema consistency + * across multiple record batches and optimizes memory allocation by pre-calculating required + * buffer sizes. + * + * @details The serializer supports two main usage patterns: + * 1. Construction with a collection of record batches for batch serialization + * 2. Construction with a single record batch followed by incremental appends + * + * The class validates that all record batches have consistent schemas and throws + * std::invalid_argument if inconsistencies are detected or if an empty collection + * is provided. + * + * Memory efficiency is achieved through: + * - Pre-calculation of total serialization size + * - Stream reservation to minimize memory reallocations + * - Lazy evaluation of size calculations using lambda functions + */ + class SPARROW_IPC_API serializer + { + public: + + /** + * @brief Constructs a serializer object with a reference to a stream. + * + * @tparam TStream The type of the stream to be used for serialization. + * @param stream Reference to the stream object that will be used for serialization operations. + * The serializer stores a pointer to this stream for later use. + */ + template + serializer(TStream& stream) + : m_stream(stream) + { + } + + /** + * @brief Destructor for the serializer. + * + * Ensures proper cleanup by calling end() if the serializer has not been + * explicitly ended. This guarantees that any pending data is flushed and + * resources are properly released before the object is destroyed. + */ + ~serializer(); + + /** + * Writes a record batch to the serializer. + * + * @param rb The record batch to write to the serializer + */ + void write(const sparrow::record_batch& rb); + + /** + * @brief Writes a collection of record batches to the stream. + * + * This method efficiently adds multiple record batches to the serialization stream + * by first calculating the total required size and reserving memory space to minimize + * reallocations during the append operations. + * + * @tparam R The type of the record batch collection (must be iterable) + * @param record_batches A collection of record batches to append to the stream + * + * The method performs the following operations: + * 1. Calculates the total size needed for all record batches + * 2. Reserves the required memory space in the stream + * 3. Iterates through each record batch and adds it to the stream + */ + template + requires std::same_as, sparrow::record_batch> + void write(const R& record_batches) + { + if (m_ended) + { + throw std::runtime_error("Cannot append to a serializer that has been ended"); + } + + const auto reserve_function = [&record_batches, this]() + { + return std::accumulate( + record_batches.begin(), + record_batches.end(), + m_stream.size(), + [this](size_t acc, const sparrow::record_batch& rb) + { + return acc + calculate_record_batch_message_size(rb); + } + ) + + (m_schema_received ? 0 : calculate_schema_message_size(*record_batches.begin())); + }; + + m_stream.reserve(reserve_function); + + if (!m_schema_received) + { + m_schema_received = true; + m_dtypes = get_column_dtypes(*record_batches.begin()); + serialize_schema_message(*record_batches.begin(), m_stream); + } + + for (const auto& rb : record_batches) + { + if (get_column_dtypes(rb) != m_dtypes) + { + throw std::invalid_argument("Record batch schema does not match serializer schema"); + } + serialize_record_batch(rb, m_stream); + } + } + + /** + * @brief Appends a record batch using the stream insertion operator. + * + * This operator provides a convenient stream-like interface for appending + * record batches to the serializer. It delegates to the append() method + * and returns a reference to the serializer to enable method chaining. + * + * @param rb The record batch to append to the serializer + * @return A reference to this serializer for method chaining + * @throws std::invalid_argument if the record batch schema doesn't match + * @throws std::runtime_error if the serializer has been ended + * + * @example + * serializer ser(initial_batch, stream); + * ser << batch1 << batch2 << batch3; + */ + serializer& operator<<(const sparrow::record_batch& rb) + { + write(rb); + return *this; + } + + /** + * @brief Appends a range of record batches using the stream insertion operator. + * + * This operator provides a convenient stream-like interface for appending + * multiple record batches to the serializer at once. It delegates to the + * append() method and returns a reference to the serializer to enable method chaining. + * + * @tparam R The type of the record batch collection (must be an input range) + * @param record_batches A range of record batches to append to the serializer + * @return A reference to this serializer for method chaining + * @throws std::invalid_argument if any record batch schema doesn't match + * @throws std::runtime_error if the serializer has been ended + * + * @example + * serializer ser(initial_batch, stream); + * std::vector batches = {batch1, batch2, batch3}; + * ser << batches << another_batch; + */ + template + requires std::same_as, sparrow::record_batch> + serializer& operator<<(const R& record_batches) + { + write(record_batches); + return *this; + } + + /** + * @brief Stream manipulator operator for functions like end_stream. + * + * This operator enables the use of manipulator functions (similar to std::endl) + * with the serializer. It accepts a function pointer that takes and returns + * a reference to a serializer. + * + * @param manip A function pointer to a manipulator function + * @return A reference to this serializer for method chaining + * + * @example + * serializer ser(stream); + * ser << batch1 << batch2 << end_stream; + */ + serializer& operator<<(serializer& (*manip)(serializer&)) + { + return manip(*this); + } + + /** + * @brief Finalizes the serialization process by writing end-of-stream marker. + * + * This method writes an end-of-stream marker to the output stream and flushes + * any buffered data. It can be called multiple times safely as it tracks + * whether the stream has already been ended to prevent duplicate operations. + * + * @note This method is idempotent - calling it multiple times has no additional effect. + * @post After calling this method, m_ended will be set to true. + */ + void end(); + + private: + + static std::vector get_column_dtypes(const sparrow::record_batch& rb); + + bool m_schema_received{false}; + std::vector m_dtypes; + any_output_stream m_stream; + bool m_ended{false}; + }; + + inline serializer& end_stream(serializer& serializer) + { + serializer.end(); + return serializer; + } +} \ No newline at end of file diff --git a/include/sparrow_ipc/utils.hpp b/include/sparrow_ipc/utils.hpp index 0c80f9c..63f1fb8 100644 --- a/include/sparrow_ipc/utils.hpp +++ b/include/sparrow_ipc/utils.hpp @@ -3,22 +3,15 @@ #include #include #include -#include #include -#include "Schema_generated.h" #include "sparrow_ipc/config/config.hpp" namespace sparrow_ipc::utils { // Aligns a value to the next multiple of 8, as required by the Arrow IPC format for message bodies - SPARROW_IPC_API int64_t align_to_8(const int64_t n); - - // Creates a Flatbuffers type from a format string - // This function maps a sparrow data type to the corresponding Flatbuffers type - SPARROW_IPC_API std::pair> - get_flatbuffer_type(flatbuffers::FlatBufferBuilder& builder, std::string_view format_str); + SPARROW_IPC_API size_t align_to_8(const size_t n); /** * @brief Checks if all record batches in a collection have consistent structure. @@ -39,7 +32,7 @@ namespace sparrow_ipc::utils requires std::same_as, sparrow::record_batch> bool check_record_batches_consistency(const R& record_batches) { - if (record_batches.empty()) + if (record_batches.empty() || record_batches.size() == 1) { return true; } @@ -67,5 +60,8 @@ namespace sparrow_ipc::utils return true; } + // Parse the format string + // The format string is expected to be "w:size", "+w:size", "d:precision,scale", etc + std::optional parse_format(std::string_view format_str, std::string_view sep); // size_t calculate_output_serialized_size(const sparrow::record_batch& record_batch); } diff --git a/src/any_output_stream.cpp b/src/any_output_stream.cpp new file mode 100644 index 0000000..b33879f --- /dev/null +++ b/src/any_output_stream.cpp @@ -0,0 +1,35 @@ +#include "sparrow_ipc/any_output_stream.hpp" + +namespace sparrow_ipc +{ + void any_output_stream::write(std::span span) + { + m_impl->write(span); + } + + void any_output_stream::write(uint8_t value, std::size_t count) + { + m_impl->write(value, count); + } + + void any_output_stream::add_padding() + { + m_impl->add_padding(); + } + + void any_output_stream::reserve(std::size_t size) + { + m_impl->reserve(size); + } + + void any_output_stream::reserve(const std::function& calculate_reserve_size) + { + m_impl->reserve(calculate_reserve_size); + } + + size_t any_output_stream::size() const + { + return m_impl->size(); + } + +} // namespace sparrow_ipc diff --git a/src/chunk_memory_serializer.cpp b/src/chunk_memory_serializer.cpp new file mode 100644 index 0000000..cbdfb4a --- /dev/null +++ b/src/chunk_memory_serializer.cpp @@ -0,0 +1,29 @@ +#include "sparrow_ipc/chunk_memory_serializer.hpp" + +#include "sparrow_ipc/any_output_stream.hpp" +#include "sparrow_ipc/serialize.hpp" +#include "sparrow_ipc/serialize_utils.hpp" + +namespace sparrow_ipc +{ + chunk_serializer::chunk_serializer(chunked_memory_output_stream>>& stream) + : m_pstream(&stream) + { + } + + void chunk_serializer::write(const sparrow::record_batch& rb) + { + write(std::ranges::single_view(rb)); + } + + void chunk_serializer::end() + { + if (m_ended) + { + return; + } + std::vector buffer(end_of_stream.begin(), end_of_stream.end()); + m_pstream->write(std::move(buffer)); + m_ended = true; + } +} diff --git a/src/flatbuffer_utils.cpp b/src/flatbuffer_utils.cpp new file mode 100644 index 0000000..91d8306 --- /dev/null +++ b/src/flatbuffer_utils.cpp @@ -0,0 +1,586 @@ +#include "sparrow_ipc/flatbuffer_utils.hpp" + +#include "sparrow_ipc/serialize_utils.hpp" +#include "sparrow_ipc/utils.hpp" + +namespace sparrow_ipc +{ + std::pair> + get_flatbuffer_type(flatbuffers::FlatBufferBuilder& builder, std::string_view format_str) + { + const auto type = sparrow::format_to_data_type(format_str); + switch (type) + { + case sparrow::data_type::NA: + { + const auto null_type = org::apache::arrow::flatbuf::CreateNull(builder); + return {org::apache::arrow::flatbuf::Type::Null, null_type.Union()}; + } + case sparrow::data_type::BOOL: + { + const auto bool_type = org::apache::arrow::flatbuf::CreateBool(builder); + return {org::apache::arrow::flatbuf::Type::Bool, bool_type.Union()}; + } + case sparrow::data_type::UINT8: + { + const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 8, false); + return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; + } + case sparrow::data_type::INT8: + { + const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 8, true); + return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; + } + case sparrow::data_type::UINT16: + { + const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 16, false); + return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; + } + case sparrow::data_type::INT16: + { + const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 16, true); + return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; + } + case sparrow::data_type::UINT32: + { + const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 32, false); + return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; + } + case sparrow::data_type::INT32: + { + const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 32, true); + return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; + } + case sparrow::data_type::UINT64: + { + const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 64, false); + return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; + } + case sparrow::data_type::INT64: + { + const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 64, true); + return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; + } + case sparrow::data_type::HALF_FLOAT: + { + const auto fp_type = org::apache::arrow::flatbuf::CreateFloatingPoint( + builder, + org::apache::arrow::flatbuf::Precision::HALF + ); + return {org::apache::arrow::flatbuf::Type::FloatingPoint, fp_type.Union()}; + } + case sparrow::data_type::FLOAT: + { + const auto fp_type = org::apache::arrow::flatbuf::CreateFloatingPoint( + builder, + org::apache::arrow::flatbuf::Precision::SINGLE + ); + return {org::apache::arrow::flatbuf::Type::FloatingPoint, fp_type.Union()}; + } + case sparrow::data_type::DOUBLE: + { + const auto fp_type = org::apache::arrow::flatbuf::CreateFloatingPoint( + builder, + org::apache::arrow::flatbuf::Precision::DOUBLE + ); + return {org::apache::arrow::flatbuf::Type::FloatingPoint, fp_type.Union()}; + } + case sparrow::data_type::STRING: + { + const auto string_type = org::apache::arrow::flatbuf::CreateUtf8(builder); + return {org::apache::arrow::flatbuf::Type::Utf8, string_type.Union()}; + } + case sparrow::data_type::LARGE_STRING: + { + const auto large_string_type = org::apache::arrow::flatbuf::CreateLargeUtf8(builder); + return {org::apache::arrow::flatbuf::Type::LargeUtf8, large_string_type.Union()}; + } + case sparrow::data_type::BINARY: + { + const auto binary_type = org::apache::arrow::flatbuf::CreateBinary(builder); + return {org::apache::arrow::flatbuf::Type::Binary, binary_type.Union()}; + } + case sparrow::data_type::LARGE_BINARY: + { + const auto large_binary_type = org::apache::arrow::flatbuf::CreateLargeBinary(builder); + return {org::apache::arrow::flatbuf::Type::LargeBinary, large_binary_type.Union()}; + } + case sparrow::data_type::STRING_VIEW: + { + const auto string_view_type = org::apache::arrow::flatbuf::CreateUtf8View(builder); + return {org::apache::arrow::flatbuf::Type::Utf8View, string_view_type.Union()}; + } + case sparrow::data_type::BINARY_VIEW: + { + const auto binary_view_type = org::apache::arrow::flatbuf::CreateBinaryView(builder); + return {org::apache::arrow::flatbuf::Type::BinaryView, binary_view_type.Union()}; + } + case sparrow::data_type::DATE_DAYS: + { + const auto date_type = org::apache::arrow::flatbuf::CreateDate( + builder, + org::apache::arrow::flatbuf::DateUnit::DAY + ); + return {org::apache::arrow::flatbuf::Type::Date, date_type.Union()}; + } + case sparrow::data_type::DATE_MILLISECONDS: + { + const auto date_type = org::apache::arrow::flatbuf::CreateDate( + builder, + org::apache::arrow::flatbuf::DateUnit::MILLISECOND + ); + return {org::apache::arrow::flatbuf::Type::Date, date_type.Union()}; + } + case sparrow::data_type::TIMESTAMP_SECONDS: + { + const auto timestamp_type = org::apache::arrow::flatbuf::CreateTimestamp( + builder, + org::apache::arrow::flatbuf::TimeUnit::SECOND + ); + return {org::apache::arrow::flatbuf::Type::Timestamp, timestamp_type.Union()}; + } + case sparrow::data_type::TIMESTAMP_MILLISECONDS: + { + const auto timestamp_type = org::apache::arrow::flatbuf::CreateTimestamp( + builder, + org::apache::arrow::flatbuf::TimeUnit::MILLISECOND + ); + return {org::apache::arrow::flatbuf::Type::Timestamp, timestamp_type.Union()}; + } + case sparrow::data_type::TIMESTAMP_MICROSECONDS: + { + const auto timestamp_type = org::apache::arrow::flatbuf::CreateTimestamp( + builder, + org::apache::arrow::flatbuf::TimeUnit::MICROSECOND + ); + return {org::apache::arrow::flatbuf::Type::Timestamp, timestamp_type.Union()}; + } + case sparrow::data_type::TIMESTAMP_NANOSECONDS: + { + const auto timestamp_type = org::apache::arrow::flatbuf::CreateTimestamp( + builder, + org::apache::arrow::flatbuf::TimeUnit::NANOSECOND + ); + return {org::apache::arrow::flatbuf::Type::Timestamp, timestamp_type.Union()}; + } + case sparrow::data_type::DURATION_SECONDS: + { + const auto duration_type = org::apache::arrow::flatbuf::CreateDuration( + builder, + org::apache::arrow::flatbuf::TimeUnit::SECOND + ); + return {org::apache::arrow::flatbuf::Type::Duration, duration_type.Union()}; + } + case sparrow::data_type::DURATION_MILLISECONDS: + { + const auto duration_type = org::apache::arrow::flatbuf::CreateDuration( + builder, + org::apache::arrow::flatbuf::TimeUnit::MILLISECOND + ); + return {org::apache::arrow::flatbuf::Type::Duration, duration_type.Union()}; + } + case sparrow::data_type::DURATION_MICROSECONDS: + { + const auto duration_type = org::apache::arrow::flatbuf::CreateDuration( + builder, + org::apache::arrow::flatbuf::TimeUnit::MICROSECOND + ); + return {org::apache::arrow::flatbuf::Type::Duration, duration_type.Union()}; + } + case sparrow::data_type::DURATION_NANOSECONDS: + { + const auto duration_type = org::apache::arrow::flatbuf::CreateDuration( + builder, + org::apache::arrow::flatbuf::TimeUnit::NANOSECOND + ); + return {org::apache::arrow::flatbuf::Type::Duration, duration_type.Union()}; + } + case sparrow::data_type::INTERVAL_MONTHS: + { + const auto interval_type = org::apache::arrow::flatbuf::CreateInterval( + builder, + org::apache::arrow::flatbuf::IntervalUnit::YEAR_MONTH + ); + return {org::apache::arrow::flatbuf::Type::Interval, interval_type.Union()}; + } + case sparrow::data_type::INTERVAL_DAYS_TIME: + { + const auto interval_type = org::apache::arrow::flatbuf::CreateInterval( + builder, + org::apache::arrow::flatbuf::IntervalUnit::DAY_TIME + ); + return {org::apache::arrow::flatbuf::Type::Interval, interval_type.Union()}; + } + case sparrow::data_type::INTERVAL_MONTHS_DAYS_NANOSECONDS: + { + const auto interval_type = org::apache::arrow::flatbuf::CreateInterval( + builder, + org::apache::arrow::flatbuf::IntervalUnit::MONTH_DAY_NANO + ); + return {org::apache::arrow::flatbuf::Type::Interval, interval_type.Union()}; + } + case sparrow::data_type::TIME_SECONDS: + { + const auto time_type = org::apache::arrow::flatbuf::CreateTime( + builder, + org::apache::arrow::flatbuf::TimeUnit::SECOND, + 32 + ); + return {org::apache::arrow::flatbuf::Type::Time, time_type.Union()}; + } + case sparrow::data_type::TIME_MILLISECONDS: + { + const auto time_type = org::apache::arrow::flatbuf::CreateTime( + builder, + org::apache::arrow::flatbuf::TimeUnit::MILLISECOND, + 32 + ); + return {org::apache::arrow::flatbuf::Type::Time, time_type.Union()}; + } + case sparrow::data_type::TIME_MICROSECONDS: + { + const auto time_type = org::apache::arrow::flatbuf::CreateTime( + builder, + org::apache::arrow::flatbuf::TimeUnit::MICROSECOND, + 64 + ); + return {org::apache::arrow::flatbuf::Type::Time, time_type.Union()}; + } + case sparrow::data_type::TIME_NANOSECONDS: + { + const auto time_type = org::apache::arrow::flatbuf::CreateTime( + builder, + org::apache::arrow::flatbuf::TimeUnit::NANOSECOND, + 64 + ); + return {org::apache::arrow::flatbuf::Type::Time, time_type.Union()}; + } + case sparrow::data_type::LIST: + { + const auto list_type = org::apache::arrow::flatbuf::CreateList(builder); + return {org::apache::arrow::flatbuf::Type::List, list_type.Union()}; + } + case sparrow::data_type::LARGE_LIST: + { + const auto large_list_type = org::apache::arrow::flatbuf::CreateLargeList(builder); + return {org::apache::arrow::flatbuf::Type::LargeList, large_list_type.Union()}; + } + case sparrow::data_type::LIST_VIEW: + { + const auto list_view_type = org::apache::arrow::flatbuf::CreateListView(builder); + return {org::apache::arrow::flatbuf::Type::ListView, list_view_type.Union()}; + } + case sparrow::data_type::LARGE_LIST_VIEW: + { + const auto large_list_view_type = org::apache::arrow::flatbuf::CreateLargeListView(builder); + return {org::apache::arrow::flatbuf::Type::LargeListView, large_list_view_type.Union()}; + } + case sparrow::data_type::FIXED_SIZED_LIST: + { + // FixedSizeList requires listSize. We need to parse the format_str. + // Format: "+w:size" + const auto list_size = utils::parse_format(format_str, ":"); + if (!list_size.has_value()) + { + throw std::runtime_error( + "Failed to parse FixedSizeList size from format string: " + std::string(format_str) + ); + } + + const auto fixed_size_list_type = org::apache::arrow::flatbuf::CreateFixedSizeList( + builder, + list_size.value() + ); + return {org::apache::arrow::flatbuf::Type::FixedSizeList, fixed_size_list_type.Union()}; + } + case sparrow::data_type::STRUCT: + { + const auto struct_type = org::apache::arrow::flatbuf::CreateStruct_(builder); + return {org::apache::arrow::flatbuf::Type::Struct_, struct_type.Union()}; + } + case sparrow::data_type::MAP: + { + // not sorted keys + const auto map_type = org::apache::arrow::flatbuf::CreateMap(builder, false); + return {org::apache::arrow::flatbuf::Type::Map, map_type.Union()}; + } + case sparrow::data_type::DENSE_UNION: + { + const auto union_type = org::apache::arrow::flatbuf::CreateUnion( + builder, + org::apache::arrow::flatbuf::UnionMode::Dense, + 0 + ); + return {org::apache::arrow::flatbuf::Type::Union, union_type.Union()}; + } + case sparrow::data_type::SPARSE_UNION: + { + const auto union_type = org::apache::arrow::flatbuf::CreateUnion( + builder, + org::apache::arrow::flatbuf::UnionMode::Sparse, + 0 + ); + return {org::apache::arrow::flatbuf::Type::Union, union_type.Union()}; + } + case sparrow::data_type::RUN_ENCODED: + { + const auto run_end_encoded_type = org::apache::arrow::flatbuf::CreateRunEndEncoded(builder); + return {org::apache::arrow::flatbuf::Type::RunEndEncoded, run_end_encoded_type.Union()}; + } + case sparrow::data_type::DECIMAL32: + { + return get_flatbuffer_decimal_type(builder, format_str, 32); + } + case sparrow::data_type::DECIMAL64: + { + return get_flatbuffer_decimal_type(builder, format_str, 64); + } + case sparrow::data_type::DECIMAL128: + { + return get_flatbuffer_decimal_type(builder, format_str, 128); + } + case sparrow::data_type::DECIMAL256: + { + return get_flatbuffer_decimal_type(builder, format_str, 256); + } + case sparrow::data_type::FIXED_WIDTH_BINARY: + { + // FixedSizeBinary requires byteWidth. We need to parse the format_str. + // Format: "w:size" + const auto byte_width = utils::parse_format(format_str, ":"); + if (!byte_width.has_value()) + { + throw std::runtime_error( + "Failed to parse FixedWidthBinary size from format string: " + std::string(format_str) + ); + } + + const auto fixed_width_binary_type = org::apache::arrow::flatbuf::CreateFixedSizeBinary( + builder, + byte_width.value() + ); + return {org::apache::arrow::flatbuf::Type::FixedSizeBinary, fixed_width_binary_type.Union()}; + } + default: + { + throw std::runtime_error("Unsupported data type for serialization"); + } + } + } + + // Creates a Flatbuffers Decimal type from a format string + // The format string is expected to be in the format "d:precision,scale" + std::pair> get_flatbuffer_decimal_type( + flatbuffers::FlatBufferBuilder& builder, + std::string_view format_str, + const int32_t bitWidth + ) + { + // Decimal requires precision and scale. We need to parse the format_str. + // Format: "d:precision,scale" + const auto scale = utils::parse_format(format_str, ","); + if (!scale.has_value()) + { + throw std::runtime_error( + "Failed to parse Decimal " + std::to_string(bitWidth) + + " scale from format string: " + std::string(format_str) + ); + } + const size_t comma_pos = format_str.find(','); + const auto precision = utils::parse_format(format_str.substr(0, comma_pos), ":"); + if (!precision.has_value()) + { + throw std::runtime_error( + "Failed to parse Decimal " + std::to_string(bitWidth) + + " precision from format string: " + std::string(format_str) + ); + } + const auto decimal_type = org::apache::arrow::flatbuf::CreateDecimal( + builder, + precision.value(), + scale.value(), + bitWidth + ); + return {org::apache::arrow::flatbuf::Type::Decimal, decimal_type.Union()}; + } + + flatbuffers::Offset>> + create_metadata(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema) + { + if (arrow_schema.metadata == nullptr) + { + return 0; + } + + const auto metadata_view = sparrow::key_value_view(arrow_schema.metadata); + std::vector> kv_offsets; + kv_offsets.reserve(metadata_view.size()); + for (const auto& [key, value] : metadata_view) + { + const auto key_offset = builder.CreateString(std::string(key)); + const auto value_offset = builder.CreateString(std::string(value)); + kv_offsets.push_back(org::apache::arrow::flatbuf::CreateKeyValue(builder, key_offset, value_offset)); + } + return builder.CreateVector(kv_offsets); + } + + ::flatbuffers::Offset + create_field(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema) + { + flatbuffers::Offset fb_name_offset = (arrow_schema.name == nullptr) + ? 0 + : builder.CreateString(arrow_schema.name); + const auto [type_enum, type_offset] = get_flatbuffer_type(builder, arrow_schema.format); + auto fb_metadata_offset = create_metadata(builder, arrow_schema); + const auto children = create_children(builder, arrow_schema); + const auto fb_field = org::apache::arrow::flatbuf::CreateField( + builder, + fb_name_offset, + (arrow_schema.flags & static_cast(sparrow::ArrowFlag::NULLABLE)) != 0, + type_enum, + type_offset, + 0, // TODO: support dictionary + children, + fb_metadata_offset + ); + return fb_field; + } + + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> + create_children(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema) + { + std::vector> children_vec; + children_vec.reserve(arrow_schema.n_children); + for (size_t i = 0; i < arrow_schema.n_children; ++i) + { + if (arrow_schema.children[i] == nullptr) + { + throw std::invalid_argument("ArrowSchema has null child pointer"); + } + const auto& child = *arrow_schema.children[i]; + flatbuffers::Offset field = create_field(builder, child); + children_vec.emplace_back(field); + } + return children_vec.empty() ? 0 : builder.CreateVector(children_vec); + } + + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> + create_children(flatbuffers::FlatBufferBuilder& builder, sparrow::record_batch::column_range columns) + { + std::vector> children_vec; + children_vec.reserve(columns.size()); + for (const auto& column : columns) + { + const auto& arrow_schema = sparrow::detail::array_access::get_arrow_proxy(column).schema(); + flatbuffers::Offset field = create_field(builder, arrow_schema); + children_vec.emplace_back(field); + } + return children_vec.empty() ? 0 : builder.CreateVector(children_vec); + } + + flatbuffers::FlatBufferBuilder get_schema_message_builder(const sparrow::record_batch& record_batch) + { + flatbuffers::FlatBufferBuilder schema_builder; + const auto fields_vec = create_children(schema_builder, record_batch.columns()); + const auto schema_offset = org::apache::arrow::flatbuf::CreateSchema( + schema_builder, + org::apache::arrow::flatbuf::Endianness::Little, // TODO: make configurable + fields_vec + ); + const auto schema_message_offset = org::apache::arrow::flatbuf::CreateMessage( + schema_builder, + org::apache::arrow::flatbuf::MetadataVersion::V5, + org::apache::arrow::flatbuf::MessageHeader::Schema, + schema_offset.Union(), + 0, // body length is 0 for schema messages + 0 // custom metadata + ); + schema_builder.Finish(schema_message_offset); + return schema_builder; + } + + void fill_fieldnodes( + const sparrow::arrow_proxy& arrow_proxy, + std::vector& nodes + ) + { + nodes.emplace_back(arrow_proxy.length(), arrow_proxy.null_count()); + nodes.reserve(nodes.size() + arrow_proxy.n_children()); + for (const auto& child : arrow_proxy.children()) + { + fill_fieldnodes(child, nodes); + } + } + + std::vector + create_fieldnodes(const sparrow::record_batch& record_batch) + { + std::vector nodes; + nodes.reserve(record_batch.columns().size()); + for (const auto& column : record_batch.columns()) + { + fill_fieldnodes(sparrow::detail::array_access::get_arrow_proxy(column), nodes); + } + return nodes; + } + + void fill_buffers( + const sparrow::arrow_proxy& arrow_proxy, + std::vector& flatbuf_buffers, + int64_t& offset + ) + { + const auto& buffers = arrow_proxy.buffers(); + for (const auto& buffer : buffers) + { + int64_t size = static_cast(buffer.size()); + flatbuf_buffers.emplace_back(offset, size); + offset += utils::align_to_8(size); + } + for (const auto& child : arrow_proxy.children()) + { + fill_buffers(child, flatbuf_buffers, offset); + } + } + + std::vector get_buffers(const sparrow::record_batch& record_batch) + { + std::vector buffers; + std::int64_t offset = 0; + for (const auto& column : record_batch.columns()) + { + const auto& arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(column); + fill_buffers(arrow_proxy, buffers, offset); + } + return buffers; + } + + flatbuffers::FlatBufferBuilder get_record_batch_message_builder(const sparrow::record_batch& record_batch) + { + const std::vector nodes = create_fieldnodes(record_batch); + const std::vector buffers = get_buffers(record_batch); + flatbuffers::FlatBufferBuilder record_batch_builder; + auto nodes_offset = record_batch_builder.CreateVectorOfStructs(nodes); + auto buffers_offset = record_batch_builder.CreateVectorOfStructs(buffers); + const auto record_batch_offset = org::apache::arrow::flatbuf::CreateRecordBatch( + record_batch_builder, + static_cast(record_batch.nb_rows()), + nodes_offset, + buffers_offset, + 0, // TODO: Compression + 0 // TODO :variadic buffer Counts + ); + + const int64_t body_size = calculate_body_size(record_batch); + const auto record_batch_message_offset = org::apache::arrow::flatbuf::CreateMessage( + record_batch_builder, + org::apache::arrow::flatbuf::MetadataVersion::V5, + org::apache::arrow::flatbuf::MessageHeader::RecordBatch, + record_batch_offset.Union(), + body_size, // body length + 0 // custom metadata + ); + record_batch_builder.Finish(record_batch_message_offset); + return record_batch_builder; + } +} diff --git a/src/serialize.cpp b/src/serialize.cpp new file mode 100644 index 0000000..a4e797d --- /dev/null +++ b/src/serialize.cpp @@ -0,0 +1,31 @@ +#include "sparrow_ipc/serialize.hpp" + +#include "sparrow_ipc/flatbuffer_utils.hpp" + +namespace sparrow_ipc +{ + void common_serialize( + const sparrow::record_batch& record_batch, + const flatbuffers::FlatBufferBuilder& builder, + any_output_stream& stream + ) + { + stream.write(continuation); + const flatbuffers::uoffset_t size = builder.GetSize(); + const std::span size_span(reinterpret_cast(&size), sizeof(uint32_t)); + stream.write(size_span); + stream.write(std::span(builder.GetBufferPointer(), size)); + stream.add_padding(); + } + + void serialize_schema_message(const sparrow::record_batch& record_batch, any_output_stream& stream) + { + common_serialize(record_batch, get_schema_message_builder(record_batch), stream); + } + + void serialize_record_batch(const sparrow::record_batch& record_batch, any_output_stream& stream) + { + common_serialize(record_batch, get_record_batch_message_builder(record_batch), stream); + generate_body(record_batch, stream); + } +} \ No newline at end of file diff --git a/src/serialize_utils.cpp b/src/serialize_utils.cpp index ac1e026..8545927 100644 --- a/src/serialize_utils.cpp +++ b/src/serialize_utils.cpp @@ -1,208 +1,30 @@ -#include - +#include "sparrow_ipc/flatbuffer_utils.hpp" #include "sparrow_ipc/magic_values.hpp" #include "sparrow_ipc/serialize.hpp" #include "sparrow_ipc/utils.hpp" namespace sparrow_ipc { - - flatbuffers::Offset>> - create_metadata(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema) - { - if (arrow_schema.metadata == nullptr) - { - return 0; - } - - const auto metadata_view = sparrow::key_value_view(arrow_schema.metadata); - std::vector> kv_offsets; - kv_offsets.reserve(metadata_view.size()); - for (const auto& [key, value] : metadata_view) - { - const auto key_offset = builder.CreateString(std::string(key)); - const auto value_offset = builder.CreateString(std::string(value)); - kv_offsets.push_back(org::apache::arrow::flatbuf::CreateKeyValue(builder, key_offset, value_offset)); - } - return builder.CreateVector(kv_offsets); - } - - ::flatbuffers::Offset - create_field(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema) - { - flatbuffers::Offset fb_name_offset = (arrow_schema.name == nullptr) - ? 0 - : builder.CreateString(arrow_schema.name); - const auto [type_enum, type_offset] = utils::get_flatbuffer_type(builder, arrow_schema.format); - auto fb_metadata_offset = create_metadata(builder, arrow_schema); - const auto children = create_children(builder, arrow_schema); - const auto fb_field = org::apache::arrow::flatbuf::CreateField( - builder, - fb_name_offset, - (arrow_schema.flags & static_cast(sparrow::ArrowFlag::NULLABLE)) != 0, - type_enum, - type_offset, - 0, // TODO: support dictionary - children, - fb_metadata_offset - ); - return fb_field; - } - - ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> - create_children(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema) - { - std::vector> children_vec; - children_vec.reserve(arrow_schema.n_children); - for (size_t i = 0; i < arrow_schema.n_children; ++i) - { - if (arrow_schema.children[i] == nullptr) - { - throw std::invalid_argument("ArrowSchema has null child pointer"); - } - const auto& child = *arrow_schema.children[i]; - flatbuffers::Offset field = create_field(builder, child); - children_vec.emplace_back(field); - } - return children_vec.empty() ? 0 : builder.CreateVector(children_vec); - } - - ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> - create_children(flatbuffers::FlatBufferBuilder& builder, sparrow::record_batch::column_range columns) - { - std::vector> children_vec; - children_vec.reserve(columns.size()); - for (const auto& column : columns) - { - const auto& arrow_schema = sparrow::detail::array_access::get_arrow_proxy(column).schema(); - flatbuffers::Offset field = create_field(builder, arrow_schema); - children_vec.emplace_back(field); - } - return children_vec.empty() ? 0 : builder.CreateVector(children_vec); - } - - flatbuffers::FlatBufferBuilder get_schema_message_builder(const sparrow::record_batch& record_batch) - { - flatbuffers::FlatBufferBuilder schema_builder; - const auto fields_vec = create_children(schema_builder, record_batch.columns()); - const auto schema_offset = org::apache::arrow::flatbuf::CreateSchema( - schema_builder, - org::apache::arrow::flatbuf::Endianness::Little, // TODO: make configurable - fields_vec - ); - const auto schema_message_offset = org::apache::arrow::flatbuf::CreateMessage( - schema_builder, - org::apache::arrow::flatbuf::MetadataVersion::V5, - org::apache::arrow::flatbuf::MessageHeader::Schema, - schema_offset.Union(), - 0, // body length is 0 for schema messages - 0 // custom metadata - ); - schema_builder.Finish(schema_message_offset); - return schema_builder; - } - - std::vector serialize_schema_message(const sparrow::record_batch& record_batch) - { - std::vector schema_buffer; - schema_buffer.insert(schema_buffer.end(), continuation.begin(), continuation.end()); - flatbuffers::FlatBufferBuilder schema_builder = get_schema_message_builder(record_batch); - const flatbuffers::uoffset_t schema_len = schema_builder.GetSize(); - schema_buffer.reserve(schema_buffer.size() + sizeof(uint32_t) + schema_len); - // Write the 4-byte length prefix after the continuation bytes - schema_buffer.insert( - schema_buffer.end(), - reinterpret_cast(&schema_len), - reinterpret_cast(&schema_len) + sizeof(uint32_t) - ); - // Append the actual message bytes - schema_buffer.insert( - schema_buffer.end(), - schema_builder.GetBufferPointer(), - schema_builder.GetBufferPointer() + schema_len - ); - add_padding(schema_buffer); - return schema_buffer; - } - - void fill_fieldnodes( - const sparrow::arrow_proxy& arrow_proxy, - std::vector& nodes - ) - { - nodes.emplace_back(arrow_proxy.length(), arrow_proxy.null_count()); - nodes.reserve(nodes.size() + arrow_proxy.n_children()); - for (const auto& child : arrow_proxy.children()) - { - fill_fieldnodes(child, nodes); - } - } - - std::vector - create_fieldnodes(const sparrow::record_batch& record_batch) - { - std::vector nodes; - nodes.reserve(record_batch.columns().size()); - for (const auto& column : record_batch.columns()) - { - fill_fieldnodes(sparrow::detail::array_access::get_arrow_proxy(column), nodes); - } - return nodes; - } - - void fill_buffers( - const sparrow::arrow_proxy& arrow_proxy, - std::vector& flatbuf_buffers, - int64_t& offset - ) - { - const auto& buffers = arrow_proxy.buffers(); - for (const auto& buffer : buffers) - { - int64_t size = static_cast(buffer.size()); - flatbuf_buffers.emplace_back(offset, size); - offset += utils::align_to_8(size); - } - for (const auto& child : arrow_proxy.children()) - { - fill_buffers(child, flatbuf_buffers, offset); - } - } - - std::vector get_buffers(const sparrow::record_batch& record_batch) - { - std::vector buffers; - std::int64_t offset = 0; - for (const auto& column : record_batch.columns()) - { - const auto& arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(column); - fill_buffers(arrow_proxy, buffers, offset); - } - return buffers; - } - - void fill_body(const sparrow::arrow_proxy& arrow_proxy, std::vector& body) + void fill_body(const sparrow::arrow_proxy& arrow_proxy, any_output_stream& stream) { for (const auto& buffer : arrow_proxy.buffers()) { - body.insert(body.end(), buffer.begin(), buffer.end()); - add_padding(body); + stream.write(buffer); + stream.add_padding(); } for (const auto& child : arrow_proxy.children()) { - fill_body(child, body); + fill_body(child, stream); } } - std::vector generate_body(const sparrow::record_batch& record_batch) + void generate_body(const sparrow::record_batch& record_batch, any_output_stream& stream) { - std::vector body; for (const auto& column : record_batch.columns()) { const auto& arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(column); - fill_body(arrow_proxy, body); + fill_body(arrow_proxy, stream); } - return body; } int64_t calculate_body_size(const sparrow::arrow_proxy& arrow_proxy) @@ -210,7 +32,7 @@ namespace sparrow_ipc int64_t total_size = 0; for (const auto& buffer : arrow_proxy.buffers()) { - total_size += utils::align_to_8(static_cast(buffer.size())); + total_size += utils::align_to_8(buffer.size()); } for (const auto& child : arrow_proxy.children()) { @@ -224,7 +46,7 @@ namespace sparrow_ipc return std::accumulate( record_batch.columns().begin(), record_batch.columns().end(), - 0, + int64_t{0}, [](int64_t acc, const sparrow::array& arr) { const auto& arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(arr); @@ -233,73 +55,54 @@ namespace sparrow_ipc ); } - flatbuffers::FlatBufferBuilder get_record_batch_message_builder( - const sparrow::record_batch& record_batch, - const std::vector& nodes, - const std::vector& buffers - ) + std::size_t calculate_schema_message_size(const sparrow::record_batch& record_batch) { - flatbuffers::FlatBufferBuilder record_batch_builder; - - auto nodes_offset = record_batch_builder.CreateVectorOfStructs(nodes); - auto buffers_offset = record_batch_builder.CreateVectorOfStructs(buffers); - const auto record_batch_offset = org::apache::arrow::flatbuf::CreateRecordBatch( - record_batch_builder, - static_cast(record_batch.nb_rows()), - nodes_offset, - buffers_offset, - 0, // TODO: Compression - 0 // TODO :variadic buffer Counts - ); + // Build the schema message to get its exact size + flatbuffers::FlatBufferBuilder schema_builder = get_schema_message_builder(record_batch); + const flatbuffers::uoffset_t schema_len = schema_builder.GetSize(); - const int64_t body_size = calculate_body_size(record_batch); - const auto record_batch_message_offset = org::apache::arrow::flatbuf::CreateMessage( - record_batch_builder, - org::apache::arrow::flatbuf::MetadataVersion::V5, - org::apache::arrow::flatbuf::MessageHeader::RecordBatch, - record_batch_offset.Union(), - body_size, // body length - 0 // custom metadata - ); - record_batch_builder.Finish(record_batch_message_offset); - return record_batch_builder; + // Calculate total size: + // - Continuation bytes (4) + // - Message length prefix (4) + // - FlatBuffer schema message data + // - Padding to 8-byte alignment + std::size_t total_size = continuation.size() + sizeof(uint32_t) + schema_len; + return utils::align_to_8(total_size); } - std::vector serialize_record_batch(const sparrow::record_batch& record_batch) + std::size_t calculate_record_batch_message_size(const sparrow::record_batch& record_batch) { - std::vector nodes = create_fieldnodes(record_batch); - std::vector flatbuf_buffers = get_buffers(record_batch); - flatbuffers::FlatBufferBuilder record_batch_builder = get_record_batch_message_builder( - record_batch, - nodes, - flatbuf_buffers - ); - std::vector output; - output.insert(output.end(), continuation.begin(), continuation.end()); + // Build the record batch message to get its exact metadata size + flatbuffers::FlatBufferBuilder record_batch_builder = get_record_batch_message_builder(record_batch); const flatbuffers::uoffset_t record_batch_len = record_batch_builder.GetSize(); - output.insert( - output.end(), - reinterpret_cast(&record_batch_len), - reinterpret_cast(&record_batch_len) + sizeof(record_batch_len) - ); - output.insert( - output.end(), - record_batch_builder.GetBufferPointer(), - record_batch_builder.GetBufferPointer() + record_batch_len - ); - add_padding(output); - std::vector body = generate_body(record_batch); - output.insert(output.end(), std::make_move_iterator(body.begin()), std::make_move_iterator(body.end())); - return output; + + // Calculate body size (already includes 8-byte alignment for each buffer) + const int64_t body_size = calculate_body_size(record_batch); + + // Calculate total size: + // - Continuation bytes (4) + // - Message length prefix (4) + // - FlatBuffer record batch metadata + // - Padding after metadata to 8-byte alignment + // - Body data (already aligned) + std::size_t metadata_size = continuation.size() + sizeof(uint32_t) + record_batch_len; + metadata_size = utils::align_to_8(metadata_size); + + return metadata_size + static_cast(body_size); } - void add_padding(std::vector& buffer) + std::vector get_column_dtypes(const sparrow::record_batch& rb) { - buffer.insert( - buffer.end(), - utils::align_to_8(static_cast(buffer.size())) - static_cast(buffer.size()), - 0 + std::vector dtypes; + dtypes.reserve(rb.nb_columns()); + std::ranges::transform( + rb.columns(), + std::back_inserter(dtypes), + [](const auto& col) + { + return col.data_type(); + } ); + return dtypes; } - -} \ No newline at end of file +} diff --git a/src/serializer.cpp b/src/serializer.cpp new file mode 100644 index 0000000..60b4c85 --- /dev/null +++ b/src/serializer.cpp @@ -0,0 +1,42 @@ +#include "sparrow_ipc/serializer.hpp" + +#include +#include +#include "sparrow_ipc/magic_values.hpp" + +namespace sparrow_ipc +{ + serializer::~serializer() + { + if (!m_ended) + { + end(); + } + } + + void serializer::write(const sparrow::record_batch& rb) + { + write(std::ranges::single_view(rb)); + } + + std::vector serializer::get_column_dtypes(const sparrow::record_batch& rb) + { + std::vector dtypes; + dtypes.reserve(rb.nb_columns()); + for (const auto& col : rb.columns()) + { + dtypes.push_back(col.data_type()); + } + return dtypes; + } + + void serializer::end() + { + if (m_ended) + { + return; + } + m_stream.write(end_of_stream); + m_ended = true; + } +} \ No newline at end of file diff --git a/src/utils.cpp b/src/utils.cpp index 3d7b5e7..2fc2490 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -1,448 +1,34 @@ #include "sparrow_ipc/utils.hpp" -#include -#include -#include - -#include "sparrow.hpp" - -namespace sparrow_ipc +namespace sparrow_ipc::utils { - namespace + std::optional parse_format(std::string_view format_str, std::string_view sep) { - // Parse the format string - // The format string is expected to be "w:size", "+w:size", "d:precision,scale", etc - std::optional parse_format(std::string_view format_str, std::string_view sep) + // Find the position of the delimiter + const auto sep_pos = format_str.find(sep); + if (sep_pos == std::string_view::npos) { - // Find the position of the delimiter - const auto sep_pos = format_str.find(sep); - if (sep_pos == std::string_view::npos) - { - return std::nullopt; - } - - std::string_view substr_str(format_str.data() + sep_pos + 1, format_str.size() - sep_pos - 1); + return std::nullopt; + } - int32_t substr_size = 0; - const auto [ptr, ec] = std::from_chars( - substr_str.data(), - substr_str.data() + substr_str.size(), - substr_size - ); + std::string_view substr_str(format_str.data() + sep_pos + 1, format_str.size() - sep_pos - 1); - if (ec != std::errc() || ptr != substr_str.data() + substr_str.size()) - { - return std::nullopt; - } - return substr_size; - } + int32_t substr_size = 0; + const auto [ptr, ec] = std::from_chars( + substr_str.data(), + substr_str.data() + substr_str.size(), + substr_size + ); - // Creates a Flatbuffers Decimal type from a format string - // The format string is expected to be in the format "d:precision,scale" - std::pair> get_flatbuffer_decimal_type( - flatbuffers::FlatBufferBuilder& builder, - std::string_view format_str, - const int32_t bitWidth - ) + if (ec != std::errc() || ptr != substr_str.data() + substr_str.size()) { - // Decimal requires precision and scale. We need to parse the format_str. - // Format: "d:precision,scale" - const auto scale = parse_format(format_str, ","); - if (!scale.has_value()) - { - throw std::runtime_error( - "Failed to parse Decimal " + std::to_string(bitWidth) - + " scale from format string: " + std::string(format_str) - ); - } - const size_t comma_pos = format_str.find(','); - const auto precision = parse_format(format_str.substr(0, comma_pos), ":"); - if (!precision.has_value()) - { - throw std::runtime_error( - "Failed to parse Decimal " + std::to_string(bitWidth) - + " precision from format string: " + std::string(format_str) - ); - } - const auto decimal_type = org::apache::arrow::flatbuf::CreateDecimal( - builder, - precision.value(), - scale.value(), - bitWidth - ); - return {org::apache::arrow::flatbuf::Type::Decimal, decimal_type.Union()}; + return std::nullopt; } + return substr_size; } - namespace utils + size_t align_to_8(const size_t n) { - int64_t align_to_8(const int64_t n) - { - return (n + 7) & -8; - } - - std::pair> - get_flatbuffer_type(flatbuffers::FlatBufferBuilder& builder, std::string_view format_str) - { - const auto type = sparrow::format_to_data_type(format_str); - switch (type) - { - case sparrow::data_type::NA: - { - const auto null_type = org::apache::arrow::flatbuf::CreateNull(builder); - return {org::apache::arrow::flatbuf::Type::Null, null_type.Union()}; - } - case sparrow::data_type::BOOL: - { - const auto bool_type = org::apache::arrow::flatbuf::CreateBool(builder); - return {org::apache::arrow::flatbuf::Type::Bool, bool_type.Union()}; - } - case sparrow::data_type::UINT8: - { - const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 8, false); - return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; - } - case sparrow::data_type::INT8: - { - const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 8, true); - return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; - } - case sparrow::data_type::UINT16: - { - const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 16, false); - return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; - } - case sparrow::data_type::INT16: - { - const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 16, true); - return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; - } - case sparrow::data_type::UINT32: - { - const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 32, false); - return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; - } - case sparrow::data_type::INT32: - { - const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 32, true); - return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; - } - case sparrow::data_type::UINT64: - { - const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 64, false); - return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; - } - case sparrow::data_type::INT64: - { - const auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 64, true); - return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; - } - case sparrow::data_type::HALF_FLOAT: - { - const auto fp_type = org::apache::arrow::flatbuf::CreateFloatingPoint( - builder, - org::apache::arrow::flatbuf::Precision::HALF - ); - return {org::apache::arrow::flatbuf::Type::FloatingPoint, fp_type.Union()}; - } - case sparrow::data_type::FLOAT: - { - const auto fp_type = org::apache::arrow::flatbuf::CreateFloatingPoint( - builder, - org::apache::arrow::flatbuf::Precision::SINGLE - ); - return {org::apache::arrow::flatbuf::Type::FloatingPoint, fp_type.Union()}; - } - case sparrow::data_type::DOUBLE: - { - const auto fp_type = org::apache::arrow::flatbuf::CreateFloatingPoint( - builder, - org::apache::arrow::flatbuf::Precision::DOUBLE - ); - return {org::apache::arrow::flatbuf::Type::FloatingPoint, fp_type.Union()}; - } - case sparrow::data_type::STRING: - { - const auto string_type = org::apache::arrow::flatbuf::CreateUtf8(builder); - return {org::apache::arrow::flatbuf::Type::Utf8, string_type.Union()}; - } - case sparrow::data_type::LARGE_STRING: - { - const auto large_string_type = org::apache::arrow::flatbuf::CreateLargeUtf8(builder); - return {org::apache::arrow::flatbuf::Type::LargeUtf8, large_string_type.Union()}; - } - case sparrow::data_type::BINARY: - { - const auto binary_type = org::apache::arrow::flatbuf::CreateBinary(builder); - return {org::apache::arrow::flatbuf::Type::Binary, binary_type.Union()}; - } - case sparrow::data_type::LARGE_BINARY: - { - const auto large_binary_type = org::apache::arrow::flatbuf::CreateLargeBinary(builder); - return {org::apache::arrow::flatbuf::Type::LargeBinary, large_binary_type.Union()}; - } - case sparrow::data_type::STRING_VIEW: - { - const auto string_view_type = org::apache::arrow::flatbuf::CreateUtf8View(builder); - return {org::apache::arrow::flatbuf::Type::Utf8View, string_view_type.Union()}; - } - case sparrow::data_type::BINARY_VIEW: - { - const auto binary_view_type = org::apache::arrow::flatbuf::CreateBinaryView(builder); - return {org::apache::arrow::flatbuf::Type::BinaryView, binary_view_type.Union()}; - } - case sparrow::data_type::DATE_DAYS: - { - const auto date_type = org::apache::arrow::flatbuf::CreateDate( - builder, - org::apache::arrow::flatbuf::DateUnit::DAY - ); - return {org::apache::arrow::flatbuf::Type::Date, date_type.Union()}; - } - case sparrow::data_type::DATE_MILLISECONDS: - { - const auto date_type = org::apache::arrow::flatbuf::CreateDate( - builder, - org::apache::arrow::flatbuf::DateUnit::MILLISECOND - ); - return {org::apache::arrow::flatbuf::Type::Date, date_type.Union()}; - } - case sparrow::data_type::TIMESTAMP_SECONDS: - { - const auto timestamp_type = org::apache::arrow::flatbuf::CreateTimestamp( - builder, - org::apache::arrow::flatbuf::TimeUnit::SECOND - ); - return {org::apache::arrow::flatbuf::Type::Timestamp, timestamp_type.Union()}; - } - case sparrow::data_type::TIMESTAMP_MILLISECONDS: - { - const auto timestamp_type = org::apache::arrow::flatbuf::CreateTimestamp( - builder, - org::apache::arrow::flatbuf::TimeUnit::MILLISECOND - ); - return {org::apache::arrow::flatbuf::Type::Timestamp, timestamp_type.Union()}; - } - case sparrow::data_type::TIMESTAMP_MICROSECONDS: - { - const auto timestamp_type = org::apache::arrow::flatbuf::CreateTimestamp( - builder, - org::apache::arrow::flatbuf::TimeUnit::MICROSECOND - ); - return {org::apache::arrow::flatbuf::Type::Timestamp, timestamp_type.Union()}; - } - case sparrow::data_type::TIMESTAMP_NANOSECONDS: - { - const auto timestamp_type = org::apache::arrow::flatbuf::CreateTimestamp( - builder, - org::apache::arrow::flatbuf::TimeUnit::NANOSECOND - ); - return {org::apache::arrow::flatbuf::Type::Timestamp, timestamp_type.Union()}; - } - case sparrow::data_type::DURATION_SECONDS: - { - const auto duration_type = org::apache::arrow::flatbuf::CreateDuration( - builder, - org::apache::arrow::flatbuf::TimeUnit::SECOND - ); - return {org::apache::arrow::flatbuf::Type::Duration, duration_type.Union()}; - } - case sparrow::data_type::DURATION_MILLISECONDS: - { - const auto duration_type = org::apache::arrow::flatbuf::CreateDuration( - builder, - org::apache::arrow::flatbuf::TimeUnit::MILLISECOND - ); - return {org::apache::arrow::flatbuf::Type::Duration, duration_type.Union()}; - } - case sparrow::data_type::DURATION_MICROSECONDS: - { - const auto duration_type = org::apache::arrow::flatbuf::CreateDuration( - builder, - org::apache::arrow::flatbuf::TimeUnit::MICROSECOND - ); - return {org::apache::arrow::flatbuf::Type::Duration, duration_type.Union()}; - } - case sparrow::data_type::DURATION_NANOSECONDS: - { - const auto duration_type = org::apache::arrow::flatbuf::CreateDuration( - builder, - org::apache::arrow::flatbuf::TimeUnit::NANOSECOND - ); - return {org::apache::arrow::flatbuf::Type::Duration, duration_type.Union()}; - } - case sparrow::data_type::INTERVAL_MONTHS: - { - const auto interval_type = org::apache::arrow::flatbuf::CreateInterval( - builder, - org::apache::arrow::flatbuf::IntervalUnit::YEAR_MONTH - ); - return {org::apache::arrow::flatbuf::Type::Interval, interval_type.Union()}; - } - case sparrow::data_type::INTERVAL_DAYS_TIME: - { - const auto interval_type = org::apache::arrow::flatbuf::CreateInterval( - builder, - org::apache::arrow::flatbuf::IntervalUnit::DAY_TIME - ); - return {org::apache::arrow::flatbuf::Type::Interval, interval_type.Union()}; - } - case sparrow::data_type::INTERVAL_MONTHS_DAYS_NANOSECONDS: - { - const auto interval_type = org::apache::arrow::flatbuf::CreateInterval( - builder, - org::apache::arrow::flatbuf::IntervalUnit::MONTH_DAY_NANO - ); - return {org::apache::arrow::flatbuf::Type::Interval, interval_type.Union()}; - } - case sparrow::data_type::TIME_SECONDS: - { - const auto time_type = org::apache::arrow::flatbuf::CreateTime( - builder, - org::apache::arrow::flatbuf::TimeUnit::SECOND, - 32 - ); - return {org::apache::arrow::flatbuf::Type::Time, time_type.Union()}; - } - case sparrow::data_type::TIME_MILLISECONDS: - { - const auto time_type = org::apache::arrow::flatbuf::CreateTime( - builder, - org::apache::arrow::flatbuf::TimeUnit::MILLISECOND, - 32 - ); - return {org::apache::arrow::flatbuf::Type::Time, time_type.Union()}; - } - case sparrow::data_type::TIME_MICROSECONDS: - { - const auto time_type = org::apache::arrow::flatbuf::CreateTime( - builder, - org::apache::arrow::flatbuf::TimeUnit::MICROSECOND, - 64 - ); - return {org::apache::arrow::flatbuf::Type::Time, time_type.Union()}; - } - case sparrow::data_type::TIME_NANOSECONDS: - { - const auto time_type = org::apache::arrow::flatbuf::CreateTime( - builder, - org::apache::arrow::flatbuf::TimeUnit::NANOSECOND, - 64 - ); - return {org::apache::arrow::flatbuf::Type::Time, time_type.Union()}; - } - case sparrow::data_type::LIST: - { - const auto list_type = org::apache::arrow::flatbuf::CreateList(builder); - return {org::apache::arrow::flatbuf::Type::List, list_type.Union()}; - } - case sparrow::data_type::LARGE_LIST: - { - const auto large_list_type = org::apache::arrow::flatbuf::CreateLargeList(builder); - return {org::apache::arrow::flatbuf::Type::LargeList, large_list_type.Union()}; - } - case sparrow::data_type::LIST_VIEW: - { - const auto list_view_type = org::apache::arrow::flatbuf::CreateListView(builder); - return {org::apache::arrow::flatbuf::Type::ListView, list_view_type.Union()}; - } - case sparrow::data_type::LARGE_LIST_VIEW: - { - const auto large_list_view_type = org::apache::arrow::flatbuf::CreateLargeListView(builder); - return {org::apache::arrow::flatbuf::Type::LargeListView, large_list_view_type.Union()}; - } - case sparrow::data_type::FIXED_SIZED_LIST: - { - // FixedSizeList requires listSize. We need to parse the format_str. - // Format: "+w:size" - const auto list_size = parse_format(format_str, ":"); - if (!list_size.has_value()) - { - throw std::runtime_error( - "Failed to parse FixedSizeList size from format string: " + std::string(format_str) - ); - } - - const auto fixed_size_list_type = org::apache::arrow::flatbuf::CreateFixedSizeList( - builder, - list_size.value() - ); - return {org::apache::arrow::flatbuf::Type::FixedSizeList, fixed_size_list_type.Union()}; - } - case sparrow::data_type::STRUCT: - { - const auto struct_type = org::apache::arrow::flatbuf::CreateStruct_(builder); - return {org::apache::arrow::flatbuf::Type::Struct_, struct_type.Union()}; - } - case sparrow::data_type::MAP: - { - // not sorted keys - const auto map_type = org::apache::arrow::flatbuf::CreateMap(builder, false); - return {org::apache::arrow::flatbuf::Type::Map, map_type.Union()}; - } - case sparrow::data_type::DENSE_UNION: - { - const auto union_type = org::apache::arrow::flatbuf::CreateUnion( - builder, - org::apache::arrow::flatbuf::UnionMode::Dense, - 0 - ); - return {org::apache::arrow::flatbuf::Type::Union, union_type.Union()}; - } - case sparrow::data_type::SPARSE_UNION: - { - const auto union_type = org::apache::arrow::flatbuf::CreateUnion( - builder, - org::apache::arrow::flatbuf::UnionMode::Sparse, - 0 - ); - return {org::apache::arrow::flatbuf::Type::Union, union_type.Union()}; - } - case sparrow::data_type::RUN_ENCODED: - { - const auto run_end_encoded_type = org::apache::arrow::flatbuf::CreateRunEndEncoded(builder); - return {org::apache::arrow::flatbuf::Type::RunEndEncoded, run_end_encoded_type.Union()}; - } - case sparrow::data_type::DECIMAL32: - { - return get_flatbuffer_decimal_type(builder, format_str, 32); - } - case sparrow::data_type::DECIMAL64: - { - return get_flatbuffer_decimal_type(builder, format_str, 64); - } - case sparrow::data_type::DECIMAL128: - { - return get_flatbuffer_decimal_type(builder, format_str, 128); - } - case sparrow::data_type::DECIMAL256: - { - return get_flatbuffer_decimal_type(builder, format_str, 256); - } - case sparrow::data_type::FIXED_WIDTH_BINARY: - { - // FixedSizeBinary requires byteWidth. We need to parse the format_str. - // Format: "w:size" - const auto byte_width = parse_format(format_str, ":"); - if (!byte_width.has_value()) - { - throw std::runtime_error( - "Failed to parse FixedWidthBinary size from format string: " - + std::string(format_str) - ); - } - - const auto fixed_width_binary_type = org::apache::arrow::flatbuf::CreateFixedSizeBinary( - builder, - byte_width.value() - ); - return {org::apache::arrow::flatbuf::Type::FixedSizeBinary, fixed_width_binary_type.Union()}; - } - default: - { - throw std::runtime_error("Unsupported data type for serialization"); - } - } - } + return (n + 7) & -8; } } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 4e49c5d..11c2f9f 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -5,10 +5,16 @@ set(test_target "test_sparrow_ipc_lib") set(SPARROW_IPC_TESTS_SRC include/sparrow_ipc_tests_helpers.hpp main.cpp + test_any_output_stream.cpp test_arrow_array.cpp test_arrow_schema.cpp + test_chunk_memory_output_stream.cpp + test_chunk_memory_serializer.cpp test_de_serialization_with_files.cpp + $<$>:test_flatbuffer_utils.cpp> + test_memory_output_streams.cpp test_serialize_utils.cpp + test_serializer.cpp test_utils.cpp ) diff --git a/tests/include/sparrow_ipc_tests_helpers.hpp b/tests/include/sparrow_ipc_tests_helpers.hpp index ad6db6e..79cc84b 100644 --- a/tests/include/sparrow_ipc_tests_helpers.hpp +++ b/tests/include/sparrow_ipc_tests_helpers.hpp @@ -1,7 +1,9 @@ #pragma once -#include "doctest/doctest.h" -#include "sparrow.hpp" +#include + +#include + namespace sparrow_ipc { @@ -32,7 +34,7 @@ namespace sparrow_ipc } // Helper function to create a simple ArrowSchema for testing - ArrowSchema + inline ArrowSchema create_test_arrow_schema(const char* format, const char* name = "test_field", bool nullable = true) { ArrowSchema schema{}; @@ -49,7 +51,8 @@ namespace sparrow_ipc } // Helper function to create ArrowSchema with metadata - ArrowSchema create_test_arrow_schema_with_metadata(const char* format, const char* name = "test_field") + inline ArrowSchema + create_test_arrow_schema_with_metadata(const char* format, const char* name = "test_field") { auto schema = create_test_arrow_schema(format, name); @@ -59,7 +62,7 @@ namespace sparrow_ipc } // Helper function to create a simple record batch for testing - sp::record_batch create_test_record_batch() + inline sp::record_batch create_test_record_batch() { // Create a simple record batch with integer and string columns using initializer syntax return sp::record_batch( diff --git a/tests/test_any_output_stream.cpp b/tests/test_any_output_stream.cpp new file mode 100644 index 0000000..4e25d4c --- /dev/null +++ b/tests/test_any_output_stream.cpp @@ -0,0 +1,263 @@ +#include +#include +#include +#include + +#include "doctest/doctest.h" + +#include "sparrow_ipc/any_output_stream.hpp" +#include "sparrow_ipc/memory_output_stream.hpp" + +TEST_SUITE("any_output_stream") +{ + TEST_CASE("Construction and basic write") + { + SUBCASE("With memory_output_stream") + { + std::vector buffer; + sparrow_ipc::memory_output_stream> mem_stream(buffer); + sparrow_ipc::any_output_stream stream(mem_stream); + + const std::vector data = {1, 2, 3, 4, 5}; + stream.write(std::span(data)); + + CHECK_EQ(buffer.size(), 5); + CHECK_EQ(buffer, data); + } + + SUBCASE("With custom stream (vector wrapper)") + { + // Custom stream that just wraps a vector + struct custom_stream + { + std::vector& buffer; + + custom_stream& write(const char* s, std::streamsize count) + { + buffer.insert(buffer.end(), s, s + count); + return *this; + } + + custom_stream& write(std::span data) + { + buffer.insert(buffer.end(), data.begin(), data.end()); + return *this; + } + + custom_stream& put(uint8_t value) + { + buffer.push_back(value); + return *this; + } + + std::size_t size() const { return buffer.size(); } + }; + + std::vector buffer; + custom_stream custom{buffer}; + sparrow_ipc::any_output_stream stream(custom); + + const std::vector data = {10, 20, 30}; + stream.write(std::span(data)); + + CHECK_EQ(buffer.size(), 3); + CHECK_EQ(buffer[0], 10); + CHECK_EQ(buffer[1], 20); + CHECK_EQ(buffer[2], 30); + } + } + + TEST_CASE("Write single byte") + { + std::vector buffer; + sparrow_ipc::memory_output_stream mem_stream(buffer); + sparrow_ipc::any_output_stream stream(mem_stream); + + stream.write(uint8_t{42}); + + CHECK_EQ(buffer.size(), 1); + CHECK_EQ(buffer[0], 42); + } + + TEST_CASE("Write repeated bytes") + { + std::vector buffer; + sparrow_ipc::memory_output_stream mem_stream(buffer); + sparrow_ipc::any_output_stream stream(mem_stream); + + stream.write(uint8_t{0}, 5); + + CHECK_EQ(buffer.size(), 5); + CHECK(std::all_of(buffer.begin(), buffer.end(), [](uint8_t b) { return b == 0; })); + } + + TEST_CASE("Add padding") + { + std::vector buffer; + sparrow_ipc::memory_output_stream mem_stream(buffer); + sparrow_ipc::any_output_stream stream(mem_stream); + + // Write 5 bytes + stream.write(std::vector{1, 2, 3, 4, 5}); + + // Add padding to align to 8-byte boundary + stream.add_padding(); + + // Should pad to 8 bytes (5 data + 3 padding) + CHECK_EQ(buffer.size(), 8); + CHECK_EQ(buffer[5], 0); + CHECK_EQ(buffer[6], 0); + CHECK_EQ(buffer[7], 0); + } + + TEST_CASE("Reserve") + { + SUBCASE("Direct reserve") + { + std::vector buffer; + sparrow_ipc::memory_output_stream mem_stream(buffer); + sparrow_ipc::any_output_stream stream(mem_stream); + + stream.reserve(100); + + CHECK_GE(buffer.capacity(), 100); + } + + SUBCASE("Lazy reserve with function") + { + std::vector buffer; + sparrow_ipc::memory_output_stream mem_stream(buffer); + sparrow_ipc::any_output_stream stream(mem_stream); + + stream.reserve([]() { return 200; }); + + CHECK_GE(buffer.capacity(), 200); + } + } + + TEST_CASE("Size tracking") + { + std::vector buffer; + sparrow_ipc::memory_output_stream mem_stream(buffer); + sparrow_ipc::any_output_stream stream(mem_stream); + + CHECK_EQ(stream.size(), 0); + + stream.write(std::vector{1, 2, 3}); + CHECK_EQ(stream.size(), 3); + + stream.write(uint8_t{4}); + CHECK_EQ(stream.size(), 4); + } + + TEST_CASE("Type recovery with get()") + { + std::vector buffer; + sparrow_ipc::memory_output_stream mem_stream(buffer); + sparrow_ipc::any_output_stream stream(mem_stream); + + SUBCASE("Correct type") + { + auto& recovered = stream.get>>(); + recovered.write(std::span{std::vector{1, 2, 3}}); + CHECK_EQ(buffer.size(), 3); + } + + SUBCASE("Wrong type throws") + { + CHECK_THROWS_AS( + stream.get(), + std::bad_cast + ); + } + } + + TEST_CASE("Move semantics") + { + std::vector buffer; + sparrow_ipc::memory_output_stream mem_stream(buffer); + sparrow_ipc::any_output_stream stream1(mem_stream); + + stream1.write(std::vector{1, 2, 3}); + + // Move construction + sparrow_ipc::any_output_stream stream2(std::move(stream1)); + CHECK_EQ(stream2.size(), 3); + + // Move assignment + sparrow_ipc::any_output_stream stream3(mem_stream); + stream3 = std::move(stream2); + CHECK_EQ(stream3.size(), 3); + } + + TEST_CASE("Polymorphic usage") + { + auto write_data = [](sparrow_ipc::any_output_stream& stream) + { + const std::vector data = {10, 20, 30, 40, 50}; + stream.write(std::span(data)); + stream.add_padding(); // Pad to 8 bytes + }; + + SUBCASE("With memory stream") + { + std::vector buffer; + sparrow_ipc::memory_output_stream mem_stream(buffer); + sparrow_ipc::any_output_stream stream(mem_stream); + + write_data(stream); + + CHECK_EQ(buffer.size(), 8); // 5 data + 3 padding + } + + SUBCASE("With ostringstream") + { + std::ostringstream oss; + sparrow_ipc::any_output_stream stream(oss); + + write_data(stream); + + CHECK_GE(oss.str().size(), 5); + } + } + + TEST_CASE("Edge cases") + { + SUBCASE("Empty write") + { + std::vector buffer; + sparrow_ipc::memory_output_stream mem_stream(buffer); + sparrow_ipc::any_output_stream stream(mem_stream); + + std::vector empty; + stream.write(std::span(empty)); + + CHECK_EQ(buffer.size(), 0); + } + + SUBCASE("Already aligned padding") + { + std::vector buffer; + sparrow_ipc::memory_output_stream mem_stream(buffer); + sparrow_ipc::any_output_stream stream(mem_stream); + + // Write exactly 8 bytes + stream.write(std::vector{1, 2, 3, 4, 5, 6, 7, 8}); + stream.add_padding(); + + // Should not add any padding + CHECK_EQ(buffer.size(), 8); + } + + SUBCASE("Zero byte write repeated") + { + std::vector buffer; + sparrow_ipc::memory_output_stream mem_stream(buffer); + sparrow_ipc::any_output_stream stream(mem_stream); + + stream.write(uint8_t{0}, 0); + + CHECK_EQ(buffer.size(), 0); + } + } +} diff --git a/tests/test_chunk_memory_output_stream.cpp b/tests/test_chunk_memory_output_stream.cpp new file mode 100644 index 0000000..b687273 --- /dev/null +++ b/tests/test_chunk_memory_output_stream.cpp @@ -0,0 +1,504 @@ +#include +#include +#include +#include +#include +#include + +#include + +#include "doctest/doctest.h" + +namespace sparrow_ipc +{ + TEST_SUITE("chunked_memory_output_stream") + { + TEST_CASE("basic construction") + { + SUBCASE("Construction with empty vector of vectors") + { + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + CHECK_EQ(stream.size(), 0); + CHECK_EQ(chunks.size(), 0); + } + + SUBCASE("Construction with existing chunks") + { + std::vector> chunks = {{1, 2, 3}, {4, 5, 6, 7}, {8, 9}}; + chunked_memory_output_stream stream(chunks); + + CHECK_EQ(stream.size(), 9); + CHECK_EQ(chunks.size(), 3); + } + } + + TEST_CASE("write operations with span") + { + SUBCASE("Write single byte span") + { + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + uint8_t data[] = {42}; + std::span span(data, 1); + + stream.write(span); + + CHECK_EQ(stream.size(), 1); + CHECK_EQ(chunks.size(), 1); + CHECK_EQ(chunks[0].size(), 1); + CHECK_EQ(chunks[0][0], 42); + } + + SUBCASE("Write multiple bytes span") + { + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + uint8_t data[] = {1, 2, 3, 4, 5}; + std::span span(data, 5); + + stream.write(span); + + CHECK_EQ(stream.size(), 5); + CHECK_EQ(chunks.size(), 1); + CHECK_EQ(chunks[0].size(), 5); + for (size_t i = 0; i < 5; ++i) + { + CHECK_EQ(chunks[0][i], i + 1); + } + } + + SUBCASE("Write empty span") + { + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + std::span empty_span; + + stream.write(empty_span); + + CHECK_EQ(stream.size(), 0); + CHECK_EQ(chunks.size(), 1); + CHECK_EQ(chunks[0].size(), 0); + } + + SUBCASE("Multiple span writes create multiple chunks") + { + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + uint8_t data1[] = {10, 20}; + uint8_t data2[] = {30, 40, 50}; + uint8_t data3[] = {60}; + + stream.write(std::span(data1, 2)); + stream.write(std::span(data2, 3)); + stream.write(std::span(data3, 1)); + + CHECK_EQ(stream.size(), 6); + CHECK_EQ(chunks.size(), 3); + + CHECK_EQ(chunks[0].size(), 2); + CHECK_EQ(chunks[0][0], 10); + CHECK_EQ(chunks[0][1], 20); + + CHECK_EQ(chunks[1].size(), 3); + CHECK_EQ(chunks[1][0], 30); + CHECK_EQ(chunks[1][1], 40); + CHECK_EQ(chunks[1][2], 50); + + CHECK_EQ(chunks[2].size(), 1); + CHECK_EQ(chunks[2][0], 60); + } + } + + TEST_CASE("write operations with move") + { + SUBCASE("Write moved vector") + { + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + std::vector buffer = {1, 2, 3, 4, 5}; + stream.write(std::move(buffer)); + + CHECK_EQ(stream.size(), 5); + CHECK_EQ(chunks.size(), 1); + CHECK_EQ(chunks[0].size(), 5); + for (size_t i = 0; i < 5; ++i) + { + CHECK_EQ(chunks[0][i], i + 1); + } + } + + SUBCASE("Write multiple moved vectors") + { + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + std::vector buffer1 = {10, 20, 30}; + std::vector buffer2 = {40, 50}; + std::vector buffer3 = {60, 70, 80, 90}; + + stream.write(std::move(buffer1)); + stream.write(std::move(buffer2)); + stream.write(std::move(buffer3)); + + CHECK_EQ(stream.size(), 9); + CHECK_EQ(chunks.size(), 3); + + CHECK_EQ(chunks[0].size(), 3); + CHECK_EQ(chunks[1].size(), 2); + CHECK_EQ(chunks[2].size(), 4); + } + + SUBCASE("Write empty moved vector") + { + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + std::vector empty_buffer; + stream.write(std::move(empty_buffer)); + + CHECK_EQ(stream.size(), 0); + CHECK_EQ(chunks.size(), 1); + CHECK_EQ(chunks[0].size(), 0); + } + } + + TEST_CASE("write operations with repeated value") + { + SUBCASE("Write value multiple times") + { + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + stream.write(static_cast(255), 5); + + CHECK_EQ(stream.size(), 5); + CHECK_EQ(chunks.size(), 1); + CHECK_EQ(chunks[0].size(), 5); + for (size_t i = 0; i < 5; ++i) + { + CHECK_EQ(chunks[0][i], 255); + } + } + + SUBCASE("Write value zero times") + { + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + stream.write(static_cast(42), 0); + + CHECK_EQ(stream.size(), 0); + CHECK_EQ(chunks.size(), 1); + CHECK_EQ(chunks[0].size(), 0); + } + + SUBCASE("Multiple repeated value writes") + { + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + stream.write(static_cast(100), 3); + stream.write(static_cast(200), 2); + stream.write(static_cast(50), 4); + + CHECK_EQ(stream.size(), 9); + CHECK_EQ(chunks.size(), 3); + + CHECK_EQ(chunks[0].size(), 3); + for (size_t i = 0; i < 3; ++i) + { + CHECK_EQ(chunks[0][i], 100); + } + + CHECK_EQ(chunks[1].size(), 2); + for (size_t i = 0; i < 2; ++i) + { + CHECK_EQ(chunks[1][i], 200); + } + + CHECK_EQ(chunks[2].size(), 4); + for (size_t i = 0; i < 4; ++i) + { + CHECK_EQ(chunks[2][i], 50); + } + } + } + + TEST_CASE("mixed write operations") + { + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + // Write span + uint8_t data[] = {1, 2, 3}; + stream.write(std::span(data, 3)); + + // Write repeated value + stream.write(static_cast(42), 2); + + // Write moved vector + std::vector buffer = {10, 20, 30, 40}; + stream.write(std::move(buffer)); + + CHECK_EQ(stream.size(), 9); + CHECK_EQ(chunks.size(), 3); + + CHECK_EQ(chunks[0].size(), 3); + CHECK_EQ(chunks[0][0], 1); + CHECK_EQ(chunks[0][1], 2); + CHECK_EQ(chunks[0][2], 3); + + CHECK_EQ(chunks[1].size(), 2); + CHECK_EQ(chunks[1][0], 42); + CHECK_EQ(chunks[1][1], 42); + + CHECK_EQ(chunks[2].size(), 4); + CHECK_EQ(chunks[2][0], 10); + CHECK_EQ(chunks[2][1], 20); + CHECK_EQ(chunks[2][2], 30); + CHECK_EQ(chunks[2][3], 40); + } + + TEST_CASE("reserve functionality") + { + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + // Reserve space + stream.reserve(100); + + // Chunks vector should have reserved capacity but size should remain 0 + CHECK_GE(chunks.capacity(), 100); + CHECK_EQ(stream.size(), 0); + CHECK_EQ(chunks.size(), 0); + + // Writing should work normally after reserve + uint8_t data[] = {1, 2, 3}; + std::span span(data, 3); + stream.write(span); + + CHECK_EQ(stream.size(), 3); + CHECK_EQ(chunks.size(), 1); + } + + TEST_CASE("size calculation") + { + SUBCASE("Size with empty chunks") + { + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + CHECK_EQ(stream.size(), 0); + } + + SUBCASE("Size with pre-existing chunks") + { + std::vector> chunks = {{1, 2, 3}, {4, 5}, {6, 7, 8, 9}}; + chunked_memory_output_stream stream(chunks); + + CHECK_EQ(stream.size(), 9); + } + + SUBCASE("Size updates after writes") + { + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + CHECK_EQ(stream.size(), 0); + + uint8_t data[] = {1, 2, 3}; + stream.write(std::span(data, 3)); + CHECK_EQ(stream.size(), 3); + + stream.write(static_cast(42), 5); + CHECK_EQ(stream.size(), 8); + + std::vector buffer = {10, 20}; + stream.write(std::move(buffer)); + CHECK_EQ(stream.size(), 10); + } + + SUBCASE("Size with chunks of varying sizes") + { + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + stream.write(static_cast(1), 1); + stream.write(static_cast(2), 10); + stream.write(static_cast(3), 100); + stream.write(static_cast(4), 1000); + + CHECK_EQ(stream.size(), 1111); + CHECK_EQ(chunks.size(), 4); + } + } + + TEST_CASE("large data handling") + { + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + SUBCASE("Single large chunk") + { + const size_t large_size = 10000; + std::vector large_data(large_size); + std::iota(large_data.begin(), large_data.end(), 0); + + stream.write(std::move(large_data)); + + CHECK_EQ(stream.size(), large_size); + CHECK_EQ(chunks.size(), 1); + CHECK_EQ(chunks[0].size(), large_size); + + // Verify data integrity + for (size_t i = 0; i < large_size; ++i) + { + CHECK_EQ(chunks[0][i], static_cast(i)); + } + } + + SUBCASE("Many small chunks") + { + const size_t num_chunks = 1000; + const size_t chunk_size = 10; + + for (size_t i = 0; i < num_chunks; ++i) + { + uint8_t value = static_cast(i); + stream.write(value, chunk_size); + } + + CHECK_EQ(stream.size(), num_chunks * chunk_size); + CHECK_EQ(chunks.size(), num_chunks); + + for (size_t i = 0; i < num_chunks; ++i) + { + CHECK_EQ(chunks[i].size(), chunk_size); + for (size_t j = 0; j < chunk_size; ++j) + { + CHECK_EQ(chunks[i][j], static_cast(i)); + } + } + } + + SUBCASE("Large repeated value write") + { + const size_t count = 50000; + stream.write(static_cast(123), count); + + CHECK_EQ(stream.size(), count); + CHECK_EQ(chunks.size(), 1); + CHECK_EQ(chunks[0].size(), count); + + for (size_t i = 0; i < count; ++i) + { + CHECK_EQ(chunks[0][i], 123); + } + } + } + + TEST_CASE("edge cases") + { + SUBCASE("Maximum value writes") + { + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + stream.write(std::numeric_limits::max(), 255); + + CHECK_EQ(stream.size(), 255); + CHECK_EQ(chunks.size(), 1); + for (size_t i = 0; i < 255; ++i) + { + CHECK_EQ(chunks[0][i], std::numeric_limits::max()); + } + } + + SUBCASE("Zero byte writes") + { + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + stream.write(static_cast(0), 100); + + CHECK_EQ(stream.size(), 100); + CHECK_EQ(chunks.size(), 1); + for (size_t i = 0; i < 100; ++i) + { + CHECK_EQ(chunks[0][i], 0); + } + } + + SUBCASE("Interleaved empty and non-empty writes") + { + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + stream.write(static_cast(1), 5); + stream.write(static_cast(2), 0); + stream.write(static_cast(3), 3); + std::vector empty; + stream.write(std::move(empty)); + stream.write(static_cast(4), 2); + + CHECK_EQ(stream.size(), 10); + CHECK_EQ(chunks.size(), 5); + + CHECK_EQ(chunks[0].size(), 5); + CHECK_EQ(chunks[1].size(), 0); + CHECK_EQ(chunks[2].size(), 3); + CHECK_EQ(chunks[3].size(), 0); + CHECK_EQ(chunks[4].size(), 2); + } + } + + TEST_CASE("reference semantics") + { + SUBCASE("Stream modifies original chunks vector") + { + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + uint8_t data[] = {1, 2, 3}; + stream.write(std::span(data, 3)); + + // Verify that the original vector is modified + CHECK_EQ(chunks.size(), 1); + CHECK_EQ(chunks[0].size(), 3); + CHECK_EQ(chunks[0][0], 1); + CHECK_EQ(chunks[0][1], 2); + CHECK_EQ(chunks[0][2], 3); + } + + SUBCASE("Multiple streams to same chunks vector") + { + std::vector> chunks; + + { + chunked_memory_output_stream stream1(chunks); + uint8_t data1[] = {10, 20}; + stream1.write(std::span(data1, 2)); + } + + { + chunked_memory_output_stream stream2(chunks); + uint8_t data2[] = {30, 40}; + stream2.write(std::span(data2, 2)); + } + + CHECK_EQ(chunks.size(), 2); + CHECK_EQ(chunks[0][0], 10); + CHECK_EQ(chunks[0][1], 20); + CHECK_EQ(chunks[1][0], 30); + CHECK_EQ(chunks[1][1], 40); + } + } + } +} diff --git a/tests/test_chunk_memory_serializer.cpp b/tests/test_chunk_memory_serializer.cpp new file mode 100644 index 0000000..1230b97 --- /dev/null +++ b/tests/test_chunk_memory_serializer.cpp @@ -0,0 +1,562 @@ +#include + +#include +#include + +#include "sparrow_ipc/chunk_memory_output_stream.hpp" +#include "sparrow_ipc/chunk_memory_serializer.hpp" +#include "sparrow_ipc_tests_helpers.hpp" + +namespace sparrow_ipc +{ + namespace sp = sparrow; + + TEST_SUITE("chunk_serializer") + { + TEST_CASE("construction with single record batch") + { + SUBCASE("Valid record batch") + { + auto rb = create_test_record_batch(); + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + chunk_serializer serializer(stream); + serializer << rb; + + // After construction with single record batch, should have schema + record batch + CHECK_EQ(chunks.size(), 2); + CHECK_GT(chunks[0].size(), 0); // Schema message + CHECK_GT(chunks[1].size(), 0); // Record batch message + CHECK_GT(stream.size(), 0); + } + + SUBCASE("Empty record batch") + { + auto empty_batch = sp::record_batch({}); + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + chunk_serializer serializer(stream); + serializer << empty_batch; + + CHECK_EQ(chunks.size(), 2); + CHECK_GT(chunks[0].size(), 0); + } + } + + TEST_CASE("construction with range of record batches") + { + SUBCASE("Valid record batches") + { + auto array1 = sp::primitive_array({1, 2, 3}); + auto array2 = sp::primitive_array({1.0, 2.0, 3.0}); + auto rb1 = sp::record_batch( + {{"col1", sp::array(std::move(array1))}, {"col2", sp::array(std::move(array2))}} + ); + + auto array3 = sp::primitive_array({4, 5, 6}); + auto array4 = sp::primitive_array({4.0, 5.0, 6.0}); + auto rb2 = sp::record_batch( + {{"col1", sp::array(std::move(array3))}, {"col2", sp::array(std::move(array4))}} + ); + + std::vector record_batches = {rb1, rb2}; + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + chunk_serializer serializer(stream); + serializer << record_batches; + + // Should have schema + 2 record batches = 3 chunks + CHECK_EQ(chunks.size(), 3); + CHECK_GT(chunks[0].size(), 0); // Schema + CHECK_GT(chunks[1].size(), 0); // First record batch + CHECK_GT(chunks[2].size(), 0); // Second record batch + } + + + SUBCASE("Reserve is called correctly") + { + auto rb = create_test_record_batch(); + std::vector record_batches = {rb}; + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + chunk_serializer serializer(stream); + serializer << record_batches; + + // Verify that chunks were reserved (capacity should be >= size) + CHECK_GE(chunks.capacity(), chunks.size()); + } + } + + TEST_CASE("write single record batch") + { + SUBCASE("Write after construction with single batch") + { + auto rb1 = create_test_record_batch(); + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + chunk_serializer serializer(stream); + serializer << rb1; + CHECK_EQ(chunks.size(), 2); // Schema + rb1 + + // Create compatible record batch + auto rb2 = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({6, 7, 8}))}, + {"string_col", sp::array(sp::string_array(std::vector{"foo", "bar", "baz"}))}} + ); + + serializer.write(rb2); + + CHECK_EQ(chunks.size(), 3); // Schema + rb1 + rb2 + CHECK_GT(chunks[2].size(), 0); + } + + SUBCASE("Multiple appends") + { + auto rb1 = create_test_record_batch(); + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + chunk_serializer serializer(stream); + serializer << rb1; + + for (int i = 0; i < 3; ++i) + { + auto rb = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({i}))}, + {"string_col", sp::array(sp::string_array(std::vector{"test"}))}} + ); + serializer.write(rb); + } + + CHECK_EQ(chunks.size(), 5); // Schema + 1 initial + 3 appended + } + } + + TEST_CASE("write range of record batches") + { + SUBCASE("Write range after construction") + { + auto rb1 = create_test_record_batch(); + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + chunk_serializer serializer(stream); + serializer.write(rb1); + CHECK_EQ(chunks.size(), 2); + + auto array1 = sp::primitive_array({10, 20}); + auto array2 = sp::string_array(std::vector{"a", "b"}); + auto rb2 = sp::record_batch( + {{"int_col", sp::array(std::move(array1))}, + {"string_col", sp::array(std::move(array2))}} + ); + + auto array3 = sp::primitive_array({30, 40}); + auto array4 = sp::string_array(std::vector{"c", "d"}); + auto rb3 = sp::record_batch( + {{"int_col", sp::array(std::move(array3))}, + {"string_col", sp::array(std::move(array4))}} + ); + + std::vector new_batches = {rb2, rb3}; + serializer.write(new_batches); + + CHECK_EQ(chunks.size(), 4); // Schema + rb1 + rb2 + rb3 + CHECK_GT(chunks[2].size(), 0); + CHECK_GT(chunks[3].size(), 0); + } + + SUBCASE("Reserve is called during range append") + { + auto rb1 = create_test_record_batch(); + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + chunk_serializer serializer(stream); + serializer.write(rb1); + + auto rb2 = create_test_record_batch(); + auto rb3 = create_test_record_batch(); + std::vector new_batches = {rb2, rb3}; + + size_t old_capacity = chunks.capacity(); + serializer.write(new_batches); + + // Reserve should have been called + CHECK_GE(chunks.capacity(), chunks.size()); + } + + SUBCASE("Empty range append does nothing") + { + auto rb1 = create_test_record_batch(); + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + chunk_serializer serializer(stream); + serializer.write(rb1); + size_t initial_size = chunks.size(); + + std::vector empty_batches; + serializer.write(empty_batches); + + CHECK_EQ(chunks.size(), initial_size); + } + } + + TEST_CASE("end serialization") + { + SUBCASE("End after construction") + { + auto rb = create_test_record_batch(); + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + chunk_serializer serializer(stream); + serializer.write(rb); + size_t initial_size = chunks.size(); + + serializer.end(); + + // End should add end-of-stream marker + CHECK_GT(chunks.size(), initial_size); + } + + SUBCASE("Cannot append after end") + { + auto rb1 = create_test_record_batch(); + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + chunk_serializer serializer(stream); + serializer.write(rb1); + serializer.end(); + + auto rb2 = create_test_record_batch(); + CHECK_THROWS_AS(serializer.write(rb2), std::runtime_error); + } + + SUBCASE("Cannot append range after end") + { + auto rb1 = create_test_record_batch(); + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + chunk_serializer serializer(stream); + serializer.write(rb1); + serializer.end(); + + std::vector new_batches = {create_test_record_batch()}; + CHECK_THROWS_AS(serializer.write(new_batches), std::runtime_error); + } + } + + TEST_CASE("stream size tracking") + { + SUBCASE("Size increases with each operation") + { + auto rb = create_test_record_batch(); + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + size_t size_before = stream.size(); + chunk_serializer serializer(stream); + serializer.write(rb); + size_t size_after_construction = stream.size(); + + CHECK_GT(size_after_construction, size_before); + + serializer.write(rb); + size_t size_after_append = stream.size(); + + CHECK_GT(size_after_append, size_after_construction); + } + } + + TEST_CASE("large number of record batches") + { + SUBCASE("Handle many record batches efficiently") + { + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + std::vector batches; + const int num_batches = 100; + + for (int i = 0; i < num_batches; ++i) + { + auto array = sp::primitive_array({i, i+1, i+2}); + batches.push_back(sp::record_batch({{"col", sp::array(std::move(array))}})); + } + + chunk_serializer serializer(stream); + serializer.write(batches); + + // Should have schema + all batches + CHECK_EQ(chunks.size(), num_batches + 1); + CHECK_GT(stream.size(), 0); + + // Verify each chunk has data + for (const auto& chunk : chunks) + { + CHECK_GT(chunk.size(), 0); + } + } + } + + TEST_CASE("different column types") + { + SUBCASE("Multiple primitive types") + { + auto int_array = sp::primitive_array({1, 2, 3}); + auto double_array = sp::primitive_array({1.5, 2.5, 3.5}); + auto float_array = sp::primitive_array({1.0f, 2.0f, 3.0f}); + + auto rb = sp::record_batch( + {{"int_col", sp::array(std::move(int_array))}, + {"double_col", sp::array(std::move(double_array))}, + {"float_col", sp::array(std::move(float_array))}} + ); + + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + chunk_serializer serializer(stream); + serializer.write(rb); + + CHECK_EQ(chunks.size(), 2); // Schema + record batch + CHECK_GT(chunks[0].size(), 0); + CHECK_GT(chunks[1].size(), 0); + } + } + + TEST_CASE("workflow example") + { + SUBCASE("Typical usage pattern") + { + // Create initial record batch + auto rb1 = create_test_record_batch(); + + // Setup chunked stream + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + // Create serializer with initial batch + chunk_serializer serializer(stream); + serializer.write(rb1); + CHECK_EQ(chunks.size(), 2); + + // Append more batches + auto rb2 = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({10, 20}))}, + {"string_col", sp::array(sp::string_array(std::vector{"x", "y"}))}} + ); + serializer.write(rb2); + CHECK_EQ(chunks.size(), 3); + + // Append range of batches + std::vector more_batches; + for (int i = 0; i < 3; ++i) + { + auto rb = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({i}))}, + {"string_col", sp::array(sp::string_array(std::vector{"test"}))}} + ); + more_batches.push_back(rb); + } + serializer.write(more_batches); + CHECK_EQ(chunks.size(), 6); + + // End serialization + serializer.end(); + CHECK_GT(chunks.size(), 6); + + // Verify all chunks have data + for (const auto& chunk : chunks) + { + CHECK_GT(chunk.size(), 0); + } + } + } + + TEST_CASE("operator<< with single record batch") + { + SUBCASE("Single batch append using <<") + { + auto rb1 = create_test_record_batch(); + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + chunk_serializer serializer(stream); + serializer.write(rb1); + CHECK_EQ(chunks.size(), 2); // Schema + rb1 + + auto rb2 = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({6, 7, 8}))}, + {"string_col", sp::array(sp::string_array(std::vector{"foo", "bar", "baz"}))}} + ); + + serializer << rb2; + + CHECK_EQ(chunks.size(), 3); // Schema + rb1 + rb2 + CHECK_GT(chunks[2].size(), 0); + } + + SUBCASE("Chaining multiple single batch appends") + { + auto rb1 = create_test_record_batch(); + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + chunk_serializer serializer(stream); + serializer.write(rb1); + + auto rb2 = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({10, 20}))}, + {"string_col", sp::array(sp::string_array(std::vector{"a", "b"}))}} + ); + + auto rb3 = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({30, 40}))}, + {"string_col", sp::array(sp::string_array(std::vector{"c", "d"}))}} + ); + + auto rb4 = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({50, 60}))}, + {"string_col", sp::array(sp::string_array(std::vector{"e", "f"}))}} + ); + + serializer << rb2 << rb3 << rb4; + + CHECK_EQ(chunks.size(), 5); // Schema + 4 record batches + CHECK_GT(chunks[2].size(), 0); + CHECK_GT(chunks[3].size(), 0); + CHECK_GT(chunks[4].size(), 0); + } + + SUBCASE("Cannot use << after end") + { + auto rb1 = create_test_record_batch(); + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + chunk_serializer serializer(stream); + serializer.write(rb1); + serializer.end(); + + auto rb2 = create_test_record_batch(); + CHECK_THROWS_AS(serializer << rb2, std::runtime_error); + } + } + + TEST_CASE("operator<< with range of record batches") + { + SUBCASE("Range append using <<") + { + auto rb1 = create_test_record_batch(); + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + chunk_serializer serializer(stream); + serializer.write(rb1); + CHECK_EQ(chunks.size(), 2); + + std::vector batches; + for (int i = 0; i < 3; ++i) + { + auto rb = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({i * 10}))}, + {"string_col", sp::array(sp::string_array(std::vector{"test"}))}} + ); + batches.push_back(rb); + } + + serializer << batches; + + CHECK_EQ(chunks.size(), 5); // Schema + rb1 + 3 batches + for (size_t i = 0; i < chunks.size(); ++i) + { + CHECK_GT(chunks[i].size(), 0); + } + } + + SUBCASE("Chaining range and single batch appends") + { + auto rb1 = create_test_record_batch(); + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + chunk_serializer serializer(stream); + serializer.write(rb1); + + std::vector batches; + for (int i = 0; i < 2; ++i) + { + auto rb = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({i}))}, + {"string_col", sp::array(sp::string_array(std::vector{"x"}))}} + ); + batches.push_back(rb); + } + + auto rb2 = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({99}))}, + {"string_col", sp::array(sp::string_array(std::vector{"final"}))}} + ); + + serializer << batches << rb2; + + CHECK_EQ(chunks.size(), 5); // Schema + rb1 + 2 from range + rb2 + } + + SUBCASE("Mixed chaining with multiple ranges") + { + auto rb1 = create_test_record_batch(); + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + chunk_serializer serializer(stream); + serializer.write(rb1); + + std::vector batches1; + batches1.push_back(sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({10}))}, + {"string_col", sp::array(sp::string_array(std::vector{"a"}))}} + )); + + std::vector batches2; + batches2.push_back(sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({20}))}, + {"string_col", sp::array(sp::string_array(std::vector{"b"}))}} + )); + + auto rb2 = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({30}))}, + {"string_col", sp::array(sp::string_array(std::vector{"c"}))}} + ); + + serializer << batches1 << rb2 << batches2; + + CHECK_EQ(chunks.size(), 5); // Schema + rb1 + 1 from batches1 + rb2 + 1 from batches2 + } + + SUBCASE("Cannot use << with range after end") + { + auto rb1 = create_test_record_batch(); + std::vector> chunks; + chunked_memory_output_stream stream(chunks); + + chunk_serializer serializer(stream); + serializer.write(rb1); + serializer.end(); + + std::vector batches = {create_test_record_batch()}; + CHECK_THROWS_AS(serializer << batches, std::runtime_error); + } + } + } +} diff --git a/tests/test_de_serialization_with_files.cpp b/tests/test_de_serialization_with_files.cpp index 8fe825b..6c799d1 100644 --- a/tests/test_de_serialization_with_files.cpp +++ b/tests/test_de_serialization_with_files.cpp @@ -14,7 +14,8 @@ #include "doctest/doctest.h" #include "sparrow.hpp" #include "sparrow_ipc/deserialize.hpp" -#include "sparrow_ipc/serialize.hpp" +#include "sparrow_ipc/memory_output_stream.hpp" +#include "sparrow_ipc/serializer.hpp" const std::filesystem::path arrow_testing_data_dir = ARROW_TESTING_DATA_DIR; @@ -162,7 +163,10 @@ TEST_SUITE("Integration tests") std::span(stream_data) ); - const auto serialized_data = sparrow_ipc::serialize(record_batches_from_json); + std::vector serialized_data; + sparrow_ipc::memory_output_stream stream(serialized_data); + sparrow_ipc::serializer serializer(stream); + serializer << record_batches_from_json << sparrow_ipc::end_stream; const auto deserialized_serialized_data = sparrow_ipc::deserialize_stream( std::span(serialized_data) ); diff --git a/tests/test_flatbuffer_utils.cpp b/tests/test_flatbuffer_utils.cpp new file mode 100644 index 0000000..fd48410 --- /dev/null +++ b/tests/test_flatbuffer_utils.cpp @@ -0,0 +1,535 @@ +#include +#include + +#include "sparrow_ipc_tests_helpers.hpp" + +namespace sparrow_ipc +{ + TEST_SUITE("flatbuffer_utils") + { + TEST_CASE("create_metadata") + { + flatbuffers::FlatBufferBuilder builder; + + SUBCASE("No metadata (nullptr)") + { + auto schema = create_test_arrow_schema("i"); + auto metadata_offset = create_metadata(builder, schema); + CHECK_EQ(metadata_offset.o, 0); + } + + SUBCASE("With metadata - basic test") + { + auto schema = create_test_arrow_schema_with_metadata("i"); + auto metadata_offset = create_metadata(builder, schema); + // For now just check that it doesn't crash + // TODO: Add proper metadata testing when sparrow metadata is properly handled + } + } + + TEST_CASE("create_field") + { + flatbuffers::FlatBufferBuilder builder; + + SUBCASE("Basic field creation") + { + auto schema = create_test_arrow_schema("i", "int_field", true); + auto field_offset = create_field(builder, schema); + CHECK_NE(field_offset.o, 0); + } + + SUBCASE("Field with null name") + { + auto schema = create_test_arrow_schema("i", nullptr, false); + auto field_offset = create_field(builder, schema); + CHECK_NE(field_offset.o, 0); + } + + SUBCASE("Non-nullable field") + { + auto schema = create_test_arrow_schema("i", "int_field", false); + auto field_offset = create_field(builder, schema); + CHECK_NE(field_offset.o, 0); + } + } + + TEST_CASE("create_children from ArrowSchema") + { + flatbuffers::FlatBufferBuilder builder; + + SUBCASE("No children") + { + auto schema = create_test_arrow_schema("i"); + auto children_offset = create_children(builder, schema); + CHECK_EQ(children_offset.o, 0); + } + + SUBCASE("With children") + { + auto parent_schema = create_test_arrow_schema("+s"); + auto child1 = new ArrowSchema(create_test_arrow_schema("i", "child1")); + auto child2 = new ArrowSchema(create_test_arrow_schema("u", "child2")); + + ArrowSchema* children[] = {child1, child2}; + parent_schema.children = children; + parent_schema.n_children = 2; + + auto children_offset = create_children(builder, parent_schema); + CHECK_NE(children_offset.o, 0); + + // Clean up + delete child1; + delete child2; + } + + SUBCASE("Null child pointer throws exception") + { + auto parent_schema = create_test_arrow_schema("+s"); + ArrowSchema* children[] = {nullptr}; + parent_schema.children = children; + parent_schema.n_children = 1; + + CHECK_THROWS_AS( + const auto children_offset = create_children(builder, parent_schema), + std::invalid_argument + ); + } + } + + TEST_CASE("create_children from record_batch columns") + { + flatbuffers::FlatBufferBuilder builder; + + SUBCASE("With valid record batch") + { + auto record_batch = create_test_record_batch(); + auto children_offset = create_children(builder, record_batch.columns()); + CHECK_NE(children_offset.o, 0); + } + + SUBCASE("Empty record batch") + { + auto empty_batch = sp::record_batch({}); + + auto children_offset = create_children(builder, empty_batch.columns()); + CHECK_EQ(children_offset.o, 0); + } + } + + TEST_CASE("get_schema_message_builder") + { + SUBCASE("Valid record batch") + { + auto record_batch = create_test_record_batch(); + auto builder = get_schema_message_builder(record_batch); + + CHECK_GT(builder.GetSize(), 0); + CHECK_NE(builder.GetBufferPointer(), nullptr); + } + } + + TEST_CASE("fill_fieldnodes") + { + SUBCASE("Single array without children") + { + auto array = sp::primitive_array({1, 2, 3, 4, 5}); + auto proxy = sp::detail::array_access::get_arrow_proxy(array); + + std::vector nodes; + fill_fieldnodes(proxy, nodes); + + CHECK_EQ(nodes.size(), 1); + CHECK_EQ(nodes[0].length(), 5); + CHECK_EQ(nodes[0].null_count(), 0); + } + + SUBCASE("Array with null values") + { + // For now, just test with a simple array without explicit nulls + // Creating arrays with null values requires more complex sparrow setup + auto array = sp::primitive_array({1, 2, 3}); + auto proxy = sp::detail::array_access::get_arrow_proxy(array); + + std::vector nodes; + fill_fieldnodes(proxy, nodes); + + CHECK_EQ(nodes.size(), 1); + CHECK_EQ(nodes[0].length(), 3); + CHECK_EQ(nodes[0].null_count(), 0); + } + } + + TEST_CASE("create_fieldnodes") + { + SUBCASE("Record batch with multiple columns") + { + auto record_batch = create_test_record_batch(); + auto nodes = create_fieldnodes(record_batch); + + CHECK_EQ(nodes.size(), 2); // Two columns + + // Check the first column (integer array) + CHECK_EQ(nodes[0].length(), 5); + CHECK_EQ(nodes[0].null_count(), 0); + + // Check the second column (string array) + CHECK_EQ(nodes[1].length(), 5); + CHECK_EQ(nodes[1].null_count(), 0); + } + } + + TEST_CASE("fill_buffers") + { + SUBCASE("Simple primitive array") + { + auto array = sp::primitive_array({1, 2, 3, 4, 5}); + auto proxy = sp::detail::array_access::get_arrow_proxy(array); + + std::vector buffers; + int64_t offset = 0; + fill_buffers(proxy, buffers, offset); + + CHECK_GT(buffers.size(), 0); + CHECK_GT(offset, 0); + + // Verify offsets are aligned + for (const auto& buffer : buffers) + { + CHECK_EQ(buffer.offset() % 8, 0); + } + } + } + + TEST_CASE("get_buffers") + { + SUBCASE("Record batch with multiple columns") + { + auto record_batch = create_test_record_batch(); + auto buffers = get_buffers(record_batch); + CHECK_GT(buffers.size(), 0); + // Verify all offsets are properly calculated and aligned + for (size_t i = 1; i < buffers.size(); ++i) + { + CHECK_GE(buffers[i].offset(), buffers[i - 1].offset() + buffers[i - 1].length()); + } + } + } + + TEST_CASE("get_flatbuffer_type") + { + flatbuffers::FlatBufferBuilder builder; + SUBCASE("Null and Boolean types") + { + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::NA)).first, + org::apache::arrow::flatbuf::Type::Null + ); + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::BOOL)).first, + org::apache::arrow::flatbuf::Type::Bool + ); + } + + SUBCASE("Integer types") + { + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::INT8)).first, + org::apache::arrow::flatbuf::Type::Int + ); // INT8 + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::UINT8)).first, + org::apache::arrow::flatbuf::Type::Int + ); // UINT8 + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::INT16)).first, + org::apache::arrow::flatbuf::Type::Int + ); // INT16 + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::UINT16)).first, + org::apache::arrow::flatbuf::Type::Int + ); // UINT16 + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::INT32)).first, + org::apache::arrow::flatbuf::Type::Int + ); // INT32 + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::UINT32)).first, + org::apache::arrow::flatbuf::Type::Int + ); // UINT32 + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::INT64)).first, + org::apache::arrow::flatbuf::Type::Int + ); // INT64 + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::UINT64)).first, + org::apache::arrow::flatbuf::Type::Int + ); // UINT64 + } + + SUBCASE("Floating Point types") + { + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::HALF_FLOAT)).first, + org::apache::arrow::flatbuf::Type::FloatingPoint + ); // HALF_FLOAT + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::FLOAT)).first, + org::apache::arrow::flatbuf::Type::FloatingPoint + ); // FLOAT + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::DOUBLE)).first, + org::apache::arrow::flatbuf::Type::FloatingPoint + ); // DOUBLE + } + + SUBCASE("String and Binary types") + { + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::STRING)).first, + org::apache::arrow::flatbuf::Type::Utf8 + ); // STRING + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::LARGE_STRING)) + .first, + org::apache::arrow::flatbuf::Type::LargeUtf8 + ); // LARGE_STRING + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::BINARY)).first, + org::apache::arrow::flatbuf::Type::Binary + ); // BINARY + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::LARGE_BINARY)) + .first, + org::apache::arrow::flatbuf::Type::LargeBinary + ); // LARGE_BINARY + CHECK_EQ( + get_flatbuffer_type(builder, "vu").first, + org::apache::arrow::flatbuf::Type::Utf8View + ); // STRING_VIEW + CHECK_EQ( + get_flatbuffer_type(builder, "vz").first, + org::apache::arrow::flatbuf::Type::BinaryView + ); // BINARY_VIEW + } + + SUBCASE("Date types") + { + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::DATE_DAYS)).first, + org::apache::arrow::flatbuf::Type::Date + ); // DATE_DAYS + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::DATE_MILLISECONDS)) + .first, + org::apache::arrow::flatbuf::Type::Date + ); // DATE_MILLISECONDS + } + + SUBCASE("Timestamp types") + { + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::TIMESTAMP_SECONDS)) + .first, + org::apache::arrow::flatbuf::Type::Timestamp + ); // TIMESTAMP_SECONDS + CHECK_EQ( + get_flatbuffer_type( + builder, + sparrow::data_type_to_format(sparrow::data_type::TIMESTAMP_MILLISECONDS) + ) + .first, + org::apache::arrow::flatbuf::Type::Timestamp + ); // TIMESTAMP_MILLISECONDS + CHECK_EQ( + get_flatbuffer_type( + builder, + sparrow::data_type_to_format(sparrow::data_type::TIMESTAMP_MICROSECONDS) + ) + .first, + org::apache::arrow::flatbuf::Type::Timestamp + ); // TIMESTAMP_MICROSECONDS + CHECK_EQ( + get_flatbuffer_type( + builder, + sparrow::data_type_to_format(sparrow::data_type::TIMESTAMP_NANOSECONDS) + ) + .first, + org::apache::arrow::flatbuf::Type::Timestamp + ); // TIMESTAMP_NANOSECONDS + } + + SUBCASE("Duration types") + { + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::DURATION_SECONDS)) + .first, + org::apache::arrow::flatbuf::Type::Duration + ); // DURATION_SECONDS + CHECK_EQ( + get_flatbuffer_type( + builder, + sparrow::data_type_to_format(sparrow::data_type::DURATION_MILLISECONDS) + ) + .first, + org::apache::arrow::flatbuf::Type::Duration + ); // DURATION_MILLISECONDS + CHECK_EQ( + get_flatbuffer_type( + builder, + sparrow::data_type_to_format(sparrow::data_type::DURATION_MICROSECONDS) + ) + .first, + org::apache::arrow::flatbuf::Type::Duration + ); // DURATION_MICROSECONDS + CHECK_EQ( + get_flatbuffer_type( + builder, + sparrow::data_type_to_format(sparrow::data_type::DURATION_NANOSECONDS) + ) + .first, + org::apache::arrow::flatbuf::Type::Duration + ); // DURATION_NANOSECONDS + } + + SUBCASE("Interval types") + { + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::INTERVAL_MONTHS)) + .first, + org::apache::arrow::flatbuf::Type::Interval + ); // INTERVAL_MONTHS + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::INTERVAL_DAYS_TIME)) + .first, + org::apache::arrow::flatbuf::Type::Interval + ); // INTERVAL_DAYS_TIME + CHECK_EQ( + get_flatbuffer_type( + builder, + sparrow::data_type_to_format(sparrow::data_type::INTERVAL_MONTHS_DAYS_NANOSECONDS) + ) + .first, + org::apache::arrow::flatbuf::Type::Interval + ); // INTERVAL_MONTHS_DAYS_NANOSECONDS + } + + SUBCASE("Time types") + { + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::TIME_SECONDS)) + .first, + org::apache::arrow::flatbuf::Type::Time + ); // TIME_SECONDS + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::TIME_MILLISECONDS)) + .first, + org::apache::arrow::flatbuf::Type::Time + ); // TIME_MILLISECONDS + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::TIME_MICROSECONDS)) + .first, + org::apache::arrow::flatbuf::Type::Time + ); // TIME_MICROSECONDS + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::TIME_NANOSECONDS)) + .first, + org::apache::arrow::flatbuf::Type::Time + ); // TIME_NANOSECONDS + } + + SUBCASE("List types") + { + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::LIST)).first, + org::apache::arrow::flatbuf::Type::List + ); // LIST + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::LARGE_LIST)).first, + org::apache::arrow::flatbuf::Type::LargeList + ); // LARGE_LIST + CHECK_EQ( + get_flatbuffer_type(builder, "+vl").first, + org::apache::arrow::flatbuf::Type::ListView + ); // LIST_VIEW + CHECK_EQ( + get_flatbuffer_type(builder, "+vL").first, + org::apache::arrow::flatbuf::Type::LargeListView + ); // LARGE_LIST_VIEW + CHECK_EQ( + get_flatbuffer_type(builder, "+w:16").first, + org::apache::arrow::flatbuf::Type::FixedSizeList + ); // FIXED_SIZED_LIST + CHECK_THROWS(get_flatbuffer_type(builder, "+w:")); // Invalid FixedSizeList format + } + + SUBCASE("Struct and Map types") + { + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::STRUCT)).first, + org::apache::arrow::flatbuf::Type::Struct_ + ); // STRUCT + CHECK_EQ( + get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::MAP)).first, + org::apache::arrow::flatbuf::Type::Map + ); // MAP + } + + SUBCASE("Union types") + { + CHECK_EQ( + get_flatbuffer_type(builder, "+ud:").first, + org::apache::arrow::flatbuf::Type::Union + ); // DENSE_UNION + CHECK_EQ( + get_flatbuffer_type(builder, "+us:").first, + org::apache::arrow::flatbuf::Type::Union + ); // SPARSE_UNION + } + + SUBCASE("Run-End Encoded type") + { + CHECK_EQ( + get_flatbuffer_type(builder, "+r").first, + org::apache::arrow::flatbuf::Type::RunEndEncoded + ); // RUN_ENCODED + } + + SUBCASE("Decimal types") + { + CHECK_EQ( + get_flatbuffer_type(builder, "d:10,5").first, + org::apache::arrow::flatbuf::Type::Decimal + ); // DECIMAL (general) + CHECK_THROWS(get_flatbuffer_type(builder, "d:10")); // Invalid Decimal format + } + + SUBCASE("Fixed Width Binary type") + { + CHECK_EQ( + get_flatbuffer_type(builder, "w:32").first, + org::apache::arrow::flatbuf::Type::FixedSizeBinary + ); // FIXED_WIDTH_BINARY + CHECK_THROWS(static_cast(get_flatbuffer_type(builder, "w:"))); // Invalid FixedSizeBinary format + } + + SUBCASE("Unsupported type returns Null") + { + CHECK_EQ( + get_flatbuffer_type(builder, "unsupported_format").first, + org::apache::arrow::flatbuf::Type::Null + ); + } + } + + TEST_CASE("get_record_batch_message_builder") + { + SUBCASE("Valid record batch with field nodes and buffers") + { + auto record_batch = create_test_record_batch(); + auto builder = get_record_batch_message_builder(record_batch); + CHECK_GT(builder.GetSize(), 0); + CHECK_NE(builder.GetBufferPointer(), nullptr); + } + } + } +} \ No newline at end of file diff --git a/tests/test_memory_output_streams.cpp b/tests/test_memory_output_streams.cpp new file mode 100644 index 0000000..bd43b45 --- /dev/null +++ b/tests/test_memory_output_streams.cpp @@ -0,0 +1,247 @@ +#include +#include +#include +#include +#include +#include + +#include + +#include "doctest/doctest.h" + +namespace sparrow_ipc +{ + TEST_SUITE("memory_output_stream") + { + TEST_CASE("basic construction") + { + SUBCASE("Construction with std::vector") + { + std::vector buffer; + memory_output_stream stream(buffer); + CHECK_EQ(stream.size(), 0); + } + + SUBCASE("Construction with non-empty buffer") + { + std::vector buffer = {1, 2, 3, 4, 5}; + memory_output_stream stream(buffer); + CHECK_EQ(stream.size(), 5); + } + } + + TEST_CASE("write operations") + { + SUBCASE("Write single byte span") + { + std::vector buffer; + memory_output_stream stream(buffer); + + uint8_t data[] = {42}; + std::span span(data, 1); + + stream.write(span); + + CHECK_EQ(stream.size(), 1); + CHECK_EQ(buffer.size(), 1); + CHECK_EQ(buffer[0], 42); + } + + SUBCASE("Write multiple bytes span") + { + std::vector buffer; + memory_output_stream stream(buffer); + + uint8_t data[] = {1, 2, 3, 4, 5}; + std::span span(data, 5); + + stream.write(span); + + CHECK_EQ(stream.size(), 5); + CHECK_EQ(buffer.size(), 5); + for (size_t i = 0; i < 5; ++i) + { + CHECK_EQ(buffer[i], i + 1); + } + } + + SUBCASE("Write empty span") + { + std::vector buffer; + memory_output_stream stream(buffer); + + std::span empty_span; + + stream.write(empty_span); + + CHECK_EQ(stream.size(), 0); + CHECK_EQ(buffer.size(), 0); + } + + SUBCASE("Write single byte (convenience method)") + { + std::vector buffer; + memory_output_stream stream(buffer); + + uint8_t single_byte = 123; + stream.write(std::span{&single_byte, 1}); + + CHECK_EQ(stream.size(), 1); + CHECK_EQ(buffer.size(), 1); + CHECK_EQ(buffer[0], 123); + } + + SUBCASE("Write value multiple times") + { + std::vector buffer; + memory_output_stream stream(buffer); + + stream.write(static_cast(255), 3); + + + CHECK_EQ(stream.size(), 3); + CHECK_EQ(buffer.size(), 3); + CHECK_EQ(buffer[0], 255); + CHECK_EQ(buffer[1], 255); + CHECK_EQ(buffer[2], 255); + } + + SUBCASE("Write value zero times") + { + std::vector buffer; + memory_output_stream stream(buffer); + + stream.write(static_cast(42), 0); + + CHECK_EQ(stream.size(), 0); + CHECK_EQ(buffer.size(), 0); + } + } + + TEST_CASE("sequential writes") + { + std::vector buffer; + memory_output_stream stream(buffer); + + // First write + uint8_t data1[] = {10, 20, 30}; + std::span span1(data1, 3); + stream.write(span1); + + CHECK_EQ(stream.size(), 3); + + // Second write + uint8_t data2[] = {40, 50}; + std::span span2(data2, 2); + stream.write(span2); + + + CHECK_EQ(stream.size(), 5); + + // Third write with repeated value + stream.write(static_cast(60), 2); + + CHECK_EQ(stream.size(), 7); + + // Verify final buffer content + std::vector expected = {10, 20, 30, 40, 50, 60, 60}; + CHECK_EQ(buffer, expected); + } + + TEST_CASE("reserve functionality") + { + std::vector buffer; + memory_output_stream stream(buffer); + + // Reserve space + stream.reserve(100); + + // Buffer should have reserved capacity but size should remain 0 + CHECK_GE(buffer.capacity(), 100); + CHECK_EQ(stream.size(), 0); + CHECK_EQ(buffer.size(), 0); + + // Writing should work normally after reserve + uint8_t data[] = {1, 2, 3}; + std::span span(data, 3); + stream.write(span); + + CHECK_EQ(stream.size(), 3); + CHECK_EQ(buffer.size(), 3); + } + + + TEST_CASE("large data handling") + { + std::vector buffer; + memory_output_stream stream(buffer); + + // Write a large amount of data + const size_t large_size = 10000; + std::vector large_data(large_size); + std::iota(large_data.begin(), large_data.end(), 0); // Fill with 0, 1, 2, ... + + std::span span(large_data); + stream.write(span); + + CHECK_EQ(stream.size(), large_size); + CHECK_EQ(buffer.size(), large_size); + + // Verify data integrity + for (size_t i = 0; i < large_size; ++i) + { + CHECK_EQ(buffer[i], static_cast(i)); + } + } + + TEST_CASE("edge cases") + { + SUBCASE("Maximum value repeated writes") + { + std::vector buffer; + memory_output_stream stream(buffer); + + stream.write(std::numeric_limits::max(), 255); + + CHECK_EQ(stream.size(), 255); + for (size_t i = 0; i < 255; ++i) + { + CHECK_EQ(buffer[i], std::numeric_limits::max()); + } + } + + SUBCASE("Zero byte repeated writes") + { + std::vector buffer; + memory_output_stream stream(buffer); + + stream.write(static_cast(0), 100); + + CHECK_EQ(stream.size(), 100); + for (size_t i = 0; i < 100; ++i) + { + CHECK_EQ(buffer[i], 0); + } + } + } + + TEST_CASE("different container types") + { + SUBCASE("With pre-filled vector") + { + std::vector buffer = {100, 200}; + memory_output_stream stream(buffer); + + CHECK_EQ(stream.size(), 2); + + uint8_t data[] = {1, 2, 3}; + std::span span(data, 3); + stream.write(span); + + CHECK_EQ(stream.size(), 5); + std::vector expected = {100, 200, 1, 2, 3}; + CHECK_EQ(buffer, expected); + } + } + } +} \ No newline at end of file diff --git a/tests/test_serialize_utils.cpp b/tests/test_serialize_utils.cpp index 2997843..ea4011e 100644 --- a/tests/test_serialize_utils.cpp +++ b/tests/test_serialize_utils.cpp @@ -1,7 +1,11 @@ +#include + #include #include +#include "sparrow_ipc/any_output_stream.hpp" #include "sparrow_ipc/magic_values.hpp" +#include "sparrow_ipc/memory_output_stream.hpp" #include "sparrow_ipc/serialize_utils.hpp" #include "sparrow_ipc/utils.hpp" #include "sparrow_ipc_tests_helpers.hpp" @@ -10,355 +14,274 @@ namespace sparrow_ipc { namespace sp = sparrow; - TEST_CASE("create_metadata") + TEST_SUITE("serialize_utils") { - flatbuffers::FlatBufferBuilder builder; - - SUBCASE("No metadata (nullptr)") - { - auto schema = create_test_arrow_schema("i"); - auto metadata_offset = create_metadata(builder, schema); - CHECK_EQ(metadata_offset.o, 0); - } - - SUBCASE("With metadata - basic test") + TEST_CASE("serialize_schema_message") { - auto schema = create_test_arrow_schema_with_metadata("i"); - auto metadata_offset = create_metadata(builder, schema); - // For now just check that it doesn't crash - // TODO: Add proper metadata testing when sparrow metadata is properly handled - } - } - - TEST_CASE("create_field") - { - flatbuffers::FlatBufferBuilder builder; - - SUBCASE("Basic field creation") - { - auto schema = create_test_arrow_schema("i", "int_field", true); - auto field_offset = create_field(builder, schema); - CHECK_NE(field_offset.o, 0); - } - - SUBCASE("Field with null name") - { - auto schema = create_test_arrow_schema("i", nullptr, false); - auto field_offset = create_field(builder, schema); - CHECK_NE(field_offset.o, 0); + SUBCASE("Valid record batch") + { + std::vector serialized; + memory_output_stream stream(serialized); + auto record_batch = create_test_record_batch(); + any_output_stream astream(stream); + serialize_schema_message(record_batch, astream); + + CHECK_GT(serialized.size(), 0); + + // Check that it starts with continuation bytes + CHECK_EQ(serialized.size() >= continuation.size(), true); + for (size_t i = 0; i < continuation.size(); ++i) + { + CHECK_EQ(serialized[i], continuation[i]); + } + + // Check that the total size is aligned to 8 bytes + CHECK_EQ(serialized.size() % 8, 0); + } } - SUBCASE("Non-nullable field") + TEST_CASE("fill_body") { - auto schema = create_test_arrow_schema("i", "int_field", false); - auto field_offset = create_field(builder, schema); - CHECK_NE(field_offset.o, 0); + SUBCASE("Simple primitive array") + { + auto array = sp::primitive_array({1, 2, 3, 4, 5}); + auto proxy = sp::detail::array_access::get_arrow_proxy(array); + std::vector body; + sparrow_ipc::memory_output_stream stream(body); + sparrow_ipc::any_output_stream astream(stream); + fill_body(proxy, astream); + CHECK_GT(body.size(), 0); + // Body size should be aligned + CHECK_EQ(body.size() % 8, 0); + } } - } - TEST_CASE("create_children from ArrowSchema") - { - flatbuffers::FlatBufferBuilder builder; - - SUBCASE("No children") + TEST_CASE("generate_body") { - auto schema = create_test_arrow_schema("i"); - auto children_offset = create_children(builder, schema); - CHECK_EQ(children_offset.o, 0); + SUBCASE("Record batch with multiple columns") + { + auto record_batch = create_test_record_batch(); + std::vector serialized; + memory_output_stream stream(serialized); + any_output_stream astream(stream); + generate_body(record_batch, astream); + CHECK_GT(serialized.size(), 0); + CHECK_EQ(serialized.size() % 8, 0); + } } - SUBCASE("With children") + TEST_CASE("calculate_body_size") { - auto parent_schema = create_test_arrow_schema("+s"); - auto child1 = new ArrowSchema(create_test_arrow_schema("i", "child1")); - auto child2 = new ArrowSchema(create_test_arrow_schema("u", "child2")); - - ArrowSchema* children[] = {child1, child2}; - parent_schema.children = children; - parent_schema.n_children = 2; - - auto children_offset = create_children(builder, parent_schema); - CHECK_NE(children_offset.o, 0); - - // Clean up - delete child1; - delete child2; - } + SUBCASE("Single array") + { + auto array = sp::primitive_array({1, 2, 3, 4, 5}); + auto proxy = sp::detail::array_access::get_arrow_proxy(array); - SUBCASE("Null child pointer throws exception") - { - auto parent_schema = create_test_arrow_schema("+s"); - ArrowSchema* children[] = {nullptr}; - parent_schema.children = children; - parent_schema.n_children = 1; + auto size = calculate_body_size(proxy); + CHECK_GT(size, 0); + CHECK_EQ(size % 8, 0); + } - CHECK_THROWS_AS(create_children(builder, parent_schema), std::invalid_argument); + SUBCASE("Record batch") + { + auto record_batch = create_test_record_batch(); + auto size = calculate_body_size(record_batch); + CHECK_GT(size, 0); + CHECK_EQ(size % 8, 0); + std::vector serialized; + memory_output_stream stream(serialized); + any_output_stream astream(stream); + generate_body(record_batch, astream); + CHECK_EQ(size, static_cast(serialized.size())); + } } - } - TEST_CASE("create_children from record_batch columns") - { - flatbuffers::FlatBufferBuilder builder; - - SUBCASE("With valid record batch") + TEST_CASE("calculate_schema_message_size") { - auto record_batch = create_test_record_batch(); - auto children_offset = create_children(builder, record_batch.columns()); - CHECK_NE(children_offset.o, 0); - } + SUBCASE("Single column record batch") + { + auto array = sp::primitive_array({1, 2, 3, 4, 5}); + auto record_batch = sp::record_batch({{"column1", sp::array(std::move(array))}}); - SUBCASE("Empty record batch") - { - auto empty_batch = sp::record_batch({}); + const auto estimated_size = calculate_schema_message_size(record_batch); + CHECK_GT(estimated_size, 0); + CHECK_EQ(estimated_size % 8, 0); - auto children_offset = create_children(builder, empty_batch.columns()); - CHECK_EQ(children_offset.o, 0); - } - } + // Verify by actual serialization + std::vector serialized; + memory_output_stream stream(serialized); + any_output_stream astream(stream); + serialize_schema_message(record_batch, astream ); - TEST_CASE("get_schema_message_builder") - { - SUBCASE("Valid record batch") - { - auto record_batch = create_test_record_batch(); - auto builder = get_schema_message_builder(record_batch); + CHECK_EQ(estimated_size, serialized.size()); + } - CHECK_GT(builder.GetSize(), 0); - CHECK_NE(builder.GetBufferPointer(), nullptr); - } - } + SUBCASE("Multi-column record batch") + { + auto record_batch = create_test_record_batch(); - TEST_CASE("serialize_schema_message") - { - SUBCASE("Valid record batch") - { - auto record_batch = create_test_record_batch(); - auto serialized = serialize_schema_message(record_batch); + auto estimated_size = calculate_schema_message_size(record_batch); + CHECK_GT(estimated_size, 0); + CHECK_EQ(estimated_size % 8, 0); - CHECK_GT(serialized.size(), 0); + std::vector serialized; + memory_output_stream stream(serialized); + any_output_stream astream(stream); + serialize_schema_message(record_batch, astream); - // Check that it starts with continuation bytes - CHECK_EQ(serialized.size() >= continuation.size(), true); - for (size_t i = 0; i < continuation.size(); ++i) - { - CHECK_EQ(serialized[i], continuation[i]); + CHECK_EQ(estimated_size, serialized.size()); } - - // Check that the total size is aligned to 8 bytes - CHECK_EQ(serialized.size() % 8, 0); } - } - TEST_CASE("fill_fieldnodes") - { - SUBCASE("Single array without children") + TEST_CASE("calculate_record_batch_message_size") { - auto array = sp::primitive_array({1, 2, 3, 4, 5}); - auto proxy = sp::detail::array_access::get_arrow_proxy(array); - - std::vector nodes; - fill_fieldnodes(proxy, nodes); - - CHECK_EQ(nodes.size(), 1); - CHECK_EQ(nodes[0].length(), 5); - CHECK_EQ(nodes[0].null_count(), 0); - } + SUBCASE("Single column record batch") + { + auto array = sp::primitive_array({1, 2, 3, 4, 5}); + auto record_batch = sp::record_batch({{"column1", sp::array(std::move(array))}}); - SUBCASE("Array with null values") - { - // For now, just test with a simple array without explicit nulls - // Creating arrays with null values requires more complex sparrow setup - auto array = sp::primitive_array({1, 2, 3}); - auto proxy = sp::detail::array_access::get_arrow_proxy(array); + auto estimated_size = calculate_record_batch_message_size(record_batch); + CHECK_GT(estimated_size, 0); + CHECK_EQ(estimated_size % 8, 0); - std::vector nodes; - fill_fieldnodes(proxy, nodes); + std::vector serialized; + memory_output_stream stream(serialized); + any_output_stream astream(stream); + serialize_record_batch(record_batch, astream); - CHECK_EQ(nodes.size(), 1); - CHECK_EQ(nodes[0].length(), 3); - CHECK_EQ(nodes[0].null_count(), 0); - } - } + CHECK_EQ(estimated_size, serialized.size()); + } - TEST_CASE("create_fieldnodes") - { - SUBCASE("Record batch with multiple columns") - { - auto record_batch = create_test_record_batch(); - auto nodes = create_fieldnodes(record_batch); + SUBCASE("Multi-column record batch") + { + auto record_batch = create_test_record_batch(); - CHECK_EQ(nodes.size(), 2); // Two columns + auto estimated_size = calculate_record_batch_message_size(record_batch); + CHECK_GT(estimated_size, 0); + CHECK_EQ(estimated_size % 8, 0); - // Check the first column (integer array) - CHECK_EQ(nodes[0].length(), 5); - CHECK_EQ(nodes[0].null_count(), 0); + // Verify by actual serialization + std::vector serialized; + memory_output_stream stream(serialized); + any_output_stream astream(stream); + serialize_record_batch(record_batch, astream); - // Check the second column (string array) - CHECK_EQ(nodes[1].length(), 5); - CHECK_EQ(nodes[1].null_count(), 0); + CHECK_EQ(estimated_size, serialized.size()); + } } - } - TEST_CASE("fill_buffers") - { - SUBCASE("Simple primitive array") + TEST_CASE("calculate_total_serialized_size") { - auto array = sp::primitive_array({1, 2, 3, 4, 5}); - auto proxy = sp::detail::array_access::get_arrow_proxy(array); + SUBCASE("Single record batch") + { + auto record_batch = create_test_record_batch(); + std::vector batches = {record_batch}; - std::vector buffers; - int64_t offset = 0; - fill_buffers(proxy, buffers, offset); + auto estimated_size = calculate_total_serialized_size(batches); + CHECK_GT(estimated_size, 0); - CHECK_GT(buffers.size(), 0); - CHECK_GT(offset, 0); + // Should equal schema size + record batch size + auto schema_size = calculate_schema_message_size(record_batch); + auto batch_size = calculate_record_batch_message_size(record_batch); + CHECK_EQ(estimated_size, schema_size + batch_size); + } - // Verify offsets are aligned - for (const auto& buffer : buffers) + SUBCASE("Multiple record batches") { - CHECK_EQ(buffer.offset() % 8, 0); + auto array1 = sp::primitive_array({1, 2, 3}); + auto array2 = sp::primitive_array({1.0, 2.0, 3.0}); + auto record_batch1 = sp::record_batch( + {{"col1", sp::array(std::move(array1))}, {"col2", sp::array(std::move(array2))}} + ); + + auto array3 = sp::primitive_array({4, 5, 6}); + auto array4 = sp::primitive_array({4.0, 5.0, 6.0}); + auto record_batch2 = sp::record_batch( + {{"col1", sp::array(std::move(array3))}, {"col2", sp::array(std::move(array4))}} + ); + + std::vector batches = {record_batch1, record_batch2}; + + auto estimated_size = calculate_total_serialized_size(batches); + CHECK_GT(estimated_size, 0); + + // Should equal schema size + sum of record batch sizes + auto schema_size = calculate_schema_message_size(batches[0]); + auto batch1_size = calculate_record_batch_message_size(batches[0]); + auto batch2_size = calculate_record_batch_message_size(batches[1]); + CHECK_EQ(estimated_size, schema_size + batch1_size + batch2_size); } - } - } - TEST_CASE("get_buffers") - { - SUBCASE("Record batch with multiple columns") - { - auto record_batch = create_test_record_batch(); - auto buffers = get_buffers(record_batch); - CHECK_GT(buffers.size(), 0); - // Verify all offsets are properly calculated and aligned - for (size_t i = 1; i < buffers.size(); ++i) + SUBCASE("Empty collection") { - CHECK_GE(buffers[i].offset(), buffers[i - 1].offset() + buffers[i - 1].length()); + std::vector empty_batches; + auto estimated_size = calculate_total_serialized_size(empty_batches); + CHECK_EQ(estimated_size, 0); } - } - } - - TEST_CASE("fill_body") - { - SUBCASE("Simple primitive array") - { - auto array = sp::primitive_array({1, 2, 3, 4, 5}); - auto proxy = sp::detail::array_access::get_arrow_proxy(array); - std::vector body; - fill_body(proxy, body); - CHECK_GT(body.size(), 0); - // Body size should be aligned - CHECK_EQ(body.size() % 8, 0); - } - } - - TEST_CASE("generate_body") - { - SUBCASE("Record batch with multiple columns") - { - auto record_batch = create_test_record_batch(); - auto body = generate_body(record_batch); - CHECK_GT(body.size(), 0); - CHECK_EQ(body.size() % 8, 0); - } - } - - TEST_CASE("calculate_body_size") - { - SUBCASE("Single array") - { - auto array = sp::primitive_array({1, 2, 3, 4, 5}); - auto proxy = sp::detail::array_access::get_arrow_proxy(array); + SUBCASE("Inconsistent schemas throw exception") + { + auto array1 = sp::primitive_array({1, 2, 3}); + auto record_batch1 = sp::record_batch({{"col1", sp::array(std::move(array1))}}); - auto size = calculate_body_size(proxy); - CHECK_GT(size, 0); - CHECK_EQ(size % 8, 0); - } + auto array2 = sp::primitive_array({1.0, 2.0, 3.0}); + auto record_batch2 = sp::record_batch( + {{"col2", sp::array(std::move(array2))}} // Different column name + ); - SUBCASE("Record batch") - { - auto record_batch = create_test_record_batch(); - auto size = calculate_body_size(record_batch); - CHECK_GT(size, 0); - CHECK_EQ(size % 8, 0); - auto body = generate_body(record_batch); - CHECK_EQ(size, static_cast(body.size())); - } - } + std::vector batches = {record_batch1, record_batch2}; - TEST_CASE("get_record_batch_message_builder") - { - SUBCASE("Valid record batch with field nodes and buffers") - { - auto record_batch = create_test_record_batch(); - auto nodes = create_fieldnodes(record_batch); - auto buffers = get_buffers(record_batch); - auto builder = get_record_batch_message_builder(record_batch, nodes, buffers); - CHECK_GT(builder.GetSize(), 0); - CHECK_NE(builder.GetBufferPointer(), nullptr); + CHECK_THROWS_AS(auto size = calculate_total_serialized_size(batches), std::invalid_argument); + } } - } - TEST_CASE("serialize_record_batch") - { - SUBCASE("Valid record batch") + TEST_CASE("serialize_record_batch") { - auto record_batch = create_test_record_batch(); - auto serialized = serialize_record_batch(record_batch); - CHECK_GT(serialized.size(), 0); - - // Check that it starts with continuation bytes - CHECK_GE(serialized.size(), continuation.size()); - for (size_t i = 0; i < continuation.size(); ++i) + SUBCASE("Valid record batch") { - CHECK_EQ(serialized[i], continuation[i]); + auto record_batch = create_test_record_batch(); + std::vector serialized; + memory_output_stream stream(serialized); + any_output_stream astream(stream); + serialize_record_batch(record_batch, astream); + CHECK_GT(serialized.size(), 0); + + // Check that it starts with continuation bytes + CHECK_GE(serialized.size(), continuation.size()); + for (size_t i = 0; i < continuation.size(); ++i) + { + CHECK_EQ(serialized[i], continuation[i]); + } + + // Check that the metadata part is aligned to 8 bytes + // Find the end of metadata (before body starts) + size_t continuation_size = continuation.size(); + size_t length_prefix_size = sizeof(uint32_t); + + CHECK_GT(serialized.size(), continuation_size + length_prefix_size); + + // Extract message length + uint32_t message_length; + std::memcpy(&message_length, serialized.data() + continuation_size, sizeof(uint32_t)); + + size_t metadata_end = continuation_size + length_prefix_size + message_length; + size_t aligned_metadata_end = utils::align_to_8(static_cast(metadata_end)); + + // Verify alignment + CHECK_EQ(aligned_metadata_end % 8, 0); + CHECK_LE(aligned_metadata_end, serialized.size()); } - // Check that the metadata part is aligned to 8 bytes - // Find the end of metadata (before body starts) - size_t continuation_size = continuation.size(); - size_t length_prefix_size = sizeof(uint32_t); - - CHECK_GT(serialized.size(), continuation_size + length_prefix_size); - - // Extract message length - uint32_t message_length; - std::memcpy(&message_length, serialized.data() + continuation_size, sizeof(uint32_t)); - - size_t metadata_end = continuation_size + length_prefix_size + message_length; - size_t aligned_metadata_end = utils::align_to_8(static_cast(metadata_end)); - - // Verify alignment - CHECK_EQ(aligned_metadata_end % 8, 0); - CHECK_LE(aligned_metadata_end, serialized.size()); - } - - SUBCASE("Empty record batch") - { - auto empty_batch = sp::record_batch({}); - auto serialized = serialize_record_batch(empty_batch); - CHECK_GT(serialized.size(), 0); - CHECK_GE(serialized.size(), continuation.size()); - } - } - - TEST_CASE("Integration test - schema and record batch serialization") - { - SUBCASE("Serialize schema and record batch for same data") - { - auto record_batch = create_test_record_batch(); - - auto schema_serialized = serialize_schema_message(record_batch); - auto record_batch_serialized = serialize_record_batch(record_batch); - - CHECK_GT(schema_serialized.size(), 0); - CHECK_GT(record_batch_serialized.size(), 0); - - // Both should start with continuation bytes - CHECK_GE(schema_serialized.size(), continuation.size()); - CHECK_GE(record_batch_serialized.size(), continuation.size()); - - // Both should be properly aligned - CHECK_EQ(schema_serialized.size() % 8, 0); + SUBCASE("Empty record batch") + { + auto empty_batch = sp::record_batch({}); + std::vector serialized; + memory_output_stream stream(serialized); + any_output_stream astream(stream); + serialize_record_batch(empty_batch, astream); + CHECK_GT(serialized.size(), 0); + CHECK_GE(serialized.size(), continuation.size()); + } } } } \ No newline at end of file diff --git a/tests/test_serializer.cpp b/tests/test_serializer.cpp new file mode 100644 index 0000000..c35bcaa --- /dev/null +++ b/tests/test_serializer.cpp @@ -0,0 +1,581 @@ +#include +#include + +#include +#include + +#include "sparrow_ipc/memory_output_stream.hpp" +#include "sparrow_ipc/serializer.hpp" +#include "sparrow_ipc_tests_helpers.hpp" + +namespace sparrow_ipc +{ + namespace sp = sparrow; + + // Stream wrapper types for testing + struct memory_stream_wrapper + { + using buffer_type = std::vector; + buffer_type buffer; + memory_output_stream stream{buffer}; + + auto& get_stream() { return stream; } + size_t size() const { return buffer.size(); } + }; + + struct ostringstream_wrapper + { + std::ostringstream oss; + + auto& get_stream() { return oss; } + size_t size() { return static_cast(oss.tellp()); } + }; + + TEST_SUITE("serializer") + { + TEST_CASE_TEMPLATE("construction and write single record batch", StreamWrapper, memory_stream_wrapper, ostringstream_wrapper) + { + SUBCASE("Valid record batch") + { + auto rb = create_test_record_batch(); + StreamWrapper wrapper; + serializer ser(wrapper.get_stream()); + ser.write(rb); + + // After writing first record batch, should have schema + record batch + CHECK_GT(wrapper.size(), 0); + } + + SUBCASE("Empty record batch") + { + auto empty_batch = sp::record_batch({}); + StreamWrapper wrapper; + serializer ser(wrapper.get_stream()); + ser.write(empty_batch); + + CHECK_GT(wrapper.size(), 0); + } + } + + TEST_CASE_TEMPLATE("construction and write range of record batches", StreamWrapper, memory_stream_wrapper, ostringstream_wrapper) + { + SUBCASE("Valid record batches") + { + auto array1 = sp::primitive_array({1, 2, 3}); + auto array2 = sp::primitive_array({1.0, 2.0, 3.0}); + auto rb1 = sp::record_batch( + {{"col1", sp::array(std::move(array1))}, {"col2", sp::array(std::move(array2))}} + ); + + auto array3 = sp::primitive_array({4, 5, 6}); + auto array4 = sp::primitive_array({4.0, 5.0, 6.0}); + auto rb2 = sp::record_batch( + {{"col1", sp::array(std::move(array3))}, {"col2", sp::array(std::move(array4))}} + ); + + std::vector record_batches = {rb1, rb2}; + StreamWrapper wrapper; + serializer ser(wrapper.get_stream()); + ser.write(record_batches); + + // Should have schema + 2 record batches + CHECK_GT(wrapper.size(), 0); + } + + SUBCASE("Reserve is called correctly") + { + auto rb = create_test_record_batch(); + std::vector record_batches = {rb}; + StreamWrapper wrapper; + serializer ser(wrapper.get_stream()); + ser.write(record_batches); + + // Verify that buffer has been written + CHECK_GT(wrapper.size(), 0); + } + } + + TEST_CASE_TEMPLATE("write single record batch", StreamWrapper, memory_stream_wrapper, ostringstream_wrapper) + { + SUBCASE("Write after construction with single batch") + { + auto rb1 = create_test_record_batch(); + StreamWrapper wrapper; + serializer ser(wrapper.get_stream()); + ser.write(rb1); + size_t size_after_construction = wrapper.size(); + + // Create compatible record batch + auto rb2 = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({6, 7, 8}))}, + {"string_col", sp::array(sp::string_array(std::vector{"foo", "bar", "baz"}))}} + ); + + ser.write(rb2); + + CHECK_GT(wrapper.size(), size_after_construction); + } + + SUBCASE("Multiple writes") + { + auto rb1 = create_test_record_batch(); + StreamWrapper wrapper; + serializer ser(wrapper.get_stream()); + ser.write(rb1); + size_t initial_size = wrapper.size(); + + for (int i = 0; i < 3; ++i) + { + auto rb = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({i}))}, + {"string_col", sp::array(sp::string_array(std::vector{"test"}))}} + ); + ser.write(rb); + } + + CHECK_GT(wrapper.size(), initial_size); + } + + SUBCASE("Mismatched schema throws exception") + { + auto rb1 = create_test_record_batch(); + StreamWrapper wrapper; + serializer ser(wrapper.get_stream()); + ser.write(rb1); + + // Create record batch with different schema + auto rb2 = sp::record_batch( + {{"different_col", sp::array(sp::primitive_array({1, 2, 3}))}} + ); + + CHECK_THROWS_AS(ser.write(rb2), std::invalid_argument); + } + } + + TEST_CASE_TEMPLATE("write range of record batches", StreamWrapper, memory_stream_wrapper, ostringstream_wrapper) + { + SUBCASE("Write range after construction") + { + auto rb1 = create_test_record_batch(); + StreamWrapper wrapper; + serializer ser(wrapper.get_stream()); + ser.write(rb1); + size_t initial_size = wrapper.size(); + + auto array1 = sp::primitive_array({10, 20}); + auto array2 = sp::string_array(std::vector{"a", "b"}); + auto rb2 = sp::record_batch( + {{"int_col", sp::array(std::move(array1))}, + {"string_col", sp::array(std::move(array2))}} + ); + + auto array3 = sp::primitive_array({30, 40}); + auto array4 = sp::string_array(std::vector{"c", "d"}); + auto rb3 = sp::record_batch( + {{"int_col", sp::array(std::move(array3))}, + {"string_col", sp::array(std::move(array4))}} + ); + + std::vector new_batches = {rb2, rb3}; + ser.write(new_batches); + + CHECK_GT(wrapper.size(), initial_size); + } + + SUBCASE("Reserve is called during range write") + { + auto rb1 = create_test_record_batch(); + StreamWrapper wrapper; + serializer ser(wrapper.get_stream()); + ser.write(rb1); + + auto rb2 = create_test_record_batch(); + auto rb3 = create_test_record_batch(); + std::vector new_batches = {rb2, rb3}; + + size_t size_before = wrapper.size(); + ser.write(new_batches); + + // Reserve should have been called, buffer should have grown + CHECK_GT(wrapper.size(), size_before); + } + + SUBCASE("Empty range write does nothing") + { + auto rb1 = create_test_record_batch(); + StreamWrapper wrapper; + serializer ser(wrapper.get_stream()); + ser.write(rb1); + size_t initial_size = wrapper.size(); + + std::vector empty_batches; + ser.write(empty_batches); + + CHECK_EQ(wrapper.size(), initial_size); + } + + SUBCASE("Mismatched schema in range throws exception") + { + auto rb1 = create_test_record_batch(); + StreamWrapper wrapper; + serializer ser(wrapper.get_stream()); + ser.write(rb1); + + auto rb2 = create_test_record_batch(); + auto rb3 = sp::record_batch( + {{"different_col", sp::array(sp::primitive_array({1, 2, 3}))}} + ); + + std::vector new_batches = {rb2, rb3}; + CHECK_THROWS_AS(ser.write(new_batches), std::invalid_argument); + } + } + + TEST_CASE_TEMPLATE("end serialization", StreamWrapper, memory_stream_wrapper, ostringstream_wrapper) + { + SUBCASE("End after construction") + { + auto rb = create_test_record_batch(); + StreamWrapper wrapper; + serializer ser(wrapper.get_stream()); + ser.write(rb); + size_t initial_size = wrapper.size(); + + ser.end(); + + // End should add end-of-stream marker + CHECK_GT(wrapper.size(), initial_size); + } + + SUBCASE("Cannot write after end") + { + auto rb1 = create_test_record_batch(); + StreamWrapper wrapper; + serializer ser(wrapper.get_stream()); + ser.write(rb1); + ser.end(); + + auto rb2 = create_test_record_batch(); + CHECK_THROWS_AS(ser.write(rb2), std::runtime_error); + } + + SUBCASE("Cannot write range after end") + { + auto rb1 = create_test_record_batch(); + StreamWrapper wrapper; + serializer ser(wrapper.get_stream()); + ser.write(rb1); + ser.end(); + + std::vector new_batches = {create_test_record_batch()}; + CHECK_THROWS_AS(ser.write(new_batches), std::runtime_error); + } + } + + TEST_CASE_TEMPLATE("stream size tracking", StreamWrapper, memory_stream_wrapper, ostringstream_wrapper) + { + SUBCASE("Size increases with each operation") + { + auto rb = create_test_record_batch(); + StreamWrapper wrapper; + size_t size_before = wrapper.size(); + serializer ser(wrapper.get_stream()); + ser.write(rb); + size_t size_after_construction = wrapper.size(); + + CHECK_GT(size_after_construction, size_before); + + ser.write(rb); + size_t size_after_write = wrapper.size(); + + CHECK_GT(size_after_write, size_after_construction); + } + } + + TEST_CASE_TEMPLATE("large number of record batches", StreamWrapper, memory_stream_wrapper, ostringstream_wrapper) + { + SUBCASE("Handle many record batches efficiently") + { + StreamWrapper wrapper; + std::vector batches; + const int num_batches = 100; + + for (int i = 0; i < num_batches; ++i) + { + auto array = sp::primitive_array({i, i+1, i+2}); + batches.push_back(sp::record_batch({{"col", sp::array(std::move(array))}})); + } + + serializer ser(wrapper.get_stream()); + ser.write(batches); + + // Should have schema + all batches + CHECK_GT(wrapper.size(), 0); + } + } + + TEST_CASE_TEMPLATE("different column types", StreamWrapper, memory_stream_wrapper, ostringstream_wrapper) + { + SUBCASE("Multiple primitive types") + { + auto int_array = sp::primitive_array({1, 2, 3}); + auto double_array = sp::primitive_array({1.5, 2.5, 3.5}); + auto float_array = sp::primitive_array({1.0f, 2.0f, 3.0f}); + + auto rb = sp::record_batch( + {{"int_col", sp::array(std::move(int_array))}, + {"double_col", sp::array(std::move(double_array))}, + {"float_col", sp::array(std::move(float_array))}} + ); + + StreamWrapper wrapper; + serializer ser(wrapper.get_stream()); + ser.write(rb); + + CHECK_GT(wrapper.size(), 0); + } + } + + TEST_CASE_TEMPLATE("operator<< with single record batch", StreamWrapper, memory_stream_wrapper, ostringstream_wrapper) + { + SUBCASE("Single batch write using <<") + { + auto rb1 = create_test_record_batch(); + StreamWrapper wrapper; + serializer ser(wrapper.get_stream()); + ser << rb1; + size_t size_after_construction = wrapper.size(); + + auto rb2 = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({6, 7, 8}))}, + {"string_col", sp::array(sp::string_array(std::vector{"foo", "bar", "baz"}))}} + ); + + ser << rb2; + + CHECK_GT(wrapper.size(), size_after_construction); + } + + SUBCASE("Chaining multiple single batch writes") + { + auto rb1 = create_test_record_batch(); + StreamWrapper wrapper; + serializer ser(wrapper.get_stream()); + ser << rb1; + size_t initial_size = wrapper.size(); + + auto rb2 = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({10, 20}))}, + {"string_col", sp::array(sp::string_array(std::vector{"a", "b"}))}} + ); + + auto rb3 = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({30, 40}))}, + {"string_col", sp::array(sp::string_array(std::vector{"c", "d"}))}} + ); + + auto rb4 = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({50, 60}))}, + {"string_col", sp::array(sp::string_array(std::vector{"e", "f"}))}} + ); + + ser << rb2 << rb3 << rb4; + + CHECK_GT(wrapper.size(), initial_size); + } + + SUBCASE("Cannot use << after end") + { + auto rb1 = create_test_record_batch(); + StreamWrapper wrapper; + serializer ser(wrapper.get_stream()); + ser << rb1; + ser.end(); + + auto rb2 = create_test_record_batch(); + CHECK_THROWS_AS(ser << rb2, std::runtime_error); + } + + SUBCASE("Mismatched schema with << throws exception") + { + auto rb1 = create_test_record_batch(); + StreamWrapper wrapper; + serializer ser(wrapper.get_stream()); + ser << rb1; + + auto rb2 = sp::record_batch( + {{"different_col", sp::array(sp::primitive_array({1, 2, 3}))}} + ); + + CHECK_THROWS_AS(ser << rb2, std::invalid_argument); + } + } + + TEST_CASE_TEMPLATE("operator<< with range of record batches", StreamWrapper, memory_stream_wrapper, ostringstream_wrapper) + { + SUBCASE("Range write using <<") + { + auto rb1 = create_test_record_batch(); + StreamWrapper wrapper; + serializer ser(wrapper.get_stream()); + ser << rb1; + size_t initial_size = wrapper.size(); + + std::vector batches; + for (int i = 0; i < 3; ++i) + { + auto rb = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({i * 10}))}, + {"string_col", sp::array(sp::string_array(std::vector{"test"}))}} + ); + batches.push_back(rb); + } + + ser << batches; + + CHECK_GT(wrapper.size(), initial_size); + } + + SUBCASE("Chaining range and single batch writes") + { + auto rb1 = create_test_record_batch(); + StreamWrapper wrapper; + serializer ser(wrapper.get_stream()); + ser << rb1; + size_t initial_size = wrapper.size(); + + std::vector batches; + for (int i = 0; i < 2; ++i) + { + auto rb = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({i}))}, + {"string_col", sp::array(sp::string_array(std::vector{"x"}))}} + ); + batches.push_back(rb); + } + + auto rb2 = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({99}))}, + {"string_col", sp::array(sp::string_array(std::vector{"final"}))}} + ); + + ser << batches << rb2; + + CHECK_GT(wrapper.size(), initial_size); + } + + SUBCASE("Mixed chaining with multiple ranges") + { + auto rb1 = create_test_record_batch(); + StreamWrapper wrapper; + serializer ser(wrapper.get_stream()); + ser << rb1; + size_t initial_size = wrapper.size(); + + std::vector batches1; + batches1.push_back(sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({10}))}, + {"string_col", sp::array(sp::string_array(std::vector{"a"}))}} + )); + + std::vector batches2; + batches2.push_back(sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({20}))}, + {"string_col", sp::array(sp::string_array(std::vector{"b"}))}} + )); + + auto rb2 = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({30}))}, + {"string_col", sp::array(sp::string_array(std::vector{"c"}))}} + ); + + ser << batches1 << rb2 << batches2; + + CHECK_GT(wrapper.size(), initial_size); + } + + SUBCASE("Cannot use << with range after end") + { + auto rb1 = create_test_record_batch(); + StreamWrapper wrapper; + serializer ser(wrapper.get_stream()); + ser << rb1; + ser.end(); + + std::vector batches = {create_test_record_batch()}; + CHECK_THROWS_AS(ser << batches, std::runtime_error); + } + + SUBCASE("Mismatched schema in range with << throws exception") + { + auto rb1 = create_test_record_batch(); + StreamWrapper wrapper; + serializer ser(wrapper.get_stream()); + ser << rb1; + + auto rb2 = create_test_record_batch(); + auto rb3 = sp::record_batch( + {{"different_col", sp::array(sp::primitive_array({1, 2, 3}))}} + ); + + std::vector batches = {rb2, rb3}; + CHECK_THROWS_AS(ser << batches, std::invalid_argument); + } + } + + TEST_CASE_TEMPLATE("workflow example with << operator", StreamWrapper, memory_stream_wrapper, ostringstream_wrapper) + { + SUBCASE("Typical usage pattern with streaming syntax") + { + // Create initial record batch + auto rb1 = create_test_record_batch(); + + // Setup stream + StreamWrapper wrapper; + + // Create serializer and write initial batch + serializer ser(wrapper.get_stream()); + ser << rb1; + size_t size_after_init = wrapper.size(); + CHECK_GT(size_after_init, 0); + + // Stream more batches + auto rb2 = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({10, 20}))}, + {"string_col", sp::array(sp::string_array(std::vector{"x", "y"}))}} + ); + + ser << rb2; + size_t size_after_rb2 = wrapper.size(); + CHECK_GT(size_after_rb2, size_after_init); + + // Stream range of batches + std::vector more_batches; + for (int i = 0; i < 3; ++i) + { + auto rb = sp::record_batch( + {{"int_col", sp::array(sp::primitive_array({i}))}, + {"string_col", sp::array(sp::string_array(std::vector{"test"}))}} + ); + more_batches.push_back(rb); + } + + ser << more_batches; + size_t size_after_range = wrapper.size(); + CHECK_GT(size_after_range, size_after_rb2); + + // Mix single and range in one chain + auto rb3 = create_test_record_batch(); + std::vector final_batches = {create_test_record_batch()}; + + ser << rb3 << final_batches; + size_t size_after_chain = wrapper.size(); + CHECK_GT(size_after_chain, size_after_range); + + // End serialization + ser.end(); + CHECK_GT(wrapper.size(), size_after_chain); + } + } + } +} diff --git a/tests/test_utils.cpp b/tests/test_utils.cpp index ab9f4a0..0619d68 100644 --- a/tests/test_utils.cpp +++ b/tests/test_utils.cpp @@ -1,9 +1,8 @@ #include -#include -#include "sparrow_ipc/arrow_interface/arrow_array_schema_common_release.hpp" #include "sparrow_ipc/utils.hpp" + namespace sparrow_ipc { TEST_CASE("align_to_8") @@ -16,334 +15,4 @@ namespace sparrow_ipc CHECK_EQ(utils::align_to_8(15), 16); CHECK_EQ(utils::align_to_8(16), 16); } - - TEST_CASE("get_flatbuffer_type") - { - flatbuffers::FlatBufferBuilder builder; - SUBCASE("Null and Boolean types") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::NA)).first, - org::apache::arrow::flatbuf::Type::Null - ); - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::BOOL)).first, - org::apache::arrow::flatbuf::Type::Bool - ); - } - - SUBCASE("Integer types") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::INT8)).first, - org::apache::arrow::flatbuf::Type::Int - ); // INT8 - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::UINT8)).first, - org::apache::arrow::flatbuf::Type::Int - ); // UINT8 - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::INT16)).first, - org::apache::arrow::flatbuf::Type::Int - ); // INT16 - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::UINT16)).first, - org::apache::arrow::flatbuf::Type::Int - ); // UINT16 - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::INT32)).first, - org::apache::arrow::flatbuf::Type::Int - ); // INT32 - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::UINT32)).first, - org::apache::arrow::flatbuf::Type::Int - ); // UINT32 - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::INT64)).first, - org::apache::arrow::flatbuf::Type::Int - ); // INT64 - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::UINT64)).first, - org::apache::arrow::flatbuf::Type::Int - ); // UINT64 - } - - SUBCASE("Floating Point types") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::HALF_FLOAT)) - .first, - org::apache::arrow::flatbuf::Type::FloatingPoint - ); // HALF_FLOAT - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::FLOAT)).first, - org::apache::arrow::flatbuf::Type::FloatingPoint - ); // FLOAT - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::DOUBLE)).first, - org::apache::arrow::flatbuf::Type::FloatingPoint - ); // DOUBLE - } - - SUBCASE("String and Binary types") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::STRING)).first, - org::apache::arrow::flatbuf::Type::Utf8 - ); // STRING - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::LARGE_STRING)) - .first, - org::apache::arrow::flatbuf::Type::LargeUtf8 - ); // LARGE_STRING - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::BINARY)).first, - org::apache::arrow::flatbuf::Type::Binary - ); // BINARY - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::LARGE_BINARY)) - .first, - org::apache::arrow::flatbuf::Type::LargeBinary - ); // LARGE_BINARY - CHECK_EQ( - utils::get_flatbuffer_type(builder, "vu").first, - org::apache::arrow::flatbuf::Type::Utf8View - ); // STRING_VIEW - CHECK_EQ( - utils::get_flatbuffer_type(builder, "vz").first, - org::apache::arrow::flatbuf::Type::BinaryView - ); // BINARY_VIEW - } - - SUBCASE("Date types") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::DATE_DAYS)) - .first, - org::apache::arrow::flatbuf::Type::Date - ); // DATE_DAYS - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::DATE_MILLISECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Date - ); // DATE_MILLISECONDS - } - - SUBCASE("Timestamp types") - { - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::TIMESTAMP_SECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Timestamp - ); // TIMESTAMP_SECONDS - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::TIMESTAMP_MILLISECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Timestamp - ); // TIMESTAMP_MILLISECONDS - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::TIMESTAMP_MICROSECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Timestamp - ); // TIMESTAMP_MICROSECONDS - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::TIMESTAMP_NANOSECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Timestamp - ); // TIMESTAMP_NANOSECONDS - } - - SUBCASE("Duration types") - { - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::DURATION_SECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Duration - ); // DURATION_SECONDS - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::DURATION_MILLISECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Duration - ); // DURATION_MILLISECONDS - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::DURATION_MICROSECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Duration - ); // DURATION_MICROSECONDS - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::DURATION_NANOSECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Duration - ); // DURATION_NANOSECONDS - } - - SUBCASE("Interval types") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::INTERVAL_MONTHS)) - .first, - org::apache::arrow::flatbuf::Type::Interval - ); // INTERVAL_MONTHS - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::INTERVAL_DAYS_TIME) - ) - .first, - org::apache::arrow::flatbuf::Type::Interval - ); // INTERVAL_DAYS_TIME - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::INTERVAL_MONTHS_DAYS_NANOSECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Interval - ); // INTERVAL_MONTHS_DAYS_NANOSECONDS - } - - SUBCASE("Time types") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::TIME_SECONDS)) - .first, - org::apache::arrow::flatbuf::Type::Time - ); // TIME_SECONDS - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::TIME_MILLISECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Time - ); // TIME_MILLISECONDS - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::TIME_MICROSECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Time - ); // TIME_MICROSECONDS - CHECK_EQ( - utils::get_flatbuffer_type( - builder, - sparrow::data_type_to_format(sparrow::data_type::TIME_NANOSECONDS) - ) - .first, - org::apache::arrow::flatbuf::Type::Time - ); // TIME_NANOSECONDS - } - - SUBCASE("List types") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::LIST)).first, - org::apache::arrow::flatbuf::Type::List - ); // LIST - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::LARGE_LIST)) - .first, - org::apache::arrow::flatbuf::Type::LargeList - ); // LARGE_LIST - CHECK_EQ( - utils::get_flatbuffer_type(builder, "+vl").first, - org::apache::arrow::flatbuf::Type::ListView - ); // LIST_VIEW - CHECK_EQ( - utils::get_flatbuffer_type(builder, "+vL").first, - org::apache::arrow::flatbuf::Type::LargeListView - ); // LARGE_LIST_VIEW - CHECK_EQ( - utils::get_flatbuffer_type(builder, "+w:16").first, - org::apache::arrow::flatbuf::Type::FixedSizeList - ); // FIXED_SIZED_LIST - CHECK_THROWS(utils::get_flatbuffer_type(builder, "+w:")); // Invalid FixedSizeList format - } - - SUBCASE("Struct and Map types") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::STRUCT)).first, - org::apache::arrow::flatbuf::Type::Struct_ - ); // STRUCT - CHECK_EQ( - utils::get_flatbuffer_type(builder, sparrow::data_type_to_format(sparrow::data_type::MAP)).first, - org::apache::arrow::flatbuf::Type::Map - ); // MAP - } - - SUBCASE("Union types") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, "+ud:").first, - org::apache::arrow::flatbuf::Type::Union - ); // DENSE_UNION - CHECK_EQ( - utils::get_flatbuffer_type(builder, "+us:").first, - org::apache::arrow::flatbuf::Type::Union - ); // SPARSE_UNION - } - - SUBCASE("Run-End Encoded type") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, "+r").first, - org::apache::arrow::flatbuf::Type::RunEndEncoded - ); // RUN_ENCODED - } - - SUBCASE("Decimal types") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, "d:10,5").first, - org::apache::arrow::flatbuf::Type::Decimal - ); // DECIMAL (general) - CHECK_THROWS(utils::get_flatbuffer_type(builder, "d:10")); // Invalid Decimal format - } - - SUBCASE("Fixed Width Binary type") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, "w:32").first, - org::apache::arrow::flatbuf::Type::FixedSizeBinary - ); // FIXED_WIDTH_BINARY - CHECK_THROWS(utils::get_flatbuffer_type(builder, "w:")); // Invalid FixedSizeBinary format - } - - SUBCASE("Unsupported type returns Null") - { - CHECK_EQ( - utils::get_flatbuffer_type(builder, "unsupported_format").first, - org::apache::arrow::flatbuf::Type::Null - ); - } - } }