Skip to content

Commit e618b0b

Browse files
committed
Use std::span
1 parent 2db5795 commit e618b0b

File tree

7 files changed

+47
-37
lines changed

7 files changed

+47
-37
lines changed

include/sparrow_ipc/deserialize.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
namespace sparrow_ipc
1414
{
1515
SPARROW_IPC_API void deserialize_schema_message(
16-
const uint8_t* buf_ptr,
16+
std::span<const uint8_t> data,
1717
size_t& current_offset,
1818
std::optional<std::string>& name,
1919
std::optional<std::vector<sparrow::metadata_pair>>& metadata
2020
);
2121
[[nodiscard]] SPARROW_IPC_API const org::apache::arrow::flatbuf::RecordBatch*
22-
deserialize_record_batch_message(const uint8_t* buf_ptr, size_t& current_offset);
22+
deserialize_record_batch_message(std::span<const uint8_t> data, size_t& current_offset);
2323

24-
[[nodiscard]] SPARROW_IPC_API std::vector<sparrow::record_batch> deserialize_stream(const uint8_t* buf_ptr);
24+
[[nodiscard]] SPARROW_IPC_API std::vector<sparrow::record_batch>
25+
deserialize_stream(std::span<const uint8_t> data);
2526
}

include/sparrow_ipc/encapsulated_message.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace sparrow_ipc
99
{
1010
public:
1111

12-
EncapsulatedMessage(const uint8_t* buf_ptr);
12+
EncapsulatedMessage(std::span<const uint8_t> data);
1313

1414
[[nodiscard]] const org::apache::arrow::flatbuf::Message* flat_buffer_message() const;
1515

@@ -36,8 +36,8 @@ namespace sparrow_ipc
3636

3737
private:
3838

39-
const uint8_t* m_buf_ptr;
39+
std::span<const uint8_t> m_data;
4040
};
4141

42-
[[nodiscard]] EncapsulatedMessage create_encapsulated_message(const uint8_t* buf_ptr);
42+
[[nodiscard]] EncapsulatedMessage create_encapsulated_message(std::span<const uint8_t> buf_ptr);
4343
}

include/sparrow_ipc/serialize_primitive_array.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include "serialize.hpp"
99
#include "utils.hpp"
1010

