Skip to content

Commit de52048

Browse files
authored
GH-26727: [C++][Flight] Use ipc::RecordBatchWriter with custom IpcPayloadWriter for TransportMessageWriter (DoExchange) (#47410)
### Rationale for this change The Flight Server DoExchange method currently does not support Dictionary replacement or Dictionary Deltas, similar to how the client currently behaves or how we do for DoGet we should use an `ipc::RecordBatchWriter` with a custom `IpcPayloadWriter` instead of reimplementing Dictionary Replacement / Deltas logic. ### What changes are included in this PR? Removes manually generation of individual ipc Payloads and uses an `ipc::RecordBatchWriter` and a custom `TransportMessagePayloadWriter` to modify the `IpcPayloads` into `FlightPayloads`. ### Are these changes tested? Yes, existing tests cover the DoExchange functionality and new test for Python has been added where Dictionary deltas are being send via DoExchange. The test was failing before this change because the dictionary wasn't updated: ```python received_table = reader.read_all() expected_table = simple_dicts_table() > assert received_table.equals(expected_table) E assert False E + where False = equals(pyarrow.Table\nsome_dicts: dictionary<values=string, indices=int64, ordered=0>\n----\nsome_dicts: [ -- dictionary:\n["foo... -- dictionary:\n["foo","baz","quux"] -- indices:\n[2,1], -- dictionary:\n["foo","baz","quux","new"] -- indices:\n[0,3]]) E + where equals = pyarrow.Table\nsome_dicts: dictionary<values=string, indices=int64, ordered=0>\n----\nsome_dicts: [ -- dictionary:\n["foo...ull], -- dictionary:\n["foo","baz","quux"] -- indices:\n[2,1], -- dictionary:\n["foo","baz","quux"] -- indices:\n[0,3]].equals pyarrow/tests/test_flight.py:2596: AssertionError ================================================================================== short test summary info =================================================================================== FAILED pyarrow/tests/test_flight.py::test_flight_dictionary_deltas_do_exchange - assert False ``` ### Are there any user-facing changes? No, only that the expected dictionary replacement/deltas will work for DoExchange. * GitHub Issue: #26727 Authored-by: Raúl Cumplido <[email protected]> Signed-off-by: Raúl Cumplido <[email protected]>
1 parent 3469def commit de52048

File tree

2 files changed

+114
-54
lines changed

2 files changed

+114
-54
lines changed

cpp/src/arrow/flight/transport_server.cc

Lines changed: 76 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "arrow/ipc/reader.h"
2727
#include "arrow/result.h"
2828
#include "arrow/status.h"
29+
#include "arrow/util/logging_internal.h"
2930

3031
namespace arrow {
3132
namespace flight {
@@ -160,25 +161,63 @@ class TransportMessageReader final : public FlightMessageReader {
160161
std::shared_ptr<Buffer> app_metadata_;
161162
};
162163

163-
// TODO(ARROW-10787): this should use the same writer/ipc trick as client
164+
/// \brief An IpcPayloadWriter for ServerDataStream.
165+
///
166+
/// To support app_metadata and reuse the existing IPC infrastructure,
167+
/// this takes a pointer to a buffer to be combined with the IPC
168+
/// payload when writing a Flight payload.
169+
class TransportMessagePayloadWriter : public ipc::internal::IpcPayloadWriter {
170+
public:
171+
TransportMessagePayloadWriter(ServerDataStream* stream,
172+
std::shared_ptr<Buffer>* app_metadata)
173+
: stream_(stream), app_metadata_(app_metadata) {}
174+
175+
Status Start() override { return Status::OK(); }
176+
Status WritePayload(const ipc::IpcPayload& ipc_payload) override {
177+
FlightPayload payload;
178+
payload.ipc_message = ipc_payload;
179+
180+
if (ipc_payload.type == ipc::MessageType::RECORD_BATCH && *app_metadata_) {
181+
payload.app_metadata = std::move(*app_metadata_);
182+
}
183+
ARROW_ASSIGN_OR_RAISE(auto success, stream_->WriteData(payload));
184+
if (!success) {
185+
return MakeFlightError(
186+
FlightStatusCode::Internal,
187+
"Could not write record batch to stream (client disconnect?)");
188+
}
189+
return arrow::Status::OK();
190+
}
191+
Status Close() override {
192+
// Closing is handled one layer up in TransportMessageWriter::Close
193+
return Status::OK();
194+
}
195+
196+
private:
197+
ServerDataStream* stream_;
198+
std::shared_ptr<Buffer>* app_metadata_;
199+
};
200+
164201
class TransportMessageWriter final : public FlightMessageWriter {
165202
public:
166203
explicit TransportMessageWriter(ServerDataStream* stream)
167-
: stream_(stream), ipc_options_(::arrow::ipc::IpcWriteOptions::Defaults()) {}
204+
: stream_(stream),
205+
app_metadata_(nullptr),
206+
ipc_options_(::arrow::ipc::IpcWriteOptions::Defaults()) {}
168207

169208
Status Begin(const std::shared_ptr<Schema>& schema,
170209
const ipc::IpcWriteOptions& options) override {
171-
if (started_) {
210+
if (batch_writer_) {
172211
return Status::Invalid("This writer has already been started.");
173212
}
174-
started_ = true;
175213
ipc_options_ = options;
214+
std::unique_ptr<ipc::internal::IpcPayloadWriter> payload_writer(
215+
new TransportMessagePayloadWriter(stream_, &app_metadata_));
176216

177-
RETURN_NOT_OK(mapper_.AddSchemaFields(*schema));
178-
FlightPayload schema_payload;
179-
RETURN_NOT_OK(ipc::GetSchemaPayload(*schema, ipc_options_, mapper_,
180-
&schema_payload.ipc_message));
181-
return WritePayload(schema_payload);
217+
ARROW_ASSIGN_OR_RAISE(batch_writer_,
218+
ipc::internal::OpenRecordBatchWriter(std::move(payload_writer),
219+
schema, ipc_options_));
220+
return Status::OK();
182221
}
183222

184223
Status WriteRecordBatch(const RecordBatch& batch) override {
@@ -188,71 +227,56 @@ class TransportMessageWriter final : public FlightMessageWriter {
188227
Status WriteMetadata(std::shared_ptr<Buffer> app_metadata) override {
189228
FlightPayload payload{};
190229
payload.app_metadata = app_metadata;
191-
return WritePayload(payload);
230+
ARROW_ASSIGN_OR_RAISE(auto success, stream_->WriteData(payload));
231+
if (!success) {
232+
ARROW_RETURN_NOT_OK(Close());
233+
return MakeFlightError(FlightStatusCode::Internal,
234+
"Could not write metadata to stream (client disconnect?)");
235+
}
236+
// Those messages are not written through the batch writer,
237+
// count them separately to include them in the stats.
238+
extra_messages_++;
239+
return Status::OK();
192240
}
193241

194242
Status WriteWithMetadata(const RecordBatch& batch,
195243
std::shared_ptr<Buffer> app_metadata) override {
196244
RETURN_NOT_OK(CheckStarted());
197-
RETURN_NOT_OK(EnsureDictionariesWritten(batch));
198-
FlightPayload payload{};
199-
if (app_metadata) {
200-
payload.app_metadata = app_metadata;
245+
app_metadata_ = app_metadata;
246+
auto status = batch_writer_->WriteRecordBatch(batch);
247+
if (!status.ok()) {
248+
ARROW_RETURN_NOT_OK(Close());
201249
}
202-
RETURN_NOT_OK(ipc::GetRecordBatchPayload(batch, ipc_options_, &payload.ipc_message));
203-
RETURN_NOT_OK(WritePayload(payload));
204-
++stats_.num_record_batches;
205-
return Status::OK();
250+
return status;
206251
}
207252

208253
Status Close() override {
209-
// It's fine to Close() without writing data
254+
if (batch_writer_) {
255+
RETURN_NOT_OK(batch_writer_->Close());
256+
}
210257
return Status::OK();
211258
}
212259

213-
ipc::WriteStats stats() const override { return stats_; }
214-
215-
private:
216-
Status WritePayload(const FlightPayload& payload) {
217-
ARROW_ASSIGN_OR_RAISE(auto success, stream_->WriteData(payload));
218-
if (!success) {
219-
return MakeFlightError(FlightStatusCode::Internal,
220-
"Could not write metadata to stream (client disconnect?)");
221-
}
222-
++stats_.num_messages;
223-
return Status::OK();
260+
ipc::WriteStats stats() const override {
261+
ARROW_CHECK_NE(batch_writer_, nullptr);
262+
auto write_stats = batch_writer_->stats();
263+
write_stats.num_messages += extra_messages_;
264+
return write_stats;
224265
}
225266

267+
private:
226268
Status CheckStarted() {
227-
if (!started_) {
269+
if (!batch_writer_) {
228270
return Status::Invalid("This writer is not started. Call Begin() with a schema");
229271
}
230272
return Status::OK();
231273
}
232274

233-
Status EnsureDictionariesWritten(const RecordBatch& batch) {
234-
if (dictionaries_written_) {
235-
return Status::OK();
236-
}
237-
dictionaries_written_ = true;
238-
ARROW_ASSIGN_OR_RAISE(const auto dictionaries,
239-
ipc::CollectDictionaries(batch, mapper_));
240-
for (const auto& pair : dictionaries) {
241-
FlightPayload payload{};
242-
RETURN_NOT_OK(ipc::GetDictionaryPayload(pair.first, pair.second, ipc_options_,
243-
&payload.ipc_message));
244-
RETURN_NOT_OK(WritePayload(payload));
245-
++stats_.num_dictionary_batches;
246-
}
247-
return Status::OK();
248-
}
249-
250275
ServerDataStream* stream_;
276+
std::unique_ptr<ipc::RecordBatchWriter> batch_writer_;
277+
std::shared_ptr<Buffer> app_metadata_;
251278
::arrow::ipc::IpcWriteOptions ipc_options_;
252-
ipc::DictionaryFieldMapper mapper_;
253-
ipc::WriteStats stats_;
254-
bool started_ = false;
255-
bool dictionaries_written_ = false;
279+
int64_t extra_messages_ = 0;
256280
};
257281

258282
/// \brief Adapt TransportDataStream to the FlightMetadataWriter

python/pyarrow/tests/test_flight.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,12 @@ def simple_ints_table():
114114

115115
def simple_dicts_table():
116116
dict_values = pa.array(["foo", "baz", "quux"], type=pa.utf8())
117-
new_dict_values = pa.array(["bar", "qux"], type=pa.utf8())
117+
new_dict_values = pa.array(["foo", "baz", "quux", "new"], type=pa.utf8())
118118
data = [
119119
pa.chunked_array([
120120
pa.DictionaryArray.from_arrays([1, 0, None], dict_values),
121121
pa.DictionaryArray.from_arrays([2, 1], dict_values),
122-
pa.DictionaryArray.from_arrays([0, 1], new_dict_values)
122+
pa.DictionaryArray.from_arrays([0, 3], new_dict_values)
123123
])
124124
]
125125
return pa.Table.from_arrays(data, names=['some_dicts'])
@@ -2536,3 +2536,39 @@ def received_headers(self, headers):
25362536
assert ("x-header-bin", b"header\x01value") in factory.headers
25372537
assert ("x-trailer", "trailer-value") in factory.headers
25382538
assert ("x-trailer-bin", b"trailer\x01value") in factory.headers
2539+
2540+
2541+
def test_flight_dictionary_deltas_do_exchange():
2542+
class DeltaFlightServer(ConstantFlightServer):
2543+
def do_exchange(self, context, descriptor, reader, writer):
2544+
if descriptor.command == b'dict_deltas':
2545+
expected_table = simple_dicts_table()
2546+
received_table = reader.read_all()
2547+
assert received_table.equals(expected_table)
2548+
2549+
options = pa.ipc.IpcWriteOptions(emit_dictionary_deltas=True)
2550+
writer.begin(expected_table.schema, options=options)
2551+
# TODO: GH-47422: Inspect ReaderStats once exposed and validate deltas
2552+
writer.write_table(expected_table)
2553+
2554+
with DeltaFlightServer() as server, \
2555+
FlightClient(('localhost', server.port)) as client:
2556+
expected_table = simple_dicts_table()
2557+
2558+
descriptor = flight.FlightDescriptor.for_command(b"dict_deltas")
2559+
writer, reader = client.do_exchange(descriptor,
2560+
options=flight.FlightCallOptions(
2561+
write_options=pa.ipc.IpcWriteOptions(
2562+
emit_dictionary_deltas=True)
2563+
)
2564+
)
2565+
# Send client table with dictionary updates (deltas should be sent)
2566+
with writer:
2567+
writer.begin(expected_table.schema, options=pa.ipc.IpcWriteOptions(
2568+
emit_dictionary_deltas=True))
2569+
writer.write_table(expected_table)
2570+
writer.done_writing()
2571+
received_table = reader.read_all()
2572+
2573+
# TODO: GH-47422: Inspect ReaderStats once exposed and validate deltas
2574+
assert received_table.equals(expected_table)

0 commit comments

Comments
 (0)