Skip to content

Commit a778fff

Browse files
committed
Rework compression and add tests
1 parent 0d949aa commit a778fff

File tree

5 files changed

+164
-77
lines changed

5 files changed

+164
-77
lines changed

include/sparrow_ipc/compression.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <cstdint>
44
#include <span>
5+
#include <variant>
56
#include <vector>
67

78
#include "Message_generated.h"
@@ -20,6 +21,8 @@ namespace sparrow_ipc
2021

2122
// CompressionType to_compression_type(org::apache::arrow::flatbuf::CompressionType compression_type);
2223

23-
SPARROW_IPC_API std::vector<std::uint8_t> compress(const org::apache::arrow::flatbuf::CompressionType compression_type, std::span<const std::uint8_t> data);
24-
SPARROW_IPC_API std::vector<std::uint8_t> decompress(const org::apache::arrow::flatbuf::CompressionType compression_type, std::span<const std::uint8_t> data);
24+
constexpr auto CompressionHeaderSize = sizeof(std::int64_t);
25+
26+
[[nodiscard]] SPARROW_IPC_API std::vector<std::uint8_t> compress(const org::apache::arrow::flatbuf::CompressionType compression_type, std::span<const std::uint8_t> data);
27+
[[nodiscard]] SPARROW_IPC_API std::variant<std::vector<std::uint8_t>, std::span<const std::uint8_t>> decompress(const org::apache::arrow::flatbuf::CompressionType compression_type, std::span<const std::uint8_t> data);
2528
}

src/compression.cpp

Lines changed: 85 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -20,90 +20,135 @@ namespace sparrow_ipc
2020
// }
2121
// }
2222

23-
std::vector<std::uint8_t> lz4_compress(std::span<const std::uint8_t> data)
23+
namespace
2424
{
25-
const std::int64_t uncompressed_size = data.size();
26-
const size_t max_compressed_size = LZ4F_compressFrameBound(uncompressed_size, nullptr);
27-
std::vector<std::uint8_t> compressed_data(max_compressed_size);
28-
const size_t compressed_size = LZ4F_compressFrame(compressed_data.data(), max_compressed_size, data.data(), uncompressed_size, nullptr);
29-
if (LZ4F_isError(compressed_size))
25+
std::vector<std::uint8_t> lz4_compress(std::span<const std::uint8_t> data)
3026
{
31-
throw std::runtime_error("Failed to compress data with LZ4 frame format");
27+
const std::int64_t uncompressed_size = data.size();
28+
const size_t max_compressed_size = LZ4F_compressFrameBound(uncompressed_size, nullptr);
29+
std::vector<std::uint8_t> compressed_data(max_compressed_size);
30+
const size_t compressed_size = LZ4F_compressFrame(compressed_data.data(), max_compressed_size, data.data(), uncompressed_size, nullptr);
31+
if (LZ4F_isError(compressed_size))
32+
{
33+
throw std::runtime_error("Failed to compress data with LZ4 frame format");
34+
}
35+
compressed_data.resize(compressed_size);
36+
return compressed_data;
3237
}
33-
compressed_data.resize(compressed_size);
34-
return compressed_data;
35-
}
3638

