1+ #include " sparrow_ipc/serialize.hpp"
2+
3+ #include < iterator>
4+
5+ #include " Message_generated.h"
6+ #include " sparrow_ipc/magic_values.hpp"
7+ #include " sparrow_ipc/utils.hpp"
8+
9+ namespace sparrow_ipc
10+ {
11+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<org::apache::arrow::flatbuf::Field>>>
12+ create_children (flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema)
13+ {
14+ std::vector<flatbuffers::Offset<org::apache::arrow::flatbuf::Field>> children_vec;
15+ children_vec.reserve (arrow_schema.n_children );
16+ for (int i = 0 ; i < arrow_schema.n_children ; ++i)
17+ {
18+ if (arrow_schema.children [i] == nullptr )
19+ {
20+ throw std::invalid_argument (" ArrowSchema has null child at index " + std::to_string (i));
21+ }
22+ flatbuffers::Offset<org::apache::arrow::flatbuf::Field> field = create_field (
23+ builder,
24+ *(arrow_schema.children [i])
25+ );
26+ children_vec.emplace_back (field);
27+ }
28+ return children_vec.empty () ? 0 : builder.CreateVector (children_vec);
29+ }
30+
31+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>>>
32+ create_metadata (flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema)
33+ {
34+ if (arrow_schema.metadata == nullptr )
35+ {
36+ return 0 ;
37+ }
38+
39+ const auto metadata_view = sparrow::key_value_view (arrow_schema.metadata );
40+ std::vector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>> kv_offsets;
41+ kv_offsets.reserve (metadata_view.size ());
42+ for (const auto & [key, value] : metadata_view)
43+ {
44+ const auto key_offset = builder.CreateString (std::string (key));
45+ const auto value_offset = builder.CreateString (std::string (value));
46+ kv_offsets.push_back (org::apache::arrow::flatbuf::CreateKeyValue (builder, key_offset, value_offset));
47+ }
48+ return builder.CreateVector (kv_offsets);
49+ }
50+
51+ ::flatbuffers::Offset<org::apache::arrow::flatbuf::Field>
52+ create_field (flatbuffers::FlatBufferBuilder& builder, const ArrowSchema& arrow_schema)
53+ {
54+ flatbuffers::Offset<flatbuffers::String> fb_name_offset = (arrow_schema.name == nullptr )
55+ ? 0
56+ : builder.CreateString (arrow_schema.name );
57+
58+ const auto [type_enum, type_offset] = utils::get_flatbuffer_type (builder, arrow_schema.format );
59+ auto fb_metadata_offset = create_metadata (builder, arrow_schema);
60+ const auto children = create_children (builder, arrow_schema);
61+
62+ const auto fb_field = org::apache::arrow::flatbuf::CreateField (
63+ builder,
64+ fb_name_offset,
65+ (arrow_schema.flags & static_cast <int64_t >(sparrow::ArrowFlag::NULLABLE)) != 0 ,
66+ type_enum,
67+ type_offset,
68+ 0 , // TODO: support dictionary
69+ children,
70+ fb_metadata_offset
71+ );
72+ return fb_field;
73+ }
74+
75+ flatbuffers::FlatBufferBuilder get_schema_message_builder (const ArrowSchema& arrow_schema)
76+ {
77+ flatbuffers::FlatBufferBuilder schema_builder;
78+ const auto fields_vec = create_children (schema_builder, arrow_schema);
79+ const auto schema_offset = org::apache::arrow::flatbuf::CreateSchema (
80+ schema_builder,
81+ org::apache::arrow::flatbuf::Endianness::Little, // TODO: make configurable
82+ fields_vec
83+ );
84+ const auto schema_message_offset = org::apache::arrow::flatbuf::CreateMessage (
85+ schema_builder,
86+ org::apache::arrow::flatbuf::MetadataVersion::V5,
87+ org::apache::arrow::flatbuf::MessageHeader::Schema,
88+ schema_offset.Union (),
89+ 0 // body length IS 0 for schema messages
90+ );
91+ schema_builder.Finish (schema_message_offset);
92+ return schema_builder;
93+ }
94+
95+ std::vector<uint8_t > serialize_schema_message (const ArrowSchema& arrow_schema)
96+ {
97+ std::vector<uint8_t > schema_buffer;
98+
99+ schema_buffer.insert (schema_buffer.end (), continuation.begin (), continuation.end ());
100+ flatbuffers::FlatBufferBuilder schema_builder = get_schema_message_builder (arrow_schema);
101+ const flatbuffers::uoffset_t schema_len = schema_builder.GetSize ();
102+ schema_buffer.reserve (schema_buffer.size () + sizeof (uint32_t ) + schema_len);
103+ // Write the 4-byte length prefix after the continuation bytes
104+ schema_buffer.insert (
105+ schema_buffer.end (),
106+ reinterpret_cast <const uint8_t *>(&schema_len),
107+ reinterpret_cast <const uint8_t *>(&schema_len) + sizeof (uint32_t )
108+ );
109+ // Append the actual message bytes
110+ schema_buffer.insert (
111+ schema_buffer.end (),
112+ schema_builder.GetBufferPointer (),
113+ schema_builder.GetBufferPointer () + schema_len
114+ );
115+ // padding to 8 bytes
116+ schema_buffer.insert (
117+ schema_buffer.end (),
118+ utils::align_to_8 (static_cast <int64_t >(schema_buffer.size ()))
119+ - static_cast <int64_t >(schema_buffer.size ()),
120+ 0
121+ );
122+ return schema_buffer;
123+ }
124+
125+ void fill_fieldnodes (
126+ const sparrow::arrow_proxy& arrow_proxy,
127+ std::vector<org::apache::arrow::flatbuf::FieldNode>& nodes
128+ )
129+ {
130+ nodes.emplace_back (arrow_proxy.length (), arrow_proxy.null_count ());
131+ nodes.reserve (nodes.size () + arrow_proxy.n_children ());
132+ for (const auto & child : arrow_proxy.children ())
133+ {
134+ fill_fieldnodes (child, nodes);
135+ }
136+ }
137+
138+ std::vector<org::apache::arrow::flatbuf::FieldNode>
139+ create_fieldnodes (const sparrow::record_batch& record_batch)
140+ {
141+ std::vector<org::apache::arrow::flatbuf::FieldNode> nodes;
142+ nodes.reserve (record_batch.columns ().size ());
143+ for (const auto & column : record_batch.columns ())
144+ {
145+ fill_fieldnodes (sparrow::detail::array_access::get_arrow_proxy (column), nodes);
146+ }
147+ return nodes;
148+ }
149+
150+ void fill_buffers (
151+ const sparrow::arrow_proxy& arrow_proxy,
152+ std::vector<org::apache::arrow::flatbuf::Buffer>& flatbuf_buffers,
153+ int64_t & offset
154+ )
155+ {
156+ const auto & buffers = arrow_proxy.buffers ();
157+ for (const auto & buffer : buffers)
158+ {
159+ int64_t size = static_cast <int64_t >(buffer.size ());
160+ flatbuf_buffers.emplace_back (offset, size);
161+ offset += utils::align_to_8 (size);
162+ }
163+ for (const auto & child : arrow_proxy.children ())
164+ {
165+ const auto & child_arrow_proxy = sparrow::detail::array_access::get_arrow_proxy (child);
166+ fill_buffers (child_arrow_proxy, flatbuf_buffers, offset);
167+ }
168+ }
169+
170+ std::vector<org::apache::arrow::flatbuf::Buffer> get_buffers (const sparrow::record_batch& record_batch)
171+ {
172+ std::vector<org::apache::arrow::flatbuf::Buffer> buffers;
173+ std::int64_t offset = 0 ;
174+ for (const auto & column : record_batch.columns ())
175+ {
176+ const auto & arrow_proxy = sparrow::detail::array_access::get_arrow_proxy (column);
177+ fill_buffers (arrow_proxy, buffers, offset);
178+ }
179+ return buffers;
180+ }
181+
182+ void fill_body (const sparrow::arrow_proxy& arrow_proxy, std::vector<uint8_t >& body)
183+ {
184+ for (const auto & buffer : arrow_proxy.buffers ())
185+ {
186+ body.insert (body.end (), buffer.begin (), buffer.end ());
187+ const int64_t padding_size = utils::align_to_8 (static_cast <int64_t >(buffer.size ()))
188+ - static_cast <int64_t >(buffer.size ());
189+ body.insert (body.end (), padding_size, 0 );
190+ }
191+ for (const auto & child : arrow_proxy.children ())
192+ {
193+ const auto & child_arrow_proxy = sparrow::detail::array_access::get_arrow_proxy (child);
194+ fill_body (child_arrow_proxy, body);
195+ }
196+ }
197+
198+ std::vector<uint8_t > generate_body (const sparrow::record_batch& record_batch)
199+ {
200+ std::vector<uint8_t > body;
201+ for (const auto & column : record_batch.columns ())
202+ {
203+ const auto & arrow_proxy = sparrow::detail::array_access::get_arrow_proxy (column);
204+ fill_body (arrow_proxy, body);
205+ }
206+ return body;
207+ }
208+
209+ std::vector<uint8_t > serialize_record_batch (const sparrow::record_batch& record_batch)
210+ {
211+ std::vector<org::apache::arrow::flatbuf::FieldNode> nodes = create_fieldnodes (record_batch);
212+ std::vector<org::apache::arrow::flatbuf::Buffer> flatbuf_buffers = get_buffers (record_batch);
213+ flatbuffers::FlatBufferBuilder record_batch_builder;
214+ org::apache::arrow::flatbuf::CreateRecordBatchDirect (
215+ record_batch_builder,
216+ static_cast <int64_t >(record_batch.nb_rows ()),
217+ &nodes,
218+ &flatbuf_buffers
219+ );
220+ std::vector<uint8_t > output;
221+ output.insert (output.end (), continuation.begin (), continuation.end ());
222+ const flatbuffers::uoffset_t record_batch_len = record_batch_builder.GetSize ();
223+ output.insert (
224+ output.end (),
225+ reinterpret_cast <const uint8_t *>(&record_batch_len),
226+ reinterpret_cast <const uint8_t *>(&record_batch_len) + sizeof (record_batch_len)
227+ );
228+ output.insert (
229+ output.end (),
230+ record_batch_builder.GetBufferPointer (),
231+ record_batch_builder.GetBufferPointer () + record_batch_len
232+ );
233+ // padding to 8 bytes
234+ output.insert (
235+ output.end (),
236+ utils::align_to_8 (static_cast <int64_t >(output.size ())) - static_cast <int64_t >(output.size ()),
237+ 0
238+ );
239+ std::vector<uint8_t > body = generate_body (record_batch);
240+ output.insert (output.end (), std::make_move_iterator (body.begin ()), std::make_move_iterator (body.end ()));
241+ return output;
242+ }
243+
244+ template <std::ranges::input_range R>
245+ requires std::same_as<std::ranges::range_value_t <R>, sparrow::record_batch>
246+ std::vector<uint8_t > serialize_record_batches (const R& record_batches)
247+ {
248+ std::vector<uint8_t > output;
249+ for (const auto & record_batch : record_batches)
250+ {
251+ const auto rb_serialized = serialize_record_batch (record_batch);
252+ output.insert (
253+ output.end (),
254+ std::make_move_iterator (rb_serialized.begin ()),
255+ std::make_move_iterator (rb_serialized.end ())
256+ );
257+ }
258+ return output;
259+ }
260+
261+ template <std::ranges::input_range R>
262+ requires std::same_as<std::ranges::range_value_t <R>, sparrow::record_batch>
263+ void serialize (const R& record_batches, std::ostream& out)
264+ {
265+ if (check_record_batches_consistency (record_batches))
266+ {
267+ throw std::invalid_argument (
268+ " All record batches must have the same schema to be serialized together."
269+ );
270+ }
271+ std::vector<uint8_t > serialized_schema = serialize_schema_message (record_batches[0 ].schema ());
272+ std::vector<uint8_t > serialized_record_batches = serialize_record_batches (record_batches);
273+ }
274+ }
0 commit comments