Skip to content

Commit 413dffe

Browse files
committed
Add lz4 compression in deserialization
1 parent ed87642 commit 413dffe

File tree

5 files changed

+209
-40
lines changed

5 files changed

+209
-40
lines changed

include/sparrow_ipc/compression.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@ namespace sparrow_ipc
1919
// CompressionType to_compression_type(org::apache::arrow::flatbuf::CompressionType compression_type);
2020

2121
std::vector<std::uint8_t> compress(org::apache::arrow::flatbuf::CompressionType compression_type, std::span<const std::uint8_t> data);
22+
std::vector<std::uint8_t> decompress(org::apache::arrow::flatbuf::CompressionType compression_type, std::span<const std::uint8_t> data);
2223
}

include/sparrow_ipc/deserialize_primitive_array.hpp

Lines changed: 69 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,30 @@
99
#include "Message_generated.h"
1010
#include "sparrow_ipc/arrow_interface/arrow_array.hpp"
1111
#include "sparrow_ipc/arrow_interface/arrow_schema.hpp"
12+
#include "sparrow_ipc/compression.hpp"
1213
#include "sparrow_ipc/deserialize_utils.hpp"
1314

1415
namespace sparrow_ipc
1516
{
17+
namespace
18+
{
19+
struct DecompressedBuffers
20+
{
21+
std::vector<uint8_t> validity_buffer;
22+
std::vector<uint8_t> data_buffer;
23+
};
24+
25+
void release_decompressed_buffers(ArrowArray* array)
26+
{
27+
if (array->private_data)
28+
{
29+
delete static_cast<DecompressedBuffers*>(array->private_data);
30+
array->private_data = nullptr;
31+
}
32+
array->release = nullptr;
33+
}
34+
}
35+
1636
template <typename T>
1737
[[nodiscard]] sparrow::primitive_array<T> deserialize_non_owning_primitive_array(
1838
const org::apache::arrow::flatbuf::RecordBatch& record_batch,
@@ -22,6 +42,46 @@ namespace sparrow_ipc
2242
size_t& buffer_index
2343
)
2444
{
45+
const auto compression = record_batch.compression();
46+
DecompressedBuffers* decompressed_buffers_owner = nullptr;
47+
48+
// Validity buffer
49+
const auto validity_buffer_metadata = record_batch.buffers()->Get(buffer_index++);
50+
auto validity_buffer_span = body.subspan(validity_buffer_metadata->offset(), validity_buffer_metadata->length());
51+
if (compression)
52+
{
53+
if (!decompressed_buffers_owner)
54+
{
55+
decompressed_buffers_owner = new DecompressedBuffers();
56+
}
57+
decompressed_buffers_owner->validity_buffer = decompress(compression->codec(), validity_buffer_span);
58+
validity_buffer_span = decompressed_buffers_owner->validity_buffer;
59+
}
60+
auto bitmap_ptr = const_cast<uint8_t*>(validity_buffer_span.data());
61+
const sparrow::dynamic_bitset_view<const std::uint8_t> bitmap_view{
62+
bitmap_ptr,
63+
static_cast<size_t>(record_batch.length())};
64+
auto null_count = bitmap_view.null_count();
65+
if (validity_buffer_metadata->length() == 0)
66+
{
67+
bitmap_ptr = nullptr;
68+
null_count = 0;
69+
}
70+
71+
// Data buffer
72+
const auto primitive_buffer_metadata = record_batch.buffers()->Get(buffer_index++);
73+
auto data_buffer_span = body.subspan(primitive_buffer_metadata->offset(), primitive_buffer_metadata->length());
74+
if (compression)
75+
{
76+
if (!decompressed_buffers_owner)
77+
{
78+
decompressed_buffers_owner = new DecompressedBuffers();
79+
}
80+
decompressed_buffers_owner->data_buffer = decompress(compression->codec(), data_buffer_span);
81+
data_buffer_span = decompressed_buffers_owner->data_buffer;
82+
}
83+
auto primitives_ptr = const_cast<uint8_t*>(data_buffer_span.data());
84+
2585
const std::string_view format = data_type_to_format(
2686
sparrow::detail::get_data_type_from_array<sparrow::primitive_array<T>>::get()
2787
);
@@ -34,17 +94,7 @@ namespace sparrow_ipc
3494
nullptr,
3595
nullptr
3696
);
37-
const auto [bitmap_ptr, null_count] = utils::get_bitmap_pointer_and_null_count(
38-
record_batch,
39-
body,
40-
buffer_index++
41-
);
42-
const auto primitive_buffer_metadata = record_batch.buffers()->Get(buffer_index++);
43-
if (body.size() < (primitive_buffer_metadata->offset() + primitive_buffer_metadata->length()))
44-
{
45-
throw std::runtime_error("Primitive buffer exceeds body size");
46-
}
47-
auto primitives_ptr = const_cast<uint8_t*>(body.data() + primitive_buffer_metadata->offset());
97+
4898
std::vector<std::uint8_t*> buffers = {bitmap_ptr, primitives_ptr};
4999
ArrowArray array = make_non_owning_arrow_array(
50100
record_batch.length(),
@@ -55,7 +105,14 @@ namespace sparrow_ipc
55105
nullptr,
56106
nullptr
57107
);
108+
109+
if (decompressed_buffers_owner)
110+
{
111+
array.private_data = decompressed_buffers_owner;
112+
array.release = release_decompressed_buffers;
113+
}
114+
58115
sparrow::arrow_proxy ap{std::move(array), std::move(schema)};
59116
return sparrow::primitive_array<T>{std::move(ap)};
60117
}
61-
}
118+
}