11-
1211
namespace sparrow_ipc
1312
{
1413
// TODO Use `arr` as const after fixing the issue upstream in sparrow::get_arrow_structures
@@ -62,11 +61,14 @@ namespace sparrow_ipc
6261
// I - Deserialize the Schema message
6362
std::optional<std::string> name;
6463
std::optional<std::vector<sparrow::metadata_pair>> metadata;
65-
deserialize_schema_message(buf_ptr, current_offset, name, metadata);
64+
deserialize_schema_message(std::span<const uint8_t>(buffer), current_offset, name, metadata);
6665

6766
// II - Deserialize the RecordBatch message
6867
const uint32_t batch_meta_len = *(reinterpret_cast<const uint32_t*>(buf_ptr + current_offset));
69-
const auto* record_batch = deserialize_record_batch_message(buf_ptr, current_offset);
68+
const auto* record_batch = deserialize_record_batch_message(
69+
std::span<const uint8_t>(buffer),
70+
current_offset
71+
);
7072

7173
current_offset += utils::align_to_8(batch_meta_len);
7274
const uint8_t* body_ptr = buf_ptr + current_offset;

src/deserialize.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@
1111
namespace sparrow_ipc
1212
{
1313
void deserialize_schema_message(
14-
const uint8_t* buf_ptr,
14+
std::span<const uint8_t> data,
1515
size_t& current_offset,
1616
std::optional<std::string>& name,
1717
std::optional<std::vector<sparrow::metadata_pair>>& metadata
1818
)
1919
{
20-
const uint32_t schema_meta_len = *(reinterpret_cast<const uint32_t*>(buf_ptr + current_offset));
20+
const uint32_t schema_meta_len = *(reinterpret_cast<const uint32_t*>(data.data() + current_offset));
2121
current_offset += sizeof(uint32_t);
22-
const auto schema_message = org::apache::arrow::flatbuf::GetMessage(buf_ptr + current_offset);
22+
const auto schema_message = org::apache::arrow::flatbuf::GetMessage(data.data() + current_offset);
2323
if (schema_message->header_type() != org::apache::arrow::flatbuf::MessageHeader::Schema)
2424
{
2525
throw std::runtime_error("Expected Schema message at the start of the buffer.");
@@ -56,10 +56,10 @@ namespace sparrow_ipc
5656
}
5757

5858
const org::apache::arrow::flatbuf::RecordBatch*
59-
deserialize_record_batch_message(const uint8_t* buf_ptr, size_t& current_offset)
59+
deserialize_record_batch_message(std::span<const uint8_t> data, size_t& current_offset)
6060
{
6161
current_offset += sizeof(uint32_t);
62-
const auto batch_message = org::apache::arrow::flatbuf::GetMessage(buf_ptr + current_offset);
62+
const auto batch_message = org::apache::arrow::flatbuf::GetMessage(data.data() + current_offset);
6363
if (batch_message->header_type() != org::apache::arrow::flatbuf::MessageHeader::RecordBatch)
6464
{
6565
throw std::runtime_error("Expected RecordBatch message, but got a different type.");
@@ -225,7 +225,7 @@ namespace sparrow_ipc
225225
return arrays;
226226
}
227227

228-
std::vector<sparrow::record_batch> deserialize_stream(const uint8_t* buf_ptr)
228+
std::vector<sparrow::record_batch> deserialize_stream(std::span<const uint8_t> data)
229229
{
230230
const org::apache::arrow::flatbuf::Schema* schema = nullptr;
231231
std::vector<sparrow::record_batch> record_batches;
@@ -234,7 +234,7 @@ namespace sparrow_ipc
234234
std::vector<sparrow::data_type> field_types;
235235
do
236236
{
237-
const EncapsulatedMessage encapsulated_message = create_encapsulated_message(buf_ptr);
237+
const EncapsulatedMessage encapsulated_message = create_encapsulated_message(data);
238238
const org::apache::arrow::flatbuf::Message* message = encapsulated_message.flat_buffer_message();
239239
switch (message->header_type())
240240
{
@@ -280,8 +280,8 @@ namespace sparrow_ipc
280280
throw std::runtime_error("Unknown message header type.");
281281
}
282282
const size_t encapsulated_message_total_length = encapsulated_message.total_length();
283-
buf_ptr += encapsulated_message_total_length;
284-
if (is_end_of_stream(std::span<const uint8_t>{buf_ptr, 8}))
283+
data = data.subspan(encapsulated_message_total_length);
284+
if (is_end_of_stream(data.subspan(0, 8)))
285285
{
286286
break;
287287
}

src/encapsulated_message.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,21 @@
77

88
namespace sparrow_ipc
99
{
10-
EncapsulatedMessage::EncapsulatedMessage(const uint8_t* buf_ptr)
11-
: m_buf_ptr(buf_ptr)
10+
EncapsulatedMessage::EncapsulatedMessage(std::span<const uint8_t> data)
11+
: m_data(data)
1212
{
1313
}
1414

1515
const org::apache::arrow::flatbuf::Message* EncapsulatedMessage::flat_buffer_message() const
1616
{
17-
const uint8_t* message_ptr = m_buf_ptr + (sizeof(uint32_t) * 2); // 4 bytes continuation + 4 bytes
18-
// metadata size
17+
const uint8_t* message_ptr = m_data.data() + (sizeof(uint32_t) * 2); // 4 bytes continuation + 4
18+
// bytes metadata size
1919
return org::apache::arrow::flatbuf::GetMessage(message_ptr);
2020
}
2121

2222
size_t EncapsulatedMessage::metadata_length() const
2323
{
24-
return *(reinterpret_cast<const uint32_t*>(m_buf_ptr + sizeof(uint32_t)));
24+
return *(reinterpret_cast<const uint32_t*>(m_data.data() + sizeof(uint32_t)));
2525
}
2626

2727
[[nodiscard]] std::variant<
@@ -76,8 +76,7 @@ namespace sparrow_ipc
7676
const size_t offset = sizeof(uint32_t) * 2 // 4 bytes continuation + 4 bytes metadata size
7777
+ metadata_length();
7878
const size_t padded_offset = utils::align_to_8(offset); // Round up to 8-byte boundary
79-
const uint8_t* body_ptr = m_buf_ptr + padded_offset;
80-
return {body_ptr, body_length()};
79+
return m_data.subspan(padded_offset, body_length());
8180
}
8281

8382
size_t EncapsulatedMessage::total_length() const
@@ -90,20 +89,20 @@ namespace sparrow_ipc
9089

9190
std::span<const uint8_t> EncapsulatedMessage::as_span() const
9291
{
93-
return {m_buf_ptr, total_length()};
92+
return m_data;
9493
}
9594

96-
EncapsulatedMessage create_encapsulated_message(const uint8_t* buf_ptr)
95+
EncapsulatedMessage create_encapsulated_message(std::span<const uint8_t> data)
9796
{
98-
if (!buf_ptr)
97+
if (!data.size() || data.size() < 8)
9998
{
100-
throw std::invalid_argument("Buffer pointer cannot be null.");
99+
throw std::invalid_argument("Buffer is too small to contain a valid message.");
101100
}
102-
const std::span<const uint8_t> continuation_span(buf_ptr, 4);
101+
const std::span<const uint8_t> continuation_span = data.subspan(0, 4);
103102
if (!is_continuation(continuation_span))
104103
{
105104
throw std::runtime_error("Buffer starts with continuation bytes, expected a valid message.");
106105
}
107-
return {buf_ptr};
106+
return {data};
108107
}
109108
}

src/serialize_null_array.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,13 @@ namespace sparrow_ipc
3232
// I - Deserialize the Schema message
3333
std::optional<std::string> name;
3434
std::optional<std::vector<sparrow::metadata_pair>> metadata;
35-
deserialize_schema_message(buf_ptr, current_offset, name, metadata);
35+
deserialize_schema_message(std::span<const uint8_t>(buffer), current_offset, name, metadata);
3636

3737
// II - Deserialize the RecordBatch message
38-
const auto* record_batch = deserialize_record_batch_message(buf_ptr, current_offset);
38+
const auto* record_batch = deserialize_record_batch_message(
39+
std::span<const uint8_t>(buffer),
40+
current_offset
41+
);
3942

4043
// The body is empty, so we don't need to read any further.
4144
// Construct the null_array from the deserialized metadata.

tests/test_primitive_array_with_files.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,21 +84,26 @@ TEST_SUITE("Integration tests")
8484
stream_file.close();
8585

8686
// Process the stream file
87-
const auto record_batches_from_stream = sparrow_ipc::deserialize_stream(stream_data.data());
87+
const auto record_batches_from_stream = sparrow_ipc::deserialize_stream(
88+
std::span<const uint8_t>(stream_data)
89+
);
8890

8991
// Compare record batches
9092
REQUIRE_EQ(record_batches_from_stream.size(), record_batches_from_json.size());
9193
for (size_t i = 0; i < record_batches_from_stream.size(); ++i)
9294
{
93-
for(size_t y = 0; y < record_batches_from_stream[i].nb_columns(); y++)
95+
for (size_t y = 0; y < record_batches_from_stream[i].nb_columns(); y++)
9496
{
9597
const auto& column_stream = record_batches_from_stream[i].get_column(y);
9698
const auto& column_json = record_batches_from_json[i].get_column(y);
9799
REQUIRE_EQ(column_stream.size(), column_json.size());
98-
for(size_t z = 0 ; z < column_json.size(); z++)
100+
for (size_t z = 0; z < column_json.size(); z++)
99101
{
100102
const auto col_name = column_stream.name().value_or("NA");
101-
INFO("Comparing batch " << i << ", column " << y << " named :"<< col_name <<" , row " << z);
103+
INFO(
104+
"Comparing batch " << i << ", column " << y << " named :" << col_name
105+
<< " , row " << z
106+
);
102107
const auto& column_stream_value = column_stream[z];
103108
const auto& column_json_value = column_json[z];
104109
CHECK_EQ(column_stream_value, column_json_value);

0 commit comments

Comments
 (0)