37-
std::vector<std::uint8_t> lz4_decompress(std::span<const std::uint8_t> data)
38-
{
39-
if (data.size() < 8)
39+
std::vector<std::uint8_t> lz4_decompress(std::span<const std::uint8_t> data, const std::int64_t decompressed_size)
4040
{
41-
throw std::runtime_error("Invalid compressed data: missing decompressed size");
41+
std::vector<std::uint8_t> decompressed_data(decompressed_size);
42+
LZ4F_dctx* dctx = nullptr;
43+
LZ4F_createDecompressionContext(&dctx, LZ4F_VERSION);
44+
size_t compressed_size_in_out = data.size();
45+
size_t decompressed_size_in_out = decompressed_size;
46+
size_t result = LZ4F_decompress(dctx, decompressed_data.data(), &decompressed_size_in_out, data.data(), &compressed_size_in_out, nullptr);
47+
if (LZ4F_isError(result) || (decompressed_size_in_out != (size_t)decompressed_size))
48+
{
49+
throw std::runtime_error("Failed to decompress data with LZ4 frame format");
50+
}
51+
LZ4F_freeDecompressionContext(dctx);
52+
return decompressed_data;
4253
}
43-
const std::int64_t decompressed_size = *reinterpret_cast<const std::int64_t*>(data.data());
44-
const auto compressed_data = data.subspan(8);
4554

46-
if (decompressed_size == -1)
55+
// TODO These functions could be moved to serialize_utils and deserialize_utils if preferred
56+
// as they are handling the header size
57+
std::vector<std::uint8_t> uncompressed_data_with_header(std::span<const std::uint8_t> data)
4758
{
48-
// TODO think of avoiding copy here
49-
return {compressed_data.begin(), compressed_data.end()};
59+
std::vector<std::uint8_t> result;
60+
result.reserve(CompressionHeaderSize + data.size());
61+
const std::int64_t header = -1;
62+
result.insert(result.end(), reinterpret_cast<const uint8_t*>(&header), reinterpret_cast<const uint8_t*>(&header) + sizeof(header));
63+
result.insert(result.end(), data.begin(), data.end());
64+
return result;
5065
}
5166

52-
std::vector<std::uint8_t> decompressed_data(decompressed_size);
53-
LZ4F_dctx* dctx = nullptr;
54-
LZ4F_createDecompressionContext(&dctx, LZ4F_VERSION);
55-
size_t compressed_size_in_out = compressed_data.size();
56-
size_t decompressed_size_in_out = decompressed_size;
57-
size_t result = LZ4F_decompress(dctx, decompressed_data.data(), &decompressed_size_in_out, compressed_data.data(), &compressed_size_in_out, nullptr);
58-
if (LZ4F_isError(result) || (decompressed_size_in_out != (size_t)decompressed_size))
67+
std::vector<std::uint8_t> lz4_compress_with_header(std::span<const std::uint8_t> data)
5968
{
60-
throw std::runtime_error("Failed to decompress data with LZ4 frame format");
69+
const std::int64_t original_size = data.size();
70+
auto compressed_body = lz4_compress(data);
71+
72+
if (compressed_body.size() >= static_cast<size_t>(original_size))
73+
{
74+
return uncompressed_data_with_header(data);
75+
}
76+
77+
std::vector<std::uint8_t> result;
78+
result.reserve(CompressionHeaderSize + compressed_body.size());
79+
result.insert(result.end(), reinterpret_cast<const uint8_t*>(&original_size), reinterpret_cast<const uint8_t*>(&original_size) + sizeof(original_size));
80+
result.insert(result.end(), compressed_body.begin(), compressed_body.end());
81+
return result;
82+
}
83+
84+
std::variant<std::vector<std::uint8_t>, std::span<const std::uint8_t>> lz4_decompress_with_header(std::span<const std::uint8_t> data)
85+
{
86+
if (data.size() < CompressionHeaderSize)
87+
{
88+
throw std::runtime_error("Invalid compressed data: missing decompressed size");
89+
}
90+
const std::int64_t decompressed_size = *reinterpret_cast<const std::int64_t*>(data.data());
91+
const auto compressed_data = data.subspan(CompressionHeaderSize);
92+
93+
if (decompressed_size == -1)
94+
{
95+
return compressed_data;
96+
}
97+
98+
return lz4_decompress(compressed_data, decompressed_size);
99+
}
100+
101+
std::span<const uint8_t> get_body_from_uncompressed_data(std::span<const uint8_t> data)
102+
{
103+
if (data.size() < CompressionHeaderSize)
104+
{
105+
throw std::runtime_error("Invalid data: missing header");
106+
}
107+
return data.subspan(CompressionHeaderSize);
61108
}
62-
LZ4F_freeDecompressionContext(dctx);
63-
return decompressed_data;
64109
}
65110

66111
std::vector<std::uint8_t> compress(const org::apache::arrow::flatbuf::CompressionType compression_type, std::span<const std::uint8_t> data)
67112
{
68-
if (data.empty())
69-
{
70-
return {};
71-
}
72113
switch (compression_type)
73114
{
74115
case org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME:
75116
{
76-
return lz4_compress(data);
117+
return lz4_compress_with_header(data);
77118
}
78119
case org::apache::arrow::flatbuf::CompressionType::ZSTD:
79120
{
80121
throw std::runtime_error("Compression using zstd is not supported yet.");
81122
}
82123
default:
83-
// TODO think of avoiding copy here
84-
return {data.begin(), data.end()};
124+
return uncompressed_data_with_header(data);
85125
}
86126
}
87127

88-
std::vector<std::uint8_t> decompress(const org::apache::arrow::flatbuf::CompressionType compression_type, std::span<const std::uint8_t> data)
128+
std::variant<std::vector<std::uint8_t>, std::span<const std::uint8_t>> decompress(const org::apache::arrow::flatbuf::CompressionType compression_type, std::span<const std::uint8_t> data)
89129
{
130+
// Handle empty input: an empty span is a valid representation for an empty buffer
131+
// (e.g., a validity bitmap for a column with no nulls) and should decompress to an empty output.
132+
// TODO if we don't call this fct anymore on validity buffers, remove this empty data handling
90133
if (data.empty())
91134
{
92135
return {};
93136
}
137+
94138
switch (compression_type)
95139
{
96140
case org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME:
97141
{
98-
return lz4_decompress(data);
142+
return lz4_decompress_with_header(data);
99143
}
100144
case org::apache::arrow::flatbuf::CompressionType::ZSTD:
101145
{
102146
throw std::runtime_error("Decompression using zstd is not supported yet.");
103147
}
104148
default:
105-
// TODO think of avoiding copy here
106-
return {data.begin(), data.end()};
149+
{
150+
return get_body_from_uncompressed_data(data);
151+
}
107152
}
108153
}
109154
}

