Skip to content

Commit b697616

Browse files
committed
Add serialize
1 parent b94eea7 commit b697616

File tree

6 files changed

+342
-1
lines changed

6 files changed

+342
-1
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ set(SPARROW_IPC_HEADERS
9797
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_primitive_array.hpp
9898
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_utils.hpp
9999
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize.hpp
100+
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/serialize.hpp
100101
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/encapsulated_message.hpp
101102
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/magic_values.hpp
102103
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/metadata.hpp
@@ -111,6 +112,7 @@ set(SPARROW_IPC_SRC
111112
${SPARROW_IPC_SOURCE_DIR}/deserialize_fixedsizebinary_array.cpp
112113
${SPARROW_IPC_SOURCE_DIR}/deserialize_utils.cpp
113114
${SPARROW_IPC_SOURCE_DIR}/deserialize.cpp
115+
${SPARROW_IPC_SOURCE_DIR}/serialize.cpp
114116
${SPARROW_IPC_SOURCE_DIR}/encapsulated_message.cpp
115117
${SPARROW_IPC_SOURCE_DIR}/metadata.cpp
116118
${SPARROW_IPC_SOURCE_DIR}/utils.cpp

include/sparrow_ipc/serialize.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#pragma once
2+
3+
#include <ostream>
4+
#include <ranges>
5+
6+
#include <sparrow/record_batch.hpp>
7+
8+
namespace sparrow_ipc
9+
{
10+
template <std::ranges::input_range R>
11+
requires std::same_as<std::ranges::range_value_t<R>, sparrow::record_batch>
12+
void serialize(const R& record_batches, std::ostream& out);
13+
14+
std::vector<uint8_t> serialize_schema_message(const ArrowSchema& arrow_schema);
15+
16+
template <std::ranges::input_range R>
17+
requires std::same_as<std::ranges::range_value_t<R>, sparrow::record_batch>
18+
std::vector<uint8_t> serialize_record_batches(const R& record_batches);
19+
20+
21+
}

include/sparrow_ipc/utils.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include <string_view>
66
#include <utility>
77

8+
#include <sparrow/record_batch.hpp>
9+
810
#include "Schema_generated.h"
911
#include "sparrow_ipc/config/config.hpp"
1012

@@ -17,4 +19,10 @@ namespace sparrow_ipc::utils
1719
// This function maps a sparrow data type to the corresponding Flatbuffers type
1820
SPARROW_IPC_API std::pair<org::apache::arrow::flatbuf::Type, flatbuffers::Offset<void>>
1921
get_flatbuffer_type(flatbuffers::FlatBufferBuilder& builder, std::string_view format_str);
22+
23+
template <std::ranges::input_range R>
24+
requires std::same_as<std::ranges::range_value_t<R>, sparrow::record_batch>
25+
SPARROW_IPC_API bool check_record_batches_consistency(const R& record_batches);
26+
27+
size_t calculate_output_serialized_size(const sparrow::record_batch& record_batch);
2028
}

src/serialize.cpp

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
#include "sparrow_ipc/serialize.hpp"
2+
3+
#include <iterator>
4+
5+
#include "Message_generated.h"
6+
#include "sparrow_ipc/magic_values.hpp"
7+
#include "sparrow_ipc/utils.hpp"
8+
9+
namespace sparrow_ipc
10+
{
11+
::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<org::apache::arrow::flatbuf::Field>>>
12+
create_children(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema)
13+
{
14+
std::vector<flatbuffers::Offset<org::apache::arrow::flatbuf::Field>> children_vec;
15+
children_vec.reserve(arrow_schema.n_children);
16+
for (int i = 0; i < arrow_schema.n_children; ++i)
17+
{
18+
if (arrow_schema.children[i] == nullptr)
19+
{
20+
throw std::invalid_argument("ArrowSchema has null child at index " + std::to_string(i));
21+
}
22+
flatbuffers::Offset<org::apache::arrow::flatbuf::Field> field = create_field(
23+
builder,
24+
*(arrow_schema.children[i])
25+
);
26+
children_vec.emplace_back(field);
27+
}
28+
return children_vec.empty() ? 0 : builder.CreateVector(children_vec);
29+
}
30+
31+
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>>>
32+
create_metadata(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema)
33+
{
34+
if (arrow_schema.metadata == nullptr)
35+
{
36+
return 0;
37+
}
38+
39+
const auto metadata_view = sparrow::key_value_view(arrow_schema.metadata);
40+
std::vector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>> kv_offsets;
41+
kv_offsets.reserve(metadata_view.size());
42+
for (const auto& [key, value] : metadata_view)
43+
{
44+
const auto key_offset = builder.CreateString(std::string(key));
45+
const auto value_offset = builder.CreateString(std::string(value));
46+
kv_offsets.push_back(org::apache::arrow::flatbuf::CreateKeyValue(builder, key_offset, value_offset));
47+
}
48+
return builder.CreateVector(kv_offsets);
49+
}
50+
51+
::flatbuffers::Offset<org::apache::arrow::flatbuf::Field>
52+
create_field(flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema)
53+
{
54+
flatbuffers::Offset<flatbuffers::String> fb_name_offset = (arrow_schema.name == nullptr)
55+
? 0
56+
: builder.CreateString(arrow_schema.name);
57+
58+
const auto [type_enum, type_offset] = utils::get_flatbuffer_type(builder, arrow_schema.format);
59+
auto fb_metadata_offset = create_metadata(builder, arrow_schema);
60+
const auto children = create_children(builder, arrow_schema);
61+
62+
const auto fb_field = org::apache::arrow::flatbuf::CreateField(
63+
builder,
64+
fb_name_offset,
65+
(arrow_schema.flags & static_cast<int64_t>(sparrow::ArrowFlag::NULLABLE)) != 0,
66+
type_enum,
67+
type_offset,
68+
0, // TODO: support dictionary
69+
children,
70+
fb_metadata_offset
71+
);
72+
return fb_field;
73+
}
74+
75+
flatbuffers::FlatBufferBuilder get_schema_message_builder(const ArrowSchema& arrow_schema)
76+
{
77+
flatbuffers::FlatBufferBuilder schema_builder;
78+
const auto fields_vec = create_children(schema_builder, arrow_schema);
79+
const auto schema_offset = org::apache::arrow::flatbuf::CreateSchema(
80+
schema_builder,
81+
org::apache::arrow::flatbuf::Endianness::Little, // TODO: make configurable
82+
fields_vec
83+
);
84+
const auto schema_message_offset = org::apache::arrow::flatbuf::CreateMessage(
85+
schema_builder,
86+
org::apache::arrow::flatbuf::MetadataVersion::V5,
87+
org::apache::arrow::flatbuf::MessageHeader::Schema,
88+
schema_offset.Union(),
89+
0 // body length IS 0 for schema messages
90+
);
91+
schema_builder.Finish(schema_message_offset);
92+
return schema_builder;
93+
}
94+
95+
std::vector<uint8_t> serialize_schema_message(const ArrowSchema& arrow_schema)
96+
{
97+
std::vector<uint8_t> schema_buffer;
98+
99+
schema_buffer.insert(schema_buffer.end(), continuation.begin(), continuation.end());
100+
flatbuffers::FlatBufferBuilder schema_builder = get_schema_message_builder(arrow_schema);
101+
const flatbuffers::uoffset_t schema_len = schema_builder.GetSize();
102+
schema_buffer.reserve(schema_buffer.size() + sizeof(uint32_t) + schema_len);
103+
// Write the 4-byte length prefix after the continuation bytes
104+
schema_buffer.insert(
105+
schema_buffer.end(),
106+
reinterpret_cast<const uint8_t*>(&schema_len),
107+
reinterpret_cast<const uint8_t*>(&schema_len) + sizeof(uint32_t)
108+
);
109+
// Append the actual message bytes
110+
schema_buffer.insert(
111+
schema_buffer.end(),
112+
schema_builder.GetBufferPointer(),
113+
schema_builder.GetBufferPointer() + schema_len
114+
);
115+
// padding to 8 bytes
116+
schema_buffer.insert(
117+
schema_buffer.end(),
118+
utils::align_to_8(static_cast<int64_t>(schema_buffer.size()))
119+
- static_cast<int64_t>(schema_buffer.size()),
120+
0
121+
);
122+
return schema_buffer;
123+
}
124+
125+
void fill_fieldnodes(
126+
const sparrow::arrow_proxy& arrow_proxy,
127+
std::vector<org::apache::arrow::flatbuf::FieldNode>& nodes
128+
)
129+
{
130+
nodes.emplace_back(arrow_proxy.length(), arrow_proxy.null_count());
131+
nodes.reserve(nodes.size() + arrow_proxy.n_children());
132+
for (const auto& child : arrow_proxy.children())
133+
{
134+
fill_fieldnodes(child, nodes);
135+
}
136+
}
137+
138+
std::vector<org::apache::arrow::flatbuf::FieldNode>
139+
create_fieldnodes(const sparrow::record_batch& record_batch)
140+
{
141+
std::vector<org::apache::arrow::flatbuf::FieldNode> nodes;
142+
nodes.reserve(record_batch.columns().size());
143+
for (const auto& column : record_batch.columns())
144+
{
145+
fill_fieldnodes(sparrow::detail::array_access::get_arrow_proxy(column), nodes);
146+
}
147+
return nodes;
148+
}
149+
150+
void fill_buffers(
151+
const sparrow::arrow_proxy& arrow_proxy,
152+
std::vector<org::apache::arrow::flatbuf::Buffer>& flatbuf_buffers,
153+
int64_t& offset
154+
)
155+
{
156+
const auto& buffers = arrow_proxy.buffers();
157+
for (const auto& buffer : buffers)
158+
{
159+
int64_t size = static_cast<int64_t>(buffer.size());
160+
flatbuf_buffers.emplace_back(offset, size);
161+
offset += utils::align_to_8(size);
162+
}
163+
for (const auto& child : arrow_proxy.children())
164+
{
165+
const auto& child_arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(child);
166+
fill_buffers(child_arrow_proxy, flatbuf_buffers, offset);
167+
}
168+
}
169+
170+
std::vector<org::apache::arrow::flatbuf::Buffer> get_buffers(const sparrow::record_batch& record_batch)
171+
{
172+
std::vector<org::apache::arrow::flatbuf::Buffer> buffers;
173+
std::int64_t offset = 0;
174+
for (const auto& column : record_batch.columns())
175+
{
176+
const auto& arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(column);
177+
fill_buffers(arrow_proxy, buffers, offset);
178+
}
179+
return buffers;
180+
}
181+
182+
void fill_body(const sparrow::arrow_proxy& arrow_proxy, std::vector<uint8_t>& body)
183+
{
184+
for (const auto& buffer : arrow_proxy.buffers())
185+
{
186+
body.insert(body.end(), buffer.begin(), buffer.end());
187+
const int64_t padding_size = utils::align_to_8(static_cast<int64_t>(buffer.size()))
188+
- static_cast<int64_t>(buffer.size());
189+
body.insert(body.end(), padding_size, 0);
190+
}
191+
for (const auto& child : arrow_proxy.children())
192+
{
193+
const auto& child_arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(child);
194+
fill_body(child_arrow_proxy, body);
195+
}
196+
}
197+
198+
std::vector<uint8_t> generate_body(const sparrow::record_batch& record_batch)
199+
{
200+
std::vector<uint8_t> body;
201+
for (const auto& column : record_batch.columns())
202+
{
203+
const auto& arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(column);
204+
fill_body(arrow_proxy, body);
205+
}
206+
return body;
207+
}
208+
209+
std::vector<uint8_t> serialize_record_batch(const sparrow::record_batch& record_batch)
210+
{
211+
std::vector<org::apache::arrow::flatbuf::FieldNode> nodes = create_fieldnodes(record_batch);
212+
std::vector<org::apache::arrow::flatbuf::Buffer> flatbuf_buffers = get_buffers(record_batch);
213+
flatbuffers::FlatBufferBuilder record_batch_builder;
214+
org::apache::arrow::flatbuf::CreateRecordBatchDirect(
215+
record_batch_builder,
216+
static_cast<int64_t>(record_batch.nb_rows()),
217+
&nodes,
218+
&flatbuf_buffers
219+
);
220+
std::vector<uint8_t> output;
221+
output.insert(output.end(), continuation.begin(), continuation.end());
222+
const flatbuffers::uoffset_t record_batch_len = record_batch_builder.GetSize();
223+
output.insert(
224+
output.end(),
225+
reinterpret_cast<const uint8_t*>(&record_batch_len),
226+
reinterpret_cast<const uint8_t*>(&record_batch_len) + sizeof(record_batch_len)
227+
);
228+
output.insert(
229+
output.end(),
230+
record_batch_builder.GetBufferPointer(),
231+
record_batch_builder.GetBufferPointer() + record_batch_len
232+
);
233+
// padding to 8 bytes
234+
output.insert(
235+
output.end(),
236+
utils::align_to_8(static_cast<int64_t>(output.size())) - static_cast<int64_t>(output.size()),
237+
0
238+
);
239+
std::vector<uint8_t> body = generate_body(record_batch);
240+
output.insert(output.end(), std::make_move_iterator(body.begin()), std::make_move_iterator(body.end()));
241+
return output;
242+
}
243+
244+
template <std::ranges::input_range R>
245+
requires std::same_as<std::ranges::range_value_t<R>, sparrow::record_batch>
246+
std::vector<uint8_t> serialize_record_batches(const R& record_batches)
247+
{
248+
std::vector<uint8_t> output;
249+
for (const auto& record_batch : record_batches)
250+
{
251+
const auto rb_serialized = serialize_record_batch(record_batch);
252+
output.insert(
253+
output.end(),
254+
std::make_move_iterator(rb_serialized.begin()),
255+
std::make_move_iterator(rb_serialized.end())
256+
);
257+
}
258+
return output;
259+
}
260+
261+
template <std::ranges::input_range R>
262+
requires std::same_as<std::ranges::range_value_t<R>, sparrow::record_batch>
263+
void serialize(const R& record_batches, std::ostream& out)
264+
{
265+
if (check_record_batches_consistency(record_batches))
266+
{
267+
throw std::invalid_argument(
268+
"All record batches must have the same schema to be serialized together."
269+
);
270+
}
271+
std::vector<uint8_t> serialized_schema = serialize_schema_message(record_batches[0].schema());
272+
std::vector<uint8_t> serialized_record_batches = serialize_record_batches(record_batches);
273+
}
274+
}

src/utils.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,5 +444,41 @@ namespace sparrow_ipc
444444
}
445445
}
446446
}
447+
448+
template <std::ranges::input_range R>
449+
requires std::same_as<std::ranges::range_value_t<R>, sparrow::record_batch>
450+
bool check_record_batches_consistency(const R& record_batches)
451+
{
452+
if (record_batches.empty())
453+
{
454+
return true;
455+
}
456+
const sparrow::record_batch& first_rb = record_batches[0];
457+
for (const sparrow::record_batch& rb : record_batches)
458+
{
459+
rb.check_consistency();
460+
if (rb.nb_columns() != first_rb.nb_columns())
461+
{
462+
return false;
463+
}
464+
if (rb.nb_rows() != first_rb.nb_rows())
465+
{
466+
return false;
467+
}
468+
for (size_t col_idx = 0; col_idx < rb.nb_columns(); ++col_idx)
469+
{
470+
const sparrow::array& arr = rb.get_column(col_idx);
471+
const sparrow::array& first_arr = first_rb.get_column(col_idx);
472+
if (arr.format() != first_arr.format())
473+
{
474+
return false;
475+
}
476+
}
477+
}
478+
}
479+
480+
size_t calculate_output_serialized_size(const sparrow::record_batch& record_batch)
481+
{
482+
}
447483
}
448484
}

tests/test_deserialization_with_files.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#include "doctest/doctest.h"
1414
#include "sparrow.hpp"
1515
#include "sparrow_ipc/deserialize.hpp"
16-
16+
#include "sparrow_ipc/serialize.hpp"
1717

1818
const std::filesystem::path arrow_testing_data_dir = ARROW_TESTING_DATA_DIR;
1919

0 commit comments

Comments
 (0)