include/sparrow_ipc/deserialize_variable_size_binary_array.hpp

Lines changed: 85 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,32 @@
88
#include "Message_generated.h"
99
#include "sparrow_ipc/arrow_interface/arrow_array.hpp"
1010
#include "sparrow_ipc/arrow_interface/arrow_schema.hpp"
11+
#include "sparrow_ipc/compression.hpp"
1112
#include "sparrow_ipc/deserialize_utils.hpp"
1213

1314
namespace sparrow_ipc
1415
{
16+
// TODO after handling deserialize_primitive_array, do the same here and then in other data types
17+
namespace
18+
{
19+
struct DecompressedBinaryBuffers
20+
{
21+
std::vector<uint8_t> validity_buffer;
22+
std::vector<uint8_t> offset_buffer;
23+
std::vector<uint8_t> data_buffer;
24+
};
25+
26+
void release_decompressed_binary_buffers(ArrowArray* array)
27+
{
28+
if (array->private_data)
29+
{
30+
delete static_cast<DecompressedBinaryBuffers*>(array->private_data);
31+
array->private_data = nullptr;
32+
}
33+
array->release = nullptr;
34+
}
35+
}
36+
1537
template <typename T>
1638
[[nodiscard]] T deserialize_non_owning_variable_size_binary(
1739
const org::apache::arrow::flatbuf::RecordBatch& record_batch,
@@ -21,6 +43,61 @@ namespace sparrow_ipc
2143
size_t& buffer_index
2244
)
2345
{
46+
const auto compression = record_batch.compression();
47+
DecompressedBinaryBuffers* decompressed_buffers_owner = nullptr;
48+
49+
// Validity buffer
50+
const auto validity_buffer_metadata = record_batch.buffers()->Get(buffer_index++);
51+
auto validity_buffer_span = body.subspan(validity_buffer_metadata->offset(), validity_buffer_metadata->length());
52+
if (compression && validity_buffer_metadata->length() > 0)
53+
{
54+
if (!decompressed_buffers_owner)
55+
{
56+
decompressed_buffers_owner = new DecompressedBinaryBuffers();
57+
}
58+
decompressed_buffers_owner->validity_buffer = decompress(compression->codec(), validity_buffer_span);
59+
validity_buffer_span = decompressed_buffers_owner->validity_buffer;
60+
}
61+
62+
uint8_t* bitmap_ptr = nullptr;
63+
int64_t null_count = 0;
64+
if (validity_buffer_metadata->length() > 0)
65+
{
66+
bitmap_ptr = const_cast<uint8_t*>(validity_buffer_span.data());
67+
const sparrow::dynamic_bitset_view<const std::uint8_t> bitmap_view{
68+
bitmap_ptr,
69+
static_cast<size_t>(record_batch.length())};
70+
null_count = bitmap_view.null_count();
71+
}
72+
73+
// Offset buffer
74+
const auto offset_metadata = record_batch.buffers()->Get(buffer_index++);
75+
auto offset_buffer_span = body.subspan(offset_metadata->offset(), offset_metadata->length());
76+
if (compression)
77+
{
78+
if (!decompressed_buffers_owner)
79+
{
80+
decompressed_buffers_owner = new DecompressedBinaryBuffers();
81+
}
82+
decompressed_buffers_owner->offset_buffer = decompress(compression->codec(), offset_buffer_span);
83+
offset_buffer_span = decompressed_buffers_owner->offset_buffer;
84+
}
85+
auto offset_ptr = const_cast<uint8_t*>(offset_buffer_span.data());
86+
87+
// Data buffer
88+
const auto buffer_metadata = record_batch.buffers()->Get(buffer_index++);
89+
auto data_buffer_span = body.subspan(buffer_metadata->offset(), buffer_metadata->length());
90+
if (compression)
91+
{
92+
if (!decompressed_buffers_owner)
93+
{
94+
decompressed_buffers_owner = new DecompressedBinaryBuffers();
95+
}
96+
decompressed_buffers_owner->data_buffer = decompress(compression->codec(), data_buffer_span);
97+
data_buffer_span = decompressed_buffers_owner->data_buffer;
98+
}
99+
auto buffer_ptr = const_cast<uint8_t*>(data_buffer_span.data());
100+
24101
const std::string_view format = data_type_to_format(sparrow::detail::get_data_type_from_array<T>::get());
25102
ArrowSchema schema = make_non_owning_arrow_schema(
26103
format,
@@ -31,24 +108,7 @@ namespace sparrow_ipc
31108
nullptr,
32109
nullptr
33110
);
34-
const auto [bitmap_ptr, null_count] = utils::get_bitmap_pointer_and_null_count(
35-
record_batch,
36-
body,
37-
buffer_index++
38-
);
39111

40-
const auto offset_metadata = record_batch.buffers()->Get(buffer_index++);
41-
if ((offset_metadata->offset() + offset_metadata->length()) > body.size())
42-
{
43-
throw std::runtime_error("Offset buffer exceeds body size");
44-
}
45-
auto offset_ptr = const_cast<uint8_t*>(body.data() + offset_metadata->offset());
46-
const auto buffer_metadata = record_batch.buffers()->Get(buffer_index++);
47-
if ((buffer_metadata->offset() + buffer_metadata->length()) > body.size())
48-
{
49-
throw std::runtime_error("Data buffer exceeds body size");
50-
}
51-
auto buffer_ptr = const_cast<uint8_t*>(body.data() + buffer_metadata->offset());
52112
std::vector<std::uint8_t*> buffers = {bitmap_ptr, offset_ptr, buffer_ptr};
53113
ArrowArray array = make_non_owning_arrow_array(
54114
record_batch.length(),
@@ -59,7 +119,14 @@ namespace sparrow_ipc
59119
nullptr,
60120
nullptr
61121
);
122+
123+
if (decompressed_buffers_owner)
124+
{
125+
array.private_data = decompressed_buffers_owner;
126+
array.release = release_decompressed_binary_buffers;
127+
}
128+
62129
sparrow::arrow_proxy ap{std::move(array), std::move(schema)};
63130
return T{std::move(ap)};
64131
}
65-
}
132+
}

