Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 21 additions & 33 deletions src/compression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,18 +158,19 @@ namespace sparrow_ipc
using compress_func = std::function<std::vector<uint8_t>(std::span<const uint8_t>)>;
using decompress_func = std::function<std::vector<uint8_t>(std::span<const uint8_t>, int64_t)>;

std::vector<std::uint8_t> lz4_compress(std::span<const std::uint8_t> data)
std::vector<std::uint8_t> lz4_compress_with_header(std::span<const std::uint8_t> data)
{
const std::int64_t uncompressed_size = data.size();
const size_t max_compressed_size = LZ4F_compressFrameBound(uncompressed_size, nullptr);
std::vector<std::uint8_t> compressed_data(max_compressed_size);
const size_t compressed_size = LZ4F_compressFrame(compressed_data.data(), max_compressed_size, data.data(), uncompressed_size, nullptr);
std::vector<std::uint8_t> result(details::CompressionHeaderSize + max_compressed_size);
const size_t compressed_size = LZ4F_compressFrame(result.data() + details::CompressionHeaderSize, max_compressed_size, data.data(), uncompressed_size, nullptr);
if (LZ4F_isError(compressed_size))
{
throw std::runtime_error("Failed to compress data with LZ4 frame format");
}
compressed_data.resize(compressed_size);
return compressed_data;
memcpy(result.data(), &uncompressed_size, sizeof(uncompressed_size));
result.resize(details::CompressionHeaderSize + compressed_size);
return result;
}

std::vector<std::uint8_t> lz4_decompress(std::span<const std::uint8_t> data, const std::int64_t decompressed_size)
Expand All @@ -188,18 +189,19 @@ namespace sparrow_ipc
return decompressed_data;
}

std::vector<std::uint8_t> zstd_compress(std::span<const std::uint8_t> data)
std::vector<std::uint8_t> zstd_compress_with_header(std::span<const std::uint8_t> data)
{
const std::int64_t uncompressed_size = data.size();
const size_t max_compressed_size = ZSTD_compressBound(uncompressed_size);
std::vector<std::uint8_t> compressed_data(max_compressed_size);
const size_t compressed_size = ZSTD_compress(compressed_data.data(), max_compressed_size, data.data(), uncompressed_size, 1);
std::vector<std::uint8_t> result(details::CompressionHeaderSize + max_compressed_size);
const size_t compressed_size = ZSTD_compress(result.data() + details::CompressionHeaderSize, max_compressed_size, data.data(), uncompressed_size, 1);
if (ZSTD_isError(compressed_size))
{
throw std::runtime_error("Failed to compress data with ZSTD");
}
compressed_data.resize(compressed_size);
return compressed_data;
memcpy(result.data(), &uncompressed_size, sizeof(uncompressed_size));
result.resize(details::CompressionHeaderSize + compressed_size);
return result;
}

std::vector<std::uint8_t> zstd_decompress(std::span<const std::uint8_t> data, const std::int64_t decompressed_size)
Expand All @@ -213,14 +215,6 @@ namespace sparrow_ipc
return decompressed_data;
}

void insert_compressed_data(std::vector<uint8_t>& result, std::int64_t original_size, std::vector<uint8_t>&& compressed_body)
{
result.reserve(details::CompressionHeaderSize + compressed_body.size());
result.insert(result.end(), reinterpret_cast<const uint8_t*>(&original_size), reinterpret_cast<const uint8_t*>(&original_size) + sizeof(original_size));
// TODO Think about avoid copying here (on every uint8_t), maybe use a list of vectors (header + body) and serialize separately on top level code instead of including header at this point?
result.insert(result.end(), std::make_move_iterator(compressed_body.begin()), std::make_move_iterator(compressed_body.end()));
}

void insert_uncompressed_data(std::vector<uint8_t>& result, const std::span<const uint8_t>& data)
{
const std::int64_t header = -1;
Expand All @@ -244,24 +238,18 @@ namespace sparrow_ipc
}

// Not in cache, compress and store
const std::int64_t original_size = data.size();

std::vector<std::uint8_t> compressed_body;
if (comp_func)
{
compressed_body = comp_func(data);
auto compressed_with_header = comp_func(data);
// Compression is effective
if (compressed_with_header.size() - details::CompressionHeaderSize < data.size())
{
return cache.store(buffer_ptr, buffer_size, std::move(compressed_with_header));
}
}

std::vector<uint8_t> result_vec;
if (comp_func && compressed_body.size() < static_cast<size_t>(original_size))
{
insert_compressed_data(result_vec, original_size, std::move(compressed_body));
}
else
{
insert_uncompressed_data(result_vec, data);
}

insert_uncompressed_data(result_vec, data);
return cache.store(buffer_ptr, buffer_size, std::move(result_vec));
}

Expand Down Expand Up @@ -301,11 +289,11 @@ namespace sparrow_ipc
{
case CompressionType::LZ4_FRAME:
{
return compress_with_header(data, lz4_compress, cache);
return compress_with_header(data, lz4_compress_with_header, cache);
}
case CompressionType::ZSTD:
{
return compress_with_header(data, zstd_compress, cache);
return compress_with_header(data, zstd_compress_with_header, cache);
}
}
assert(false && "Unhandled compression type");
Expand Down
Loading