Skip to content

Commit 46291cb

Browse files
committed
Simplify
1 parent 273fc24 commit 46291cb

File tree

6 files changed

+86
-67
lines changed

6 files changed

+86
-67
lines changed

include/sparrow_ipc/deserialize_primitive_array.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ namespace sparrow_ipc
3838
const auto compression = record_batch.compression();
3939
std::vector<std::vector<std::uint8_t>> decompressed_buffers;
4040

41+
// TODO do not decompress validity buffers?
4142
auto validity_buffer_span = utils::get_and_decompress_buffer(record_batch, body, buffer_index, compression, decompressed_buffers);
4243

4344
const auto [bitmap_ptr, null_count] = utils::get_bitmap_pointer_and_null_count(validity_buffer_span, record_batch.length());

include/sparrow_ipc/serialize_utils.hpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,11 @@ namespace sparrow_ipc
128128
*
129129
* @param record_batch The record batch to serialize.
130130
* @param compression_type The compression algorithm to use (e.g., LZ4_FRAME, ZSTD).
131-
* @return A std::pair containing:
132-
* - first: A vector of bytes representing the complete compressed message body.
133-
* - second: A vector of FlatBuffer Buffer objects describing the offset and
134-
* size of each buffer within the compressed body.
131+
* @return A vector of FlatBuffer Buffer objects describing the offset and
132+
* size of each buffer within the compressed body.
135133
*/
136-
[[nodiscard]] SPARROW_IPC_API std::pair<std::vector<uint8_t>, std::vector<org::apache::arrow::flatbuf::Buffer>>
137-
generate_compressed_body_and_buffers(const sparrow::record_batch& record_batch, const org::apache::arrow::flatbuf::CompressionType compression_type);
134+
[[nodiscard]] SPARROW_IPC_API std::vector<org::apache::arrow::flatbuf::Buffer>
135+
generate_compressed_buffers(const sparrow::record_batch& record_batch, const org::apache::arrow::flatbuf::CompressionType compression_type);
138136

139137
/**
140138
* @brief Fills the body vector with serialized data from an arrow proxy and its children.
@@ -150,8 +148,9 @@ namespace sparrow_ipc
150148
*
151149
* @param arrow_proxy The arrow proxy containing buffers and potential child proxies to serialize
152150
* @param stream The output stream where the serialized body data will be written
151+
* @param compression The compression type to use when serializing
153152
*/
154-
SPARROW_IPC_API void fill_body(const sparrow::arrow_proxy& arrow_proxy, any_output_stream& stream);
153+
SPARROW_IPC_API void fill_body(const sparrow::arrow_proxy& arrow_proxy, any_output_stream& stream, std::optional<org::apache::arrow::flatbuf::CompressionType> compression = std::nullopt);
155154

156155
/**
157156
* @brief Generates a serialized body from a record batch.
@@ -162,8 +161,9 @@ namespace sparrow_ipc
162161
*
163162
* @param record_batch The record batch containing columns to be serialized
164163
* @param stream The output stream where the serialized body will be written
164+
* @param compression The compression type to use when serializing
165165
*/
166-
SPARROW_IPC_API void generate_body(const sparrow::record_batch& record_batch, any_output_stream& stream);
166+
SPARROW_IPC_API void generate_body(const sparrow::record_batch& record_batch, any_output_stream& stream, std::optional<org::apache::arrow::flatbuf::CompressionType> compression = std::nullopt);
167167

168168
/**
169169
* @brief Calculates the total size of the body section for an Arrow array.
@@ -173,9 +173,10 @@ namespace sparrow_ipc
173173
* buffer size is aligned to 8-byte boundaries as required by the Arrow format.
174174
*
175175
* @param arrow_proxy The Arrow array proxy containing buffers and child arrays
176+
* @param compression The compression type to use when serializing
176177
* @return int64_t The total aligned size in bytes of all buffers in the array hierarchy
177178
*/
178-
[[nodiscard]] SPARROW_IPC_API int64_t calculate_body_size(const sparrow::arrow_proxy& arrow_proxy);
179+
[[nodiscard]] SPARROW_IPC_API int64_t calculate_body_size(const sparrow::arrow_proxy& arrow_proxy, std::optional<org::apache::arrow::flatbuf::CompressionType> compression = std::nullopt);
179180