src/compression.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,44 @@ namespace sparrow_ipc
4545
return {data.begin(), data.end()};
4646
}
4747
}
48+
49+
std::vector<std::uint8_t> decompress(org::apache::arrow::flatbuf::CompressionType compression_type, std::span<const std::uint8_t> data)
50+
{
51+
if (data.empty())
52+
{
53+
return {};
54+
}
55+
switch (compression_type)
56+
{
57+
case org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME:
58+
{
59+
if (data.size() < 8)
60+
{
61+
throw std::runtime_error("Invalid compressed data: missing decompressed size");
62+
}
63+
const std::int64_t decompressed_size = *reinterpret_cast<const std::int64_t*>(data.data());
64+
const auto compressed_data = data.subspan(8);
65+
66+
if (decompressed_size == -1)
67+
{
68+
return {compressed_data.begin(), compressed_data.end()};
69+
}
70+
71+
std::vector<std::uint8_t> decompressed_data(decompressed_size);
72+
LZ4F_dctx* dctx = nullptr;
73+
LZ4F_createDecompressionContext(&dctx, LZ4F_VERSION);
74+
size_t compressed_size_in_out = compressed_data.size();
75+
size_t decompressed_size_in_out = decompressed_size;
76+
size_t result = LZ4F_decompress(dctx, decompressed_data.data(), &decompressed_size_in_out, compressed_data.data(), &compressed_size_in_out, nullptr);
77+
if (LZ4F_isError(result))
78+
{
79+
throw std::runtime_error("Failed to decompress data with LZ4 frame format");
80+
}
81+
LZ4F_freeDecompressionContext(dctx);
82+
return decompressed_data;
83+
}
84+
default:
85+
return {data.begin(), data.end()};
86+
}
87+
}
4888
}

