diff --git a/src/compression.cpp b/src/compression.cpp index 0d0da2e..7579f6e 100644 --- a/src/compression.cpp +++ b/src/compression.cpp @@ -158,18 +158,19 @@ namespace sparrow_ipc using compress_func = std::function(std::span)>; using decompress_func = std::function(std::span, int64_t)>; - std::vector lz4_compress(std::span data) + std::vector lz4_compress_with_header(std::span data) { const std::int64_t uncompressed_size = data.size(); const size_t max_compressed_size = LZ4F_compressFrameBound(uncompressed_size, nullptr); - std::vector compressed_data(max_compressed_size); - const size_t compressed_size = LZ4F_compressFrame(compressed_data.data(), max_compressed_size, data.data(), uncompressed_size, nullptr); + std::vector result(details::CompressionHeaderSize + max_compressed_size); + const size_t compressed_size = LZ4F_compressFrame(result.data() + details::CompressionHeaderSize, max_compressed_size, data.data(), uncompressed_size, nullptr); if (LZ4F_isError(compressed_size)) { throw std::runtime_error("Failed to compress data with LZ4 frame format"); } - compressed_data.resize(compressed_size); - return compressed_data; + memcpy(result.data(), &uncompressed_size, sizeof(uncompressed_size)); + result.resize(details::CompressionHeaderSize + compressed_size); + return result; } std::vector lz4_decompress(std::span data, const std::int64_t decompressed_size) @@ -188,18 +189,19 @@ namespace sparrow_ipc return decompressed_data; } - std::vector zstd_compress(std::span data) + std::vector zstd_compress_with_header(std::span data) { const std::int64_t uncompressed_size = data.size(); const size_t max_compressed_size = ZSTD_compressBound(uncompressed_size); - std::vector compressed_data(max_compressed_size); - const size_t compressed_size = ZSTD_compress(compressed_data.data(), max_compressed_size, data.data(), uncompressed_size, 1); + std::vector result(details::CompressionHeaderSize + max_compressed_size); + const size_t compressed_size = ZSTD_compress(result.data() + details::CompressionHeaderSize, max_compressed_size, data.data(), uncompressed_size, 1); if (ZSTD_isError(compressed_size)) { throw std::runtime_error("Failed to compress data with ZSTD"); } - compressed_data.resize(compressed_size); - return compressed_data; + memcpy(result.data(), &uncompressed_size, sizeof(uncompressed_size)); + result.resize(details::CompressionHeaderSize + compressed_size); + return result; } std::vector zstd_decompress(std::span data, const std::int64_t decompressed_size) @@ -213,14 +215,6 @@ namespace sparrow_ipc return decompressed_data; } - void insert_compressed_data(std::vector& result, std::int64_t original_size, std::vector&& compressed_body) - { - result.reserve(details::CompressionHeaderSize + compressed_body.size()); - result.insert(result.end(), reinterpret_cast(&original_size), reinterpret_cast(&original_size) + sizeof(original_size)); - // TODO Think about avoid copying here (on every uint8_t), maybe use a list of vectors (header + body) and serialize separately on top level code instead of including header at this point? - result.insert(result.end(), std::make_move_iterator(compressed_body.begin()), std::make_move_iterator(compressed_body.end())); - } - void insert_uncompressed_data(std::vector& result, const std::span& data) { const std::int64_t header = -1; @@ -244,24 +238,18 @@ namespace sparrow_ipc } // Not in cache, compress and store - const std::int64_t original_size = data.size(); - - std::vector compressed_body; if (comp_func) { - compressed_body = comp_func(data); + auto compressed_with_header = comp_func(data); + // Compression is effective + if (compressed_with_header.size() - details::CompressionHeaderSize < data.size()) + { + return cache.store(buffer_ptr, buffer_size, std::move(compressed_with_header)); + } } std::vector result_vec; - if (comp_func && compressed_body.size() < static_cast(original_size)) - { - insert_compressed_data(result_vec, original_size, std::move(compressed_body)); - } - else - { - insert_uncompressed_data(result_vec, data); - } - + insert_uncompressed_data(result_vec, data); return cache.store(buffer_ptr, buffer_size, std::move(result_vec)); } @@ -301,11 +289,11 @@ namespace sparrow_ipc { case CompressionType::LZ4_FRAME: { - return compress_with_header(data, lz4_compress, cache); + return compress_with_header(data, lz4_compress_with_header, cache); } case CompressionType::ZSTD: { - return compress_with_header(data, zstd_compress, cache); + return compress_with_header(data, zstd_compress_with_header, cache); } } assert(false && "Unhandled compression type");