Skip to content

Commit 0d949aa

Browse files
committed
Minor changes
1 parent 66b7660 commit 0d949aa

File tree

4 files changed

+106
-68
lines changed

4 files changed

+106
-68
lines changed

environment-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ dependencies:
88
- cxx-compiler
99
# Libraries dependencies
1010
- flatbuffers
11-
- lz4
11+
- lz4-c
1212
- nlohmann_json
1313
- sparrow-devel >=1.1.2
1414
# Testing dependencies

include/sparrow_ipc/chunk_memory_serializer.hpp

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ namespace sparrow_ipc
4444
* @param stream Reference to a chunked memory output stream that will receive the serialized chunks
4545
* @param compression Optional: The compression type to use for record batch bodies.
4646
*/
47+
// TODO Use enums and such to avoid including flatbuffers headers
4748
chunk_serializer(chunked_memory_output_stream<std::vector<std::vector<uint8_t>>>& stream, std::optional<org::apache::arrow::flatbuf::CompressionType> compression = std::nullopt);
4849

4950
/**
@@ -144,21 +145,7 @@ namespace sparrow_ipc
144145
throw std::runtime_error("Cannot append record batches to a serializer that has been ended");
145146
}
146147

147-
const auto reserve_function = [&record_batches, this]()
148-
{
149-
return std::accumulate(
150-
record_batches.begin(),
151-
record_batches.end(),
152-
m_pstream->size(),
153-
[this](size_t acc, const sparrow::record_batch& rb)
154-
{
155-
return acc + calculate_record_batch_message_size(rb, m_compression);
156-
}
157-
)
158-
+ (m_schema_received ? 0 : calculate_schema_message_size(*record_batches.begin()));
159-
};
160-
161-
m_pstream->reserve(reserve_function);
148+
m_pstream->reserve((m_schema_received ? 0 : 1) + m_pstream->size() + record_batches.size());
162149

163150
if (!m_schema_received)
164151
{

src/compression.cpp

Lines changed: 55 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,49 @@ namespace sparrow_ipc
2020
// }
2121
// }
2222

23+
std::vector<std::uint8_t> lz4_compress(std::span<const std::uint8_t> data)
24+
{
25+
const std::int64_t uncompressed_size = data.size();
26+
const size_t max_compressed_size = LZ4F_compressFrameBound(uncompressed_size, nullptr);
27+
std::vector<std::uint8_t> compressed_data(max_compressed_size);
28+
const size_t compressed_size = LZ4F_compressFrame(compressed_data.data(), max_compressed_size, data.data(), uncompressed_size, nullptr);
29+
if (LZ4F_isError(compressed_size))
30+
{
31+
throw std::runtime_error("Failed to compress data with LZ4 frame format");
32+
}
33+
compressed_data.resize(compressed_size);
34+
return compressed_data;
35+
}
36+
37+
std::vector<std::uint8_t> lz4_decompress(std::span<const std::uint8_t> data)
38+
{
39+
if (data.size() < 8)
40+
{
41+
throw std::runtime_error("Invalid compressed data: missing decompressed size");
42+
}
43+
const std::int64_t decompressed_size = *reinterpret_cast<const std::int64_t*>(data.data());
44+
const auto compressed_data = data.subspan(8);
45+
46+
if (decompressed_size == -1)
47+
{
48+
// TODO think of avoiding copy here
49+
return {compressed_data.begin(), compressed_data.end()};
50+
}
51+
52+
std::vector<std::uint8_t> decompressed_data(decompressed_size);
53+
LZ4F_dctx* dctx = nullptr;
54+
LZ4F_createDecompressionContext(&dctx, LZ4F_VERSION);
55+
size_t compressed_size_in_out = compressed_data.size();
56+
size_t decompressed_size_in_out = decompressed_size;
57+
size_t result = LZ4F_decompress(dctx, decompressed_data.data(), &decompressed_size_in_out, compressed_data.data(), &compressed_size_in_out, nullptr);
58+
if (LZ4F_isError(result) || (decompressed_size_in_out != (size_t)decompressed_size))
59+
{
60+
throw std::runtime_error("Failed to decompress data with LZ4 frame format");
61+
}
62+
LZ4F_freeDecompressionContext(dctx);
63+
return decompressed_data;
64+
}
65+
2366
std::vector<std::uint8_t> compress(const org::apache::arrow::flatbuf::CompressionType compression_type, std::span<const std::uint8_t> data)
2467
{
2568
if (data.empty())
@@ -30,18 +73,14 @@ namespace sparrow_ipc
3073
{
3174
case org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME:
3275
{
33-
const std::int64_t uncompressed_size = data.size();
34-
const size_t max_compressed_size = LZ4F_compressFrameBound(uncompressed_size, nullptr);
35-
std::vector<std::uint8_t> compressed_data(max_compressed_size);
36-
const size_t compressed_size = LZ4F_compressFrame(compressed_data.data(), max_compressed_size, data.data(), uncompressed_size, nullptr);
37-
if (LZ4F_isError(compressed_size))
38-
{
39-
throw std::runtime_error("Failed to compress data with LZ4 frame format");
40-
}
41-
compressed_data.resize(compressed_size);
42-
return compressed_data;
76+
return lz4_compress(data);
77+
}
78+
case org::apache::arrow::flatbuf::CompressionType::ZSTD:
79+
{
80+
throw std::runtime_error("Compression using zstd is not supported yet.");
4381
}
4482
default:
83+
// TODO think of avoiding copy here
4584
return {data.begin(), data.end()};
4685
}
4786
}
@@ -56,32 +95,14 @@ namespace sparrow_ipc
5695
{
5796
case org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME:
5897
{
59-
if (data.size() < 8)
60-
{
61-
throw std::runtime_error("Invalid compressed data: missing decompressed size");
62-
}
63-
const std::int64_t decompressed_size = *reinterpret_cast<const std::int64_t*>(data.data());
64-
const auto compressed_data = data.subspan(8);
65-
66-
if (decompressed_size == -1)
67-
{
68-
return {compressed_data.begin(), compressed_data.end()};
69-
}
70-
71-
std::vector<std::uint8_t> decompressed_data(decompressed_size);
72-
LZ4F_dctx* dctx = nullptr;
73-
LZ4F_createDecompressionContext(&dctx, LZ4F_VERSION);
74-
size_t compressed_size_in_out = compressed_data.size();
75-
size_t decompressed_size_in_out = decompressed_size;
76-
size_t result = LZ4F_decompress(dctx, decompressed_data.data(), &decompressed_size_in_out, compressed_data.data(), &compressed_size_in_out, nullptr);
77-
if (LZ4F_isError(result))
78-
{
79-
throw std::runtime_error("Failed to decompress data with LZ4 frame format");
80-
}
81-
LZ4F_freeDecompressionContext(dctx);
82-
return decompressed_data;
98+
return lz4_decompress(data);
99+
}
100+
case org::apache::arrow::flatbuf::CompressionType::ZSTD:
101+
{
102+
throw std::runtime_error("Decompression using zstd is not supported yet.");
83103
}
84104
default:
105+
// TODO think of avoiding copy here
85106
return {data.begin(), data.end()};
86107
}
87108
}

tests/test_compression.cpp

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,59 @@
77

88
namespace sparrow_ipc
99
{
10-
TEST_CASE("Compression and Decompression Round-trip")
10+
TEST_SUITE("De/Compression")
1111
{
12-
std::string original_string = "hello world, this is a test of compression and decompression!!";
13-
std::vector<uint8_t> original_data(original_string.begin(), original_string.end());
14-
const int64_t original_size = original_data.size();
12+
TEST_CASE("Unsupported ZSTD de/compression")
13+
{
14+
std::string original_string = "some data to compress";
15+
std::vector<uint8_t> original_data(original_string.begin(), original_string.end());
16+
const auto compression_type = org::apache::arrow::flatbuf::CompressionType::ZSTD;
1517

16-
// Compress data
17-
auto compression_type = org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME;
18-
std::vector<uint8_t> compressed_frame = compress(compression_type, original_data);
18+
// Test compression with ZSTD
19+
CHECK_THROWS_WITH_AS(sparrow_ipc::compress(compression_type, original_data), "Compression using zstd is not supported yet.", std::runtime_error);
1920

20-
CHECK_GT(compressed_frame.size(), 0);
21-
CHECK_NE(compressed_frame, original_data);
21+
// Test decompression with ZSTD
22+
CHECK_THROWS_WITH_AS(sparrow_ipc::decompress(compression_type, original_data), "Decompression using zstd is not supported yet.", std::runtime_error);
23+
}
2224

23-
// Manually create the IPC-formatted compressed buffer by adding the 8-byte prefix
24-
std::vector<uint8_t> ipc_formatted_buffer;
25-
ipc_formatted_buffer.reserve(sizeof(int64_t) + compressed_frame.size());
26-
ipc_formatted_buffer.insert(ipc_formatted_buffer.end(), reinterpret_cast<const uint8_t*>(&original_size), reinterpret_cast<const uint8_t*>(&original_size) + sizeof(int64_t));
27-
ipc_formatted_buffer.insert(ipc_formatted_buffer.end(), compressed_frame.begin(), compressed_frame.end());
25+
TEST_CASE("Empty data")
26+
{
27+
const std::vector<uint8_t> empty_data;
28+
const auto compression_type = org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME;
2829

29-
// Decompress
30-
std::vector<uint8_t> decompressed_data = decompress(compression_type, ipc_formatted_buffer);
30+
// Test compression of empty data
31+
auto compressed = compress(compression_type, empty_data);
32+
CHECK(compressed.empty());
3133

32-
CHECK_EQ(decompressed_data.size(), original_data.size());
33-
CHECK_EQ(decompressed_data, original_data);
34+
// Test decompression of empty data
35+
auto decompressed = decompress(compression_type, empty_data);
36+
CHECK(decompressed.empty());
37+
}
38+
39+
TEST_CASE("Data compression and decompression round-trip")
40+
{
41+
std::string original_string = "hello world, this is a test of compression and decompression!!";
42+
std::vector<uint8_t> original_data(original_string.begin(), original_string.end());
43+
const int64_t original_size = original_data.size();
44+
45+
// Compress data
46+
auto compression_type = org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME;
47+
std::vector<uint8_t> compressed_frame = compress(compression_type, original_data);
48+
49+
CHECK_GT(compressed_frame.size(), 0);
50+
CHECK_NE(compressed_frame, original_data);
51+
52+
// Manually create the IPC-formatted compressed buffer by adding the 8-byte prefix
53+
std::vector<uint8_t> ipc_formatted_buffer;
54+
ipc_formatted_buffer.reserve(sizeof(int64_t) + compressed_frame.size());
55+
ipc_formatted_buffer.insert(ipc_formatted_buffer.end(), reinterpret_cast<const uint8_t*>(&original_size), reinterpret_cast<const uint8_t*>(&original_size) + sizeof(int64_t));
56+
ipc_formatted_buffer.insert(ipc_formatted_buffer.end(), compressed_frame.begin(), compressed_frame.end());
57+
58+
// Decompress
59+
std::vector<uint8_t> decompressed_data = decompress(compression_type, ipc_formatted_buffer);
60+
61+
CHECK_EQ(decompressed_data.size(), original_data.size());
62+
CHECK_EQ(decompressed_data, original_data);
63+
}
3464
}
3565
}

0 commit comments

Comments
 (0)