180181
/**
181182
* @brief Calculates the total body size of a record batch by summing the body sizes of all its columns.
@@ -185,9 +186,10 @@ namespace sparrow_ipc
185186
* the total memory required for the serialized data content of the record batch.
186187
*
187188
* @param record_batch The sparrow record batch containing columns to calculate size for
189+
* @param compression The compression type to use when serializing
188190
* @return int64_t The total body size in bytes of all columns in the record batch
189191
*/
190-
[[nodiscard]] SPARROW_IPC_API int64_t calculate_body_size(const sparrow::record_batch& record_batch);
192+
[[nodiscard]] SPARROW_IPC_API int64_t calculate_body_size(const sparrow::record_batch& record_batch, std::optional<org::apache::arrow::flatbuf::CompressionType> compression = std::nullopt);
191193

192194
SPARROW_IPC_API std::vector<sparrow::data_type> get_column_dtypes(const sparrow::record_batch& rb);
193195
}

src/flatbuffer_utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,7 @@ namespace sparrow_ipc
583583
0 // TODO :variadic buffer Counts
584584
);
585585

586-
const int64_t body_size = body_size_override.value_or(calculate_body_size(record_batch));
586+
const int64_t body_size = body_size_override.value_or(calculate_body_size(record_batch, compression));
587587
const auto record_batch_message_offset = org::apache::arrow::flatbuf::CreateMessage(
588588
record_batch_builder,
589589
org::apache::arrow::flatbuf::MetadataVersion::V5,

src/serialize.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,14 @@ namespace sparrow_ipc
2828
if (compression.has_value())
2929
{
3030
// TODO Handle this inside get_record_batch_message_builder
31-
auto [compressed_body, compressed_buffers] = generate_compressed_body_and_buffers(record_batch, compression.value());
32-
common_serialize(get_record_batch_message_builder(record_batch, compression, compressed_body.size(), &compressed_buffers), stream);
33-
// TODO Use something equivalent to generate_body (stream wise, handling children etc)
34-
stream.write(std::span(compressed_body.data(), compressed_body.size()));
31+
auto compressed_buffers = generate_compressed_buffers(record_batch, compression.value());
32+
auto body_size_override = calculate_body_size(record_batch, compression);
33+
common_serialize(get_record_batch_message_builder(record_batch, compression, body_size_override, &compressed_buffers), stream);
3534
}
3635
else
3736
{
38-
common_serialize(get_record_batch_message_builder(record_batch, compression), stream);
39-
generate_body(record_batch, stream);
37+
common_serialize(get_record_batch_message_builder(record_batch, compression, std::nullopt, nullptr), stream);
4038
}
39+
generate_body(record_batch, stream, compression);
4140
}
4241
}

src/serialize_utils.cpp

Lines changed: 42 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -7,52 +7,69 @@
77

