Skip to content

Commit 5ad8ebb

Browse files
authored
Handle compression in deserialize_fixedsizebinary_array (#59)
Add more tests
1 parent b7a3605 commit 5ad8ebb

File tree

5 files changed

+95
-14
lines changed

5 files changed

+95
-14
lines changed

include/sparrow_ipc/serializer.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@ namespace sparrow_ipc
8282
requires std::same_as<std::ranges::range_value_t<R>, sparrow::record_batch>
8383
void write(const R& record_batches)
8484
{
85+
if (std::ranges::empty(record_batches))
86+
{
87+
return;
88+
}
89+
8590
if (m_ended)
8691
{
8792
throw std::runtime_error("Cannot append to a serializer that has been ended");

src/deserialize.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ namespace sparrow_ipc
6161
const std::optional<std::vector<sparrow::metadata_pair>>& metadata = field_metadata[field_idx++];
6262
const std::string name = field->name() == nullptr ? "" : field->name()->str();
6363
const auto field_type = field->type_type();
64+
// TODO rename all the deserialize_non_owning... fcts since this is not correct anymore
6465
const auto deserialize_non_owning_primitive_array_lambda = [&]<typename T>()
6566
{
6667
return deserialize_non_owning_primitive_array<T>(
@@ -207,8 +208,20 @@ namespace sparrow_ipc
207208
std::vector<std::optional<std::vector<sparrow::metadata_pair>>> fields_metadata;
208209
do
209210
{
211+
// Check for end-of-stream marker here as data could contain only that (if no record batches present/written)
212+
if (data.size() >= 8 && is_end_of_stream(data.subspan(0, 8)))
213+
{
214+
break;
215+
}
216+
210217
const auto [encapsulated_message, rest] = extract_encapsulated_message(data);
211218
const org::apache::arrow::flatbuf::Message* message = encapsulated_message.flat_buffer_message();
219+
220+
if (message == nullptr)
221+
{
222+
throw std::invalid_argument("Extracted flatbuffers message is null.");
223+
}
224+
212225
switch (message->header_type())
213226
{
214227
case org::apache::arrow::flatbuf::MessageHeader::Schema:
@@ -269,10 +282,6 @@ namespace sparrow_ipc
269282
throw std::runtime_error("Unknown message header type.");
270283
}
271284
data = rest;
272-
if (is_end_of_stream(data.subspan(0, 8)))
273-
{
274-
break;
275-
}
276285
} while (!data.empty());
277286
return record_batches;
278287
}

src/deserialize_fixedsizebinary_array.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
namespace sparrow_ipc
44
{
5-
// TODO add compression here and tests (not available for this type in apache arrow integration tests files)
65
sparrow::fixed_width_binary_array deserialize_non_owning_fixedwidthbinary(
76
const org::apache::arrow::flatbuf::RecordBatch& record_batch,
87
std::span<const uint8_t> body,
@@ -23,10 +22,22 @@ namespace sparrow_ipc
2322
nullptr
2423
);
2524

25+
const auto compression = record_batch.compression();
2626
std::vector<arrow_array_private_data::optionally_owned_buffer> buffers;
27+
2728
auto validity_buffer_span = utils::get_buffer(record_batch, body, buffer_index);
28-
buffers.push_back(validity_buffer_span);
29-
buffers.push_back(utils::get_buffer(record_batch, body, buffer_index));
29+
auto data_buffer_span = utils::get_buffer(record_batch, body, buffer_index);
30+
31+
if (compression)
32+
{
33+
buffers.push_back(utils::get_decompressed_buffer(validity_buffer_span, compression));
34+
buffers.push_back(utils::get_decompressed_buffer(data_buffer_span, compression));
35+
}
36+
else
37+
{
38+
buffers.push_back(validity_buffer_span);
39+
buffers.push_back(data_buffer_span);
40+
}
3041

3142
// TODO bitmap_ptr is not used anymore... Leave it for now, and remove later if no need confirmed
3243
const auto [bitmap_ptr, null_count] = utils::get_bitmap_pointer_and_null_count(validity_buffer_span, record_batch.length());

src/encapsulated_message.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,10 @@ namespace sparrow_ipc
106106
const std::span<const uint8_t> continuation_span = data.subspan(0, 4);
107107
if (!is_continuation(continuation_span))
108108
{
109-
throw std::runtime_error("Buffer starts with continuation bytes, expected a valid message.");
109+
throw std::runtime_error("Buffer should start with continuation bytes, expected a valid message.");
110110
}
111111
encapsulated_message message(data);
112112
std::span<const uint8_t> rest = data.subspan(message.total_length());
113113
return {std::move(message), std::move(rest)};
114114
}
115-
}
115+
}

tests/test_de_serialization_with_files.cpp

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,19 @@ const std::filesystem::path tests_resources_files_path_with_compression = arrow_
2727

2828
const std::vector<std::filesystem::path> files_paths_to_test = {
2929
tests_resources_files_path / "generated_primitive",
30-
// tests_resources_files_path / "generated_primitive_large_offsets",
3130
tests_resources_files_path / "generated_primitive_zerolength",
32-
// tests_resources_files_path / "generated_primitive_no_batches"
31+
tests_resources_files_path / "generated_primitive_no_batches",
32+
tests_resources_files_path / "generated_binary",
33+
tests_resources_files_path / "generated_large_binary",
34+
tests_resources_files_path / "generated_binary_zerolength",
35+
tests_resources_files_path / "generated_binary_no_batches",
3336
};
3437

3538
const std::vector<std::filesystem::path> files_paths_to_test_with_compression = {
3639
tests_resources_files_path_with_compression / "generated_lz4",
37-
tests_resources_files_path_with_compression/ "generated_uncompressible_lz4"
38-
// tests_resources_files_path_with_compression / "generated_zstd"
39-
// tests_resources_files_path_with_compression/ "generated_uncompressible_zstd"
40+
tests_resources_files_path_with_compression/ "generated_uncompressible_lz4",
41+
// tests_resources_files_path_with_compression / "generated_zstd",
42+
// tests_resources_files_path_with_compression/ "generated_uncompressible_zstd",
4043
};
4144

4245

@@ -236,4 +239,57 @@ TEST_SUITE("Integration tests")
236239
}
237240
}
238241
}
242+
243+
TEST_CASE("Round trip of classic test files serialization/deserialization using LZ4 compression")
244+
{
245+
for (const auto& file_path : files_paths_to_test)
246+
{
247+
std::filesystem::path json_path = file_path;
248+
json_path.replace_extension(".json");
249+
250+
// Load the JSON file
251+
auto json_data = load_json_file(json_path);
252+
CHECK(json_data != nullptr);
253+
254+
const size_t num_batches = get_number_of_batches(json_path);
255+
std::vector<sparrow::record_batch> record_batches_from_json;
256+
for (size_t batch_idx = 0; batch_idx < num_batches; ++batch_idx)
257+
{
258+
INFO("Processing batch " << batch_idx << " of " << num_batches);
259+
record_batches_from_json.emplace_back(
260+
sparrow::json_reader::build_record_batch_from_json(json_data, batch_idx)
261+
);
262+
}
263+
264+
// Load stream file
265+
std::filesystem::path stream_file_path = file_path;
266+
stream_file_path.replace_extension(".stream");
267+
std::ifstream stream_file(stream_file_path, std::ios::in | std::ios::binary);
268+
REQUIRE(stream_file.is_open());
269+
const std::vector<uint8_t> stream_data(
270+
(std::istreambuf_iterator<char>(stream_file)),
271+
(std::istreambuf_iterator<char>())
272+
);
273+
stream_file.close();
274+
275+
// Process the stream file
276+
const auto record_batches_from_stream = sparrow_ipc::deserialize_stream(
277+
std::span<const uint8_t>(stream_data)
278+
);
279+
280+
// Serialize from json with LZ4 compression
281+
std::vector<uint8_t> serialized_data;
282+
sparrow_ipc::memory_output_stream stream(serialized_data);
283+
sparrow_ipc::serializer serializer(stream, sparrow_ipc::CompressionType::LZ4_FRAME);
284+
serializer << record_batches_from_json << sparrow_ipc::end_stream;
285+
286+
// Deserialize
287+
const auto deserialized_serialized_data = sparrow_ipc::deserialize_stream(
288+
std::span<const uint8_t>(serialized_data)
289+
);
290+
291+
// Compare
292+
compare_record_batches(record_batches_from_stream, deserialized_serialized_data);
293+
}
294+
}
239295
}

0 commit comments

Comments
 (0)