src/deserialize_utils.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,28 @@ namespace sparrow_ipc::utils
6161

6262
if (compression)
6363
{
64-
decompressed_storage.emplace_back(decompress(compression->codec(), buffer_span));
65-
buffer_span = decompressed_storage.back();
64+
auto decompressed_result = decompress(compression->codec(), buffer_span);
65+
return std::visit(
66+
[&decompressed_storage](auto&& arg) -> std::span<const uint8_t>
67+
{
68+
using T = std::decay_t<decltype(arg)>;
69+
if constexpr (std::is_same_v<T, std::vector<uint8_t>>)
70+
{
71+
// Decompression happened
72+
decompressed_storage.emplace_back(std::move(arg));
73+
return decompressed_storage.back();
74+
}
75+
else // It's a std::span<const uint8_t>
76+
{
77+
// No decompression, but we are in a compression context, so we must copy the buffer
78+
// to ensure the owning ArrowArray has access to all its data.
79+
// TODO think about other strategies
80+
decompressed_storage.emplace_back(arg.begin(), arg.end());
81+
return decompressed_storage.back();
82+
}
83+
},
84+
decompressed_result
85+
);
6686
}
6787
return buffer_span;
6888
}

src/serialize_utils.cpp

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ namespace sparrow_ipc
103103
return metadata_size + actual_body_size;
104104
}
105105

106-
[[nodiscard]] SPARROW_IPC_API std::pair<std::vector<uint8_t>, std::vector<org::apache::arrow::flatbuf::Buffer>>
106+
std::pair<std::vector<uint8_t>, std::vector<org::apache::arrow::flatbuf::Buffer>>
107107
generate_compressed_body_and_buffers(const sparrow::record_batch& record_batch, const org::apache::arrow::flatbuf::CompressionType compression_type)
108108
{
109109
std::vector<uint8_t> compressed_body;
@@ -115,23 +115,16 @@ namespace sparrow_ipc
115115
const auto& arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(column);
116116
for (const auto& buffer : arrow_proxy.buffers())
117117
{
118-
// Original buffer size
119-
const int64_t original_size = static_cast<int64_t>(buffer.size());
118+
// Compress the buffer. The returned buffer already has the correct size header.
119+
std::vector<uint8_t> compressed_buffer_with_header = compress(compression_type, std::span<const uint8_t>(buffer.data(), buffer.size()));
120120

121-
// Compress the buffer
122-
std::vector<uint8_t> compressed_buffer_data = compress(compression_type, std::span<const uint8_t>(buffer.data(), original_size));
123-
const int64_t compressed_data_size = static_cast<int64_t>(compressed_buffer_data.size());
121+
const size_t aligned_chunk_size = utils::align_to_8(compressed_buffer_with_header.size());
122+
const size_t padding_needed = aligned_chunk_size - compressed_buffer_with_header.size();
124123

125-
// Calculate total size of this compressed chunk (original size prefix + compressed data + padding)
126-
const int64_t total_chunk_size = sizeof(int64_t) + compressed_data_size;
127-
const size_t aligned_chunk_size = utils::align_to_8(total_chunk_size);
128-
const size_t padding_needed = aligned_chunk_size - total_chunk_size;
124+
// Write compressed data with header
125+
compressed_body.insert(compressed_body.end(), compressed_buffer_with_header.begin(), compressed_buffer_with_header.end());
129126

130-
// Write original size (8 bytes) followed by compressed data
131-
compressed_body.insert(compressed_body.end(), reinterpret_cast<const uint8_t*>(&original_size), reinterpret_cast<const uint8_t*>(&original_size) + sizeof(int64_t));
132-
compressed_body.insert(compressed_body.end(), compressed_buffer_data.begin(), compressed_buffer_data.end());
133-
134-
// Add padding to the compressed data
127+
// Add padding
135128
compressed_body.insert(compressed_body.end(), padding_needed, 0);
136129

137130
// Update compressed buffer metadata

tests/test_compression.cpp

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ namespace sparrow_ipc
1616
const auto compression_type = org::apache::arrow::flatbuf::CompressionType::ZSTD;
1717

1818
// Test compression with ZSTD
19-
CHECK_THROWS_WITH_AS(sparrow_ipc::compress(compression_type, original_data), "Compression using zstd is not supported yet.", std::runtime_error);
19+
CHECK_THROWS_WITH_AS(compress(compression_type, original_data), "Compression using zstd is not supported yet.", std::runtime_error);
2020

2121
// Test decompression with ZSTD
22-
CHECK_THROWS_WITH_AS(sparrow_ipc::decompress(compression_type, original_data), "Decompression using zstd is not supported yet.", std::runtime_error);
22+
CHECK_THROWS_WITH_AS(decompress(compression_type, original_data), "Decompression using zstd is not supported yet.", std::runtime_error);
2323
}
2424

2525
TEST_CASE("Empty data")
@@ -29,37 +29,63 @@ namespace sparrow_ipc
2929

3030
// Test compression of empty data
3131
auto compressed = compress(compression_type, empty_data);
32-
CHECK(compressed.empty());
32+
CHECK_EQ(compressed.size(), CompressionHeaderSize);
33+
const std::int64_t header = *reinterpret_cast<const std::int64_t*>(compressed.data());
34+
CHECK_EQ(header, -1);
3335

3436
// Test decompression of empty data
35-
auto decompressed = decompress(compression_type, empty_data);
36-
CHECK(decompressed.empty());
37+
auto decompressed = decompress(compression_type, compressed);
38+
std::visit([](const auto& value) { CHECK(value.empty()); }, decompressed);
3739
}
3840