88
namespace sparrow_ipc
99
{
10-
void fill_body(const sparrow::arrow_proxy& arrow_proxy, any_output_stream& stream)
10+
void fill_body(const sparrow::arrow_proxy& arrow_proxy, any_output_stream& stream, std::optional<org::apache::arrow::flatbuf::CompressionType> compression)
1111
{
12-
for (const auto& buffer : arrow_proxy.buffers())
13-
{
14-
stream.write(buffer);
12+
std::for_each(arrow_proxy.buffers().begin(), arrow_proxy.buffers().end(), [&](const auto& buffer) {
13+
if (compression.has_value())
14+
{
15+
auto compressed_buffer_with_header = compress(compression.value(), std::span<const uint8_t>(buffer.data(), buffer.size()));
16+
stream.write(std::span(compressed_buffer_with_header.data(), compressed_buffer_with_header.size()));
17+
}
18+
else
19+
{
20+
stream.write(buffer);
21+
}
1522
stream.add_padding();
16-
}
17-
for (const auto& child : arrow_proxy.children())
18-
{
19-
fill_body(child, stream);
20-
}
23+
});
24+
25+
std::for_each(arrow_proxy.children().begin(), arrow_proxy.children().end(), [&](const auto& child) {
26+
fill_body(child, stream, compression);
27+
});
2128
}
2229

23-
void generate_body(const sparrow::record_batch& record_batch, any_output_stream& stream)
30+
void generate_body(const sparrow::record_batch& record_batch, any_output_stream& stream, std::optional<org::apache::arrow::flatbuf::CompressionType> compression)
2431
{
25-
for (const auto& column : record_batch.columns())
26-
{
32+
std::for_each(record_batch.columns().begin(), record_batch.columns().end(), [&](const auto& column) {
2733
const auto& arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(column);
28-
fill_body(arrow_proxy, stream);
29-
}
34+
fill_body(arrow_proxy, stream, compression);
35+
});
3036
}
3137

32-
int64_t calculate_body_size(const sparrow::arrow_proxy& arrow_proxy)
38+
int64_t calculate_body_size(const sparrow::arrow_proxy& arrow_proxy, std::optional<org::apache::arrow::flatbuf::CompressionType> compression)
3339
{
3440
int64_t total_size = 0;
35-
for (const auto& buffer : arrow_proxy.buffers())
41+
if (compression.has_value())
42+
{
43+
for (const auto& buffer : arrow_proxy.buffers())
44+
{
45+
total_size += utils::align_to_8(compress(compression.value(), std::span<const uint8_t>(buffer.data(), buffer.size())).size());
46+
}
47+
}
48+
else
3649
{
37-
total_size += utils::align_to_8(buffer.size());
50+
for (const auto& buffer : arrow_proxy.buffers())
51+
{
52+
total_size += utils::align_to_8(buffer.size());
53+
}
3854
}
55+
3956
for (const auto& child : arrow_proxy.children())
4057
{
41-
total_size += calculate_body_size(child);
58+
total_size += calculate_body_size(child, compression);
4259
}
4360
return total_size;
4461
}
4562

46-
int64_t calculate_body_size(const sparrow::record_batch& record_batch)
63+
int64_t calculate_body_size(const sparrow::record_batch& record_batch, std::optional<org::apache::arrow::flatbuf::CompressionType> compression)
4764
{
4865
return std::accumulate(
4966
record_batch.columns().begin(),
5067
record_batch.columns().end(),
5168
int64_t{0},
52-
[](int64_t acc, const sparrow::array& arr)
69+
[&](int64_t acc, const sparrow::array& arr)
5370
{
5471
const auto& arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(arr);
55-
return acc + calculate_body_size(arrow_proxy);
72+
return acc + calculate_body_size(arrow_proxy, compression);
5673
}
5774
);
5875
}
@@ -78,18 +95,7 @@ namespace sparrow_ipc
7895
flatbuffers::FlatBufferBuilder record_batch_builder = get_record_batch_message_builder(record_batch, compression);
7996
const flatbuffers::uoffset_t record_batch_len = record_batch_builder.GetSize();
8097

81-
std::size_t actual_body_size = 0;
82-
if (compression.has_value())
83-
{
84-
// If compressed, the body size is the sum of compressed buffer sizes + original size prefixes + padding
85-
auto [compressed_body, compressed_buffers] = generate_compressed_body_and_buffers(record_batch, compression.value());
86-
actual_body_size = compressed_body.size();
87-
}
88-
else
89-
{
90-
// If not compressed, the body size is the sum of uncompressed buffer sizes with padding
91-
actual_body_size = static_cast<std::size_t>(calculate_body_size(record_batch));
92-
}
98+
const std::size_t actual_body_size = static_cast<std::size_t>(calculate_body_size(record_batch, compression));
9399

94100
// Calculate total size:
95101
// - Continuation bytes (4)
@@ -103,10 +109,9 @@ namespace sparrow_ipc
103109
return metadata_size + actual_body_size;
104110
}
105111