tests/test_de_serialization_with_files.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ const std::vector<std::filesystem::path> files_paths_to_test = {
3232
};
3333

3434
const std::vector<std::filesystem::path> files_paths_to_test_with_compression = {
35-
tests_resources_files_path_with_compression / "generated_lz4"
36-
// tests_resources_files_path_with_compression/ "generated_uncompressible_lz4"
35+
tests_resources_files_path_with_compression / "generated_lz4",
36+
tests_resources_files_path_with_compression/ "generated_uncompressible_lz4"
3737
// tests_resources_files_path_with_compression / "generated_zstd"
3838
// tests_resources_files_path_with_compression/ "generated_uncompressible_zstd"
3939
};
@@ -66,21 +66,27 @@ void compare_record_batches(
6666
)
6767
{
6868
REQUIRE_EQ(record_batches_1.size(), record_batches_2.size());
69+
// std::cout << "record_batches1 size: " << record_batches_1.size() << " record_batches2 size: " << record_batches_2.size() << std::endl;
6970
for (size_t i = 0; i < record_batches_1.size(); ++i)
7071
{
7172
for (size_t y = 0; y < record_batches_1[i].nb_columns(); y++)
7273
{
74+
// std::cout << "record_batches1 nb cols: " << record_batches_1[i].nb_columns() << " record_batches2 nb cols: " << record_batches_2[i].nb_columns() << std::endl;
7375
const auto& column_1 = record_batches_1[i].get_column(y);
7476
const auto& column_2 = record_batches_2[i].get_column(y);
7577
REQUIRE_EQ(column_1.size(), column_2.size());
78+
// std::cout << "column_1.size(): " << column_1.size() << " column_2.size(): " << column_2.size() << std::endl;
7679
for (size_t z = 0; z < column_1.size(); z++)
7780
{
7881
const auto col_name = column_1.name().value_or("NA");
7982
INFO("Comparing batch " << i << ", column " << y << " named :" << col_name << " , row " << z);
83+
// std::cout << "Comparing batch " << i << ", column " << y << " named :" << col_name << " , row " << z << std::endl;
8084
REQUIRE_EQ(column_1.data_type(), column_2.data_type());
8185
CHECK_EQ(column_1.name(), column_2.name());
86+
// std::cout << "column_1.name() :" << column_1.name() << " and " << column_2.name() << std::endl;
8287
const auto& column_1_value = column_1[z];
8388
const auto& column_2_value = column_2[z];
89+
// std::cout << "column_1_value :" << column_1_value << " and " << column_2_value << std::endl;
8490
CHECK_EQ(column_1_value, column_2_value);
8591
}
8692
}
@@ -182,7 +188,7 @@ TEST_SUITE("Integration tests")
182188
}
183189
}
184190

185-
TEST_CASE("Serialization with LZ4 compression")
191+
TEST_CASE("Compare record_batch serialization with stream file using LZ4 compression")
186192
{
187193
for (const auto& file_path : files_paths_to_test_with_compression)
188194
{
@@ -220,14 +226,12 @@ TEST_SUITE("Integration tests")
220226
const auto record_batches_from_stream = sparrow_ipc::deserialize_stream(
221227
std::span<const uint8_t>(stream_data)
222228
);
223-
224-
const auto serialized_data = sparrow_ipc::serialize(record_batches_from_json, std::nullopt);
225-
// const auto deserialized_serialized_data = sparrow_ipc::deserialize_stream(
226-
// std::span<const uint8_t>(serialized_data)
227-
// );
228-
// compare_record_batches(record_batches_from_stream, deserialized_serialized_data);
229+
const auto serialized_data = sparrow_ipc::serialize(record_batches_from_json, org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME);
230+
const auto deserialized_serialized_data = sparrow_ipc::deserialize_stream(
231+
std::span<const uint8_t>(serialized_data)
232+
);
233+
compare_record_batches(record_batches_from_stream, deserialized_serialized_data);
229234
}
230-
231235
}
232236
}
233237
}

0 commit comments

Comments
 (0)