3941
TEST_CASE("Data compression and decompression round-trip")
4042
{
41-
std::string original_string = "hello world, this is a test of compression and decompression!!";
43+
std::string original_string = "Hello world, this is a test of compression and decompression. But we need more words to make this compression worth it!";
4244
std::vector<uint8_t> original_data(original_string.begin(), original_string.end());
43-
const int64_t original_size = original_data.size();
4445

4546
// Compress data
4647
auto compression_type = org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME;
47-
std::vector<uint8_t> compressed_frame = compress(compression_type, original_data);
48+
std::vector<uint8_t> compressed_data = compress(compression_type, original_data);
4849

49-
CHECK_GT(compressed_frame.size(), 0);
50-
CHECK_NE(compressed_frame, original_data);
50+
// Decompress
51+
auto decompressed_result = decompress(compression_type, compressed_data);
52+
std::visit(
53+
[&original_data](const auto& decompressed_data)
54+
{
55+
CHECK_EQ(decompressed_data.size(), original_data.size());
56+
const std::vector<uint8_t> vec(decompressed_data.begin(), decompressed_data.end());
57+
CHECK_EQ(vec, original_data);
58+
},
59+
decompressed_result
60+
);
61+
}
5162

52-
// Manually create the IPC-formatted compressed buffer by adding the 8-byte prefix
53-
std::vector<uint8_t> ipc_formatted_buffer;
54-
ipc_formatted_buffer.reserve(sizeof(int64_t) + compressed_frame.size());
55-
ipc_formatted_buffer.insert(ipc_formatted_buffer.end(), reinterpret_cast<const uint8_t*>(&original_size), reinterpret_cast<const uint8_t*>(&original_size) + sizeof(int64_t));
56-
ipc_formatted_buffer.insert(ipc_formatted_buffer.end(), compressed_frame.begin(), compressed_frame.end());
63+
TEST_CASE("Data compression with incompressible data")
64+
{
65+
std::string original_string = "abc";
66+
std::vector<uint8_t> original_data(original_string.begin(), original_string.end());
67+
68+
// Compress data
69+
auto compression_type = org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME;
70+
std::vector<uint8_t> compressed_data = compress(compression_type, original_data);
5771

5872
// Decompress
59-
std::vector<uint8_t> decompressed_data = decompress(compression_type, ipc_formatted_buffer);
73+
auto decompressed_result = decompress(compression_type, compressed_data);
74+
std::visit(
75+
[&original_data](const auto& decompressed_data)
76+
{
77+
CHECK_EQ(decompressed_data.size(), original_data.size());
78+
const std::vector<uint8_t> vec(decompressed_data.begin(), decompressed_data.end());
79+
CHECK_EQ(vec, original_data);
80+
},
81+
decompressed_result
82+
);
6083

61-
CHECK_EQ(decompressed_data.size(), original_data.size());
62-
CHECK_EQ(decompressed_data, original_data);
84+
// Check that the compressed data is just the original data with a -1 header
85+
const std::int64_t header = *reinterpret_cast<const std::int64_t*>(compressed_data.data());
86+
CHECK_EQ(header, -1);
87+
std::vector<uint8_t> body(compressed_data.begin() + sizeof(header), compressed_data.end());
88+
CHECK_EQ(body, original_data);
6389
}
6490
}
6591
}

0 commit comments

Comments
 (0)