106-
std::pair<std::vector<uint8_t>, std::vector<org::apache::arrow::flatbuf::Buffer>>
107-
generate_compressed_body_and_buffers(const sparrow::record_batch& record_batch, const org::apache::arrow::flatbuf::CompressionType compression_type)
112+
std::vector<org::apache::arrow::flatbuf::Buffer>
113+
generate_compressed_buffers(const sparrow::record_batch& record_batch, const org::apache::arrow::flatbuf::CompressionType compression_type)
108114
{
109-
std::vector<uint8_t> compressed_body;
110115
std::vector<org::apache::arrow::flatbuf::Buffer> compressed_buffers;
111116
int64_t current_offset = 0;
112117

@@ -115,24 +120,13 @@ namespace sparrow_ipc
115120
const auto& arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(column);
116121
for (const auto& buffer : arrow_proxy.buffers())
117122
{
118-
// Compress the buffer. The returned buffer already has the correct size header.
119123
std::vector<uint8_t> compressed_buffer_with_header = compress(compression_type, std::span<const uint8_t>(buffer.data(), buffer.size()));
120-
121124
const size_t aligned_chunk_size = utils::align_to_8(compressed_buffer_with_header.size());
122-
const size_t padding_needed = aligned_chunk_size - compressed_buffer_with_header.size();
123-
124-
// Write compressed data with header
125-
compressed_body.insert(compressed_body.end(), compressed_buffer_with_header.begin(), compressed_buffer_with_header.end());
126-
127-
// Add padding
128-
compressed_body.insert(compressed_body.end(), padding_needed, 0);
129-
130-
// Update compressed buffer metadata
131125
compressed_buffers.emplace_back(current_offset, aligned_chunk_size);
132126
current_offset += aligned_chunk_size;
133127
}
134128
}
135-
return {compressed_body, compressed_buffers};
129+
return compressed_buffers;
136130
}
137131

138132
std::vector<sparrow::data_type> get_column_dtypes(const sparrow::record_batch& rb)

tests/test_serialize_utils.cpp

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,43 @@ namespace sparrow_ipc
4040
}
4141
}
4242

43+
// TODO after the used fcts are stable regarding compression, add tests for fcts having it as an additional argument
44+
// cf. fill_body example
4345
TEST_CASE("fill_body")
4446
{
45-
SUBCASE("Simple primitive array")
47+
SUBCASE("Simple primitive array (uncompressed)")
4648
{
4749
auto array = sp::primitive_array<int32_t>({1, 2, 3, 4, 5});
4850
auto proxy = sp::detail::array_access::get_arrow_proxy(array);
4951
std::vector<uint8_t> body;
5052
sparrow_ipc::memory_output_stream stream(body);
5153
sparrow_ipc::any_output_stream astream(stream);
52-
fill_body(proxy, astream);
54+
fill_body(proxy, astream, std::nullopt);
5355
CHECK_GT(body.size(), 0);
5456
// Body size should be aligned
5557
CHECK_EQ(body.size() % 8, 0);
5658
}
59+
60+
SUBCASE("Simple primitive array (compressible)")
61+
{
62+
std::vector<int32_t> data(1000, 12345); // Repeating values, should be very compressible
63+
auto array = sp::primitive_array<int32_t>(data);
64+
auto proxy = sp::detail::array_access::get_arrow_proxy(array);
65+
66+
// Compressed
67+
std::vector<uint8_t> body_compressed;
68+
sparrow_ipc::memory_output_stream stream_compressed(body_compressed);
69+
sparrow_ipc::any_output_stream astream_compressed(stream_compressed);
70+
fill_body(proxy, astream_compressed, org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME);
71+
72+
// Uncompressed
73+
std::vector<uint8_t> body_uncompressed;
74+
sparrow_ipc::memory_output_stream stream_uncompressed(body_uncompressed);
75+
sparrow_ipc::any_output_stream astream_uncompressed(stream_uncompressed);
76+
fill_body(proxy, astream_uncompressed, std::nullopt);
77+
// Check that compressed size is smaller than uncompressed size
78+
CHECK_LT(body_compressed.size(), body_uncompressed.size());
79+
}
5780
}
5881

5982
TEST_CASE("generate_body")

0 commit comments

Comments
 (0)