Skip to content

Commit c3dbe23

Browse files
committed
Refactor and add some tests
1 parent 8c4e318 commit c3dbe23

File tree

4 files changed

+156
-71
lines changed

4 files changed

+156
-71
lines changed

include/sparrow_ipc/flatbuffer_utils.hpp

Lines changed: 69 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <sparrow/record_batch.hpp>
77

88
#include "sparrow_ipc/compression.hpp"
9+
#include "sparrow_ipc/utils.hpp"
910

1011
namespace sparrow_ipc
1112
{
@@ -166,6 +167,42 @@ namespace sparrow_ipc
166167
[[nodiscard]] std::vector<org::apache::arrow::flatbuf::FieldNode>
167168
create_fieldnodes(const sparrow::record_batch& record_batch);
168169

170+
namespace details
171+
{
172+
template <typename Func>
173+
void fill_buffers_impl(
174+
const sparrow::arrow_proxy& arrow_proxy,
175+
std::vector<org::apache::arrow::flatbuf::Buffer>& flatbuf_buffers,
176+
int64_t& offset,
177+
Func&& get_buffer_size
178+
)
179+
{
180+
const auto& buffers = arrow_proxy.buffers();
181+
for (const auto& buffer : buffers)
182+
{
183+
int64_t size = get_buffer_size(buffer);
184+
flatbuf_buffers.emplace_back(offset, size);
185+
offset += utils::align_to_8(size);
186+
}
187+
for (const auto& child : arrow_proxy.children())
188+
{
189+
fill_buffers_impl(child, flatbuf_buffers, offset, get_buffer_size);
190+
}
191+
}
192+
193+
template <typename Func>
194+
std::vector<org::apache::arrow::flatbuf::Buffer> get_buffers_impl(const sparrow::record_batch& record_batch, Func&& fill_buffers_func)
195+
{
196+
std::vector<org::apache::arrow::flatbuf::Buffer> buffers;
197+
int64_t offset = 0;
198+
for (const auto& column : record_batch.columns())
199+
{
200+
const auto& arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(column);
201+
fill_buffers_func(arrow_proxy, buffers, offset);
202+
}
203+
return buffers;
204+
}
205+
} // namespace details
169206

170207
/**
171208
* @brief Recursively fills a vector of FlatBuffer Buffer objects with buffer information from an Arrow
@@ -208,20 +245,39 @@ namespace sparrow_ipc
208245
get_buffers(const sparrow::record_batch& record_batch);
209246

210247
/**
211-
* @brief Generates the compressed message body and buffer metadata for a record batch.
248+
* @brief Recursively populates a vector with compressed buffer metadata from an Arrow proxy.
212249
*
213-
* This function traverses the record batch, compresses each buffer using the specified
214-
* compression algorithm, and constructs the message body. For each compressed buffer,
215-
* it is prefixed by its 8-byte uncompressed size. Padding is added after each
216-
* compressed buffer to ensure 8-byte alignment.
250+
* This function traverses the Arrow proxy and its children, compressing each buffer and recording
251+
* its metadata (offset and size) in the provided vector. The offset is updated to ensure proper
252+
* alignment for each subsequent buffer.
217253
*
218-
* @param record_batch The record batch to serialize.
219-
* @param compression_type The compression algorithm to use (e.g., LZ4_FRAME, ZSTD).
220-
* @return A vector of FlatBuffer Buffer objects describing the offset and
221-
* size of each buffer within the compressed body.
254+
* @param arrow_proxy The Arrow proxy containing the buffers to be compressed.
255+
* @param flatbuf_compressed_buffers A vector to store the resulting compressed buffer metadata.
256+
* @param offset The current offset in the buffer layout, which will be updated by the function.
257+
* @param compression_type The compression algorithm to use.
222258
*/
223-
[[nodiscard]] SPARROW_IPC_API std::vector<org::apache::arrow::flatbuf::Buffer>
224-
generate_compressed_buffers(const sparrow::record_batch& record_batch, const CompressionType compression_type);
259+
void fill_compressed_buffers(
260+
const sparrow::arrow_proxy& arrow_proxy,
261+
std::vector<org::apache::arrow::flatbuf::Buffer>& flatbuf_compressed_buffers,
262+
int64_t& offset,
263+
const CompressionType compression_type
264+
);
265+
266+
/**
267+
* @brief Retrieves metadata describing the layout of compressed buffers within a record batch.
268+
*
269+
* This function processes a record batch to determine the metadata (offset and size)
270+
* for each of its buffers, assuming they are compressed using the specified algorithm.
271+
* This metadata accounts for each compressed buffer being prefixed by its 8-byte
272+
* uncompressed size and padded to ensure 8-byte alignment.
273+
*
274+
* @param record_batch The record batch whose buffers' compressed metadata is to be retrieved.
275+
* @param compression_type The compression algorithm that would be applied (e.g., LZ4_FRAME, ZSTD).
276+
* @return A vector of FlatBuffer Buffer objects, each describing the offset and
277+
* size of a corresponding compressed buffer within a larger message body.
278+
*/
279+
[[nodiscard]] std::vector<org::apache::arrow::flatbuf::Buffer>
280+
get_compressed_buffers(const sparrow::record_batch& record_batch, const CompressionType compression_type);
225281

226282
/**
227283
* @brief Calculates the total size of the body section for an Arrow array.
@@ -234,7 +290,7 @@ namespace sparrow_ipc
234290
* @param compression The compression type to use when serializing
235291
* @return int64_t The total aligned size in bytes of all buffers in the array hierarchy
236292
*/
237-
[[nodiscard]] SPARROW_IPC_API int64_t calculate_body_size(const sparrow::arrow_proxy& arrow_proxy, std::optional<CompressionType> compression = std::nullopt);
293+
[[nodiscard]] int64_t calculate_body_size(const sparrow::arrow_proxy& arrow_proxy, std::optional<CompressionType> compression = std::nullopt);
238294

239295
/**
240296
* @brief Calculates the total body size of a record batch by summing the body sizes of all its columns.
@@ -247,7 +303,7 @@ namespace sparrow_ipc
247303
* @param compression The compression type to use when serializing
248304
* @return int64_t The total body size in bytes of all columns in the record batch
249305
*/
250-
[[nodiscard]] SPARROW_IPC_API int64_t calculate_body_size(const sparrow::record_batch& record_batch, std::optional<CompressionType> compression = std::nullopt);
306+
[[nodiscard]] int64_t calculate_body_size(const sparrow::record_batch& record_batch, std::optional<CompressionType> compression = std::nullopt);
251307

252308
/**
253309
* @brief Creates a FlatBuffer message containing a serialized Apache Arrow RecordBatch.

src/flatbuffer_utils.cpp

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
#include "compression_impl.hpp"
44
#include "sparrow_ipc/flatbuffer_utils.hpp"
5-
#include "sparrow_ipc/utils.hpp"
65

76
namespace sparrow_ipc
87
{
@@ -537,49 +536,38 @@ namespace sparrow_ipc
537536
int64_t& offset
538537
)
539538
{
540-
const auto& buffers = arrow_proxy.buffers();
541-
for (const auto& buffer : buffers)
542-
{
543-
int64_t size = static_cast<int64_t>(buffer.size());
544-
flatbuf_buffers.emplace_back(offset, size);
545-
offset += utils::align_to_8(size);
546-
}
547-
for (const auto& child : arrow_proxy.children())
548-
{
549-
fill_buffers(child, flatbuf_buffers, offset);
550-
}
539+
details::fill_buffers_impl(arrow_proxy, flatbuf_buffers, offset, [](const auto& buffer) {
540+
return static_cast<int64_t>(buffer.size());
541+
});
551542
}
552543

553544
std::vector<org::apache::arrow::flatbuf::Buffer> get_buffers(const sparrow::record_batch& record_batch)
554545
{
555-
std::vector<org::apache::arrow::flatbuf::Buffer> buffers;
556-
std::int64_t offset = 0;
557-
for (const auto& column : record_batch.columns())
558-
{
559-
const auto& arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(column);
560-
fill_buffers(arrow_proxy, buffers, offset);
561-
}
562-
return buffers;
546+
return details::get_buffers_impl(record_batch, [](const sparrow::arrow_proxy& proxy, std::vector<org::apache::arrow::flatbuf::Buffer>& buffers, int64_t& offset) {
547+
fill_buffers(proxy, buffers, offset);
548+
});
563549
}
564550

565-
std::vector<org::apache::arrow::flatbuf::Buffer>
566-
generate_compressed_buffers(const sparrow::record_batch& record_batch, const CompressionType compression_type)
551+
void fill_compressed_buffers(
552+
const sparrow::arrow_proxy& arrow_proxy,
553+
std::vector<org::apache::arrow::flatbuf::Buffer>& flatbuf_compressed_buffers,
554+
int64_t& offset,
555+
const CompressionType compression_type
556+
)
567557
{
568-
std::vector<org::apache::arrow::flatbuf::Buffer> compressed_buffers;
569-
int64_t current_offset = 0;
558+
details::fill_buffers_impl(
559+
arrow_proxy, flatbuf_compressed_buffers, offset, [&](const auto& buffer) {
560+
return compress(compression_type, std::span<const uint8_t>(buffer.data(), buffer.size()))
561+
.size();
562+
});
563+
}
570564

571-
for (const auto& column : record_batch.columns())
572-
{
573-
const auto& arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(column);
574-
for (const auto& buffer : arrow_proxy.buffers())
575-
{
576-
std::vector<uint8_t> compressed_buffer_with_header = compress(compression_type, std::span<const uint8_t>(buffer.data(), buffer.size()));
577-
const size_t aligned_chunk_size = utils::align_to_8(compressed_buffer_with_header.size());
578-
compressed_buffers.emplace_back(current_offset, aligned_chunk_size);
579-
current_offset += aligned_chunk_size;
580-
}
581-
}
582-
return compressed_buffers;
565+
std::vector<org::apache::arrow::flatbuf::Buffer>
566+
get_compressed_buffers(const sparrow::record_batch& record_batch, const CompressionType compression_type)
567+
{
568+
return details::get_buffers_impl(record_batch, [&](const sparrow::arrow_proxy& proxy, std::vector<org::apache::arrow::flatbuf::Buffer>& buffers, int64_t& offset) {
569+
fill_compressed_buffers(proxy, buffers, offset, compression_type);
570+
});
583571
}
584572

585573
int64_t calculate_body_size(const sparrow::arrow_proxy& arrow_proxy, std::optional<CompressionType> compression)
@@ -628,7 +616,7 @@ namespace sparrow_ipc
628616
std::optional<std::vector<org::apache::arrow::flatbuf::Buffer>> compressed_buffers;
629617
if (compression)
630618
{
631-
compressed_buffers = generate_compressed_buffers(record_batch, compression.value());
619+
compressed_buffers = get_compressed_buffers(record_batch, compression.value());
632620
compression_offset = org::apache::arrow::flatbuf::CreateBodyCompression(record_batch_builder, details::to_fb_compression_type(compression.value()), org::apache::arrow::flatbuf::BodyCompressionMethod::BUFFER);
633621
}
634622
const auto& buffers = compressed_buffers ? *compressed_buffers : get_buffers(record_batch);

tests/test_flatbuffer_utils.cpp

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -178,40 +178,75 @@ namespace sparrow_ipc
178178
}
179179
}
180180

181+
void test_fill_buffers_variant(
182+
const std::function<void(const sparrow::arrow_proxy&, std::vector<org::apache::arrow::flatbuf::Buffer>&, int64_t&)>& fill_func)
183+
{
184+
auto array = sp::primitive_array<int32_t>({1, 2, 3, 4, 5});
185+
auto proxy = sp::detail::array_access::get_arrow_proxy(array);
186+
187+
std::vector<org::apache::arrow::flatbuf::Buffer> buffers;
188+
int64_t offset = 0;
189+
fill_func(proxy, buffers, offset);
190+
191+
CHECK_GT(buffers.size(), 0);
192+
CHECK_GT(offset, 0);
193+
194+
// Verify offsets are aligned
195+
for (const auto& buffer : buffers)
196+
{
197+
CHECK_EQ(buffer.offset() % 8, 0);
198+
}
199+
}
200+
181201
TEST_CASE("fill_buffers")
182202
{
183203
SUBCASE("Simple primitive array")
184204
{
185-
auto array = sp::primitive_array<int32_t>({1, 2, 3, 4, 5});
186-
auto proxy = sp::detail::array_access::get_arrow_proxy(array);
187-
188-
std::vector<org::apache::arrow::flatbuf::Buffer> buffers;
189-
int64_t offset = 0;
190-
fill_buffers(proxy, buffers, offset);
205+
test_fill_buffers_variant([](const sparrow::arrow_proxy& proxy, std::vector<org::apache::arrow::flatbuf::Buffer>& buffers, int64_t& offset) {
206+
fill_buffers(proxy, buffers, offset);
207+
});
208+
}
209+
}
191210

192-
CHECK_GT(buffers.size(), 0);
193-
CHECK_GT(offset, 0);
211+
TEST_CASE("fill_compressed_buffers")
212+
{
213+
SUBCASE("Simple primitive array")
214+
{
215+
test_fill_buffers_variant([](const sparrow::arrow_proxy& proxy, std::vector<org::apache::arrow::flatbuf::Buffer>& buffers, int64_t& offset) {
216+
fill_compressed_buffers(proxy, buffers, offset, CompressionType::LZ4_FRAME);
217+
});
218+
}
219+
}
194220

195-
// Verify offsets are aligned
196-
for (const auto& buffer : buffers)
197-
{
198-
CHECK_EQ(buffer.offset() % 8, 0);
199-
}
221+
void test_get_buffers_variant(const std::function<std::vector<org::apache::arrow::flatbuf::Buffer>(const sparrow::record_batch&)>& get_func)
222+
{
223+
auto record_batch = create_test_record_batch();
224+
auto buffers = get_func(record_batch);
225+
CHECK_GT(buffers.size(), 0);
226+
// Verify all offsets are properly calculated and aligned
227+
for (size_t i = 1; i < buffers.size(); ++i)
228+
{
229+
CHECK_GE(buffers[i].offset(), buffers[i - 1].offset() + buffers[i - 1].length());
200230
}
201231
}
202232

203233
TEST_CASE("get_buffers")
204234
{
205235
SUBCASE("Record batch with multiple columns")
206236
{
207-
auto record_batch = create_test_record_batch();
208-
auto buffers = get_buffers(record_batch);
209-
CHECK_GT(buffers.size(), 0);
210-
// Verify all offsets are properly calculated and aligned
211-
for (size_t i = 1; i < buffers.size(); ++i)
212-
{
213-
CHECK_GE(buffers[i].offset(), buffers[i - 1].offset() + buffers[i - 1].length());
214-
}
237+
test_get_buffers_variant([](const sparrow::record_batch& record_batch) {
238+
return get_buffers(record_batch);
239+
});
240+
}
241+
}
242+
243+
TEST_CASE("get_compressed_buffers")
244+
{
245+
SUBCASE("Record batch with multiple columns")
246+
{
247+
test_get_buffers_variant([](const sparrow::record_batch& record_batch) {
248+
return get_compressed_buffers(record_batch, CompressionType::LZ4_FRAME);
249+
});
215250
}
216251
}
217252

tests/test_serialize_utils.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,18 @@ namespace sparrow_ipc
6969
sparrow_ipc::memory_output_stream stream_compressed(body_compressed);
7070
sparrow_ipc::any_output_stream astream_compressed(stream_compressed);
7171
fill_body(proxy, astream_compressed, CompressionType::LZ4_FRAME);
72+
CHECK_GT(body_compressed.size(), 0);
73+
// Body size should be aligned
74+
CHECK_EQ(body_compressed.size() % 8, 0);
7275

7376
// Uncompressed
7477
std::vector<uint8_t> body_uncompressed;
7578
sparrow_ipc::memory_output_stream stream_uncompressed(body_uncompressed);
7679
sparrow_ipc::any_output_stream astream_uncompressed(stream_uncompressed);
7780
fill_body(proxy, astream_uncompressed, std::nullopt);
81+
CHECK_GT(body_uncompressed.size(), 0);
82+
// Body size should be aligned
83+
CHECK_EQ(body_uncompressed.size() % 8, 0);
7884
// Check that compressed size is smaller than uncompressed size
7985
CHECK_LT(body_compressed.size(), body_uncompressed.size());
8086
}

0 commit comments

Comments
 (0)