22
33#include < iterator>
44
5- #include " Message_generated.h"
65#include " sparrow_ipc/magic_values.hpp"
76#include " sparrow_ipc/utils.hpp"
87
98namespace sparrow_ipc
109{
11- ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<org::apache::arrow::flatbuf::Field>>>
12- create_children (flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema)
13- {
14- std::vector<flatbuffers::Offset<org::apache::arrow::flatbuf::Field>> children_vec;
15- children_vec.reserve (arrow_schema.n_children );
16- for (int i = 0 ; i < arrow_schema.n_children ; ++i)
17- {
18- if (arrow_schema.children [i] == nullptr )
19- {
20- throw std::invalid_argument (" ArrowSchema has null child at index " + std::to_string (i));
21- }
22- flatbuffers::Offset<org::apache::arrow::flatbuf::Field> field = create_field (
23- builder,
24- *(arrow_schema.children [i])
25- );
26- children_vec.emplace_back (field);
27- }
28- return children_vec.empty () ? 0 : builder.CreateVector (children_vec);
29- }
30-
3110 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>>>
3211 create_metadata (flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema)
3312 {
@@ -72,6 +51,26 @@ namespace sparrow_ipc
7251 return fb_field;
7352 }
7453
54+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<org::apache::arrow::flatbuf::Field>>>
55+ create_children (flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema)
56+ {
57+ std::vector<flatbuffers::Offset<org::apache::arrow::flatbuf::Field>> children_vec;
58+ children_vec.reserve (arrow_schema.n_children );
59+ for (int i = 0 ; i < arrow_schema.n_children ; ++i)
60+ {
61+ if (arrow_schema.children [i] == nullptr )
62+ {
63+ throw std::invalid_argument (" ArrowSchema has null child at index " + std::to_string (i));
64+ }
65+ flatbuffers::Offset<org::apache::arrow::flatbuf::Field> field = create_field (
66+ builder,
67+ *(arrow_schema.children [i])
68+ );
69+ children_vec.emplace_back (field);
70+ }
71+ return children_vec.empty () ? 0 : builder.CreateVector (children_vec);
72+ }
73+
7574 flatbuffers::FlatBufferBuilder get_schema_message_builder (const ArrowSchema& arrow_schema)
7675 {
7776 flatbuffers::FlatBufferBuilder schema_builder;
@@ -86,7 +85,8 @@ namespace sparrow_ipc
8685 org::apache::arrow::flatbuf::MetadataVersion::V5,
8786 org::apache::arrow::flatbuf::MessageHeader::Schema,
8887 schema_offset.Union (),
89- 0 // body length IS 0 for schema messages
88+ 0 , // body length IS 0 for schema messages
89+ 0 // custom metadata
9090 );
9191 schema_builder.Finish (schema_message_offset);
9292 return schema_builder;
@@ -206,17 +206,78 @@ namespace sparrow_ipc
206206 return body;
207207 }
208208
209- std::vector<uint8_t > serialize_record_batch (const sparrow::record_batch& record_batch)
209+ int64_t calculate_body_size (const sparrow::arrow_proxy& arrow_proxy)
210+ {
211+ int64_t total_size = 0 ;
212+ for (const auto & buffer : arrow_proxy.buffers ())
213+ {
214+ total_size += utils::align_to_8 (static_cast <int64_t >(buffer.size ()));
215+ }
216+ for (const auto & child : arrow_proxy.children ())
217+ {
218+ const auto & child_arrow_proxy = sparrow::detail::array_access::get_arrow_proxy (child);
219+ total_size += calculate_body_size (child_arrow_proxy);
220+ }
221+ return total_size;
222+ }
223+
224+ int64_t calculate_body_size (const sparrow::record_batch& record_batch)
225+ {
226+ return std::accumulate (
227+ record_batch.columns ().begin (),
228+ record_batch.columns ().end (),
229+ 0 ,
230+ [](int64_t acc, const sparrow::array& arr)
231+ {
232+ const auto & arrow_proxy = sparrow::detail::array_access::get_arrow_proxy (arr);
233+ return acc + calculate_body_size (arrow_proxy);
234+ }
235+ );
236+ }
237+
238+ flatbuffers::FlatBufferBuilder get_record_batch_message_builder (
239+ const sparrow::record_batch& record_batch,
240+ const std::vector<org::apache::arrow::flatbuf::FieldNode>& nodes,
241+ const std::vector<org::apache::arrow::flatbuf::Buffer>& buffers
242+ )
210243 {
211- std::vector<org::apache::arrow::flatbuf::FieldNode> nodes = create_fieldnodes (record_batch);
212- std::vector<org::apache::arrow::flatbuf::Buffer> flatbuf_buffers = get_buffers (record_batch);
213244 flatbuffers::FlatBufferBuilder record_batch_builder;
214- org::apache::arrow::flatbuf::CreateRecordBatchDirect (
245+ auto nodes_offset = record_batch_builder.CreateVectorOfStructs (nodes);
246+ auto buffers_offset = record_batch_builder.CreateVectorOfStructs (buffers);
247+ const auto record_batch_offset = org::apache::arrow::flatbuf::CreateRecordBatch (
215248 record_batch_builder,
216249 static_cast <int64_t >(record_batch.nb_rows ()),
217- &nodes,
218- &flatbuf_buffers
250+ nodes_offset,
251+ buffers_offset,
252+ 0 , // TODO: Compression
253+ 0 // TODO :variadic buffer Counts
254+ );
255+
256+ const int64_t body_size = calculate_body_size (record_batch);
257+ const auto record_batch_message_offset = org::apache::arrow::flatbuf::CreateMessage (
258+ record_batch_builder,
259+ org::apache::arrow::flatbuf::MetadataVersion::V5,
260+ org::apache::arrow::flatbuf::MessageHeader::RecordBatch,
261+ record_batch_offset.Union (),
262+ body_size, // body length
263+ 0 // custom metadata
219264 );
265+ record_batch_builder.Finish (record_batch_message_offset);
266+ return record_batch_builder;
267+ }
268+
269+ std::vector<uint8_t > serialize_record_batch (const sparrow::record_batch& record_batch)
270+ {
271+ std::vector<org::apache::arrow::flatbuf::FieldNode> nodes = create_fieldnodes (record_batch);
272+ std::vector<org::apache::arrow::flatbuf::Buffer> flatbuf_buffers = get_buffers (record_batch);
273+ flatbuffers::FlatBufferBuilder record_batch_builder;
274+ ::flatbuffers::Offset<org::apache::arrow::flatbuf::RecordBatch>
275+ record_batch_offset = org::apache::arrow::flatbuf::CreateRecordBatchDirect (
276+ record_batch_builder,
277+ static_cast <int64_t >(record_batch.nb_rows ()),
278+ &nodes,
279+ &flatbuf_buffers
280+ );
220281 std::vector<uint8_t > output;
221282 output.insert (output.end (), continuation.begin (), continuation.end ());
222283 const flatbuffers::uoffset_t record_batch_len = record_batch_builder.GetSize ();
@@ -260,7 +321,7 @@ namespace sparrow_ipc
260321
261322 template <std::ranges::input_range R>
262323 requires std::same_as<std::ranges::range_value_t <R>, sparrow::record_batch>
263- void serialize (const R& record_batches, std::ostream& out )
324+ std::vector< uint8_t > serialize (const R& record_batches)
264325 {
265326 if (check_record_batches_consistency (record_batches))
266327 {
@@ -270,5 +331,11 @@ namespace sparrow_ipc
270331 }
271332 std::vector<uint8_t > serialized_schema = serialize_schema_message (record_batches[0 ].schema ());
272333 std::vector<uint8_t > serialized_record_batches = serialize_record_batches (record_batches);
334+ serialized_schema.insert (
335+ serialized_schema.end (),
336+ std::make_move_iterator (serialized_record_batches.begin ()),
337+ std::make_move_iterator (serialized_record_batches.end ())
338+ );
339+ return serialized_schema;
273340 }
274341}
0 commit comments