Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion cpp/src/arrow/flight/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,14 @@ class ClientStreamReader : public FlightStreamReader {
}
return batches;
}

arrow::ipc::ReadStats stats() const override {
if (batch_reader_ == nullptr) {
return ipc::ReadStats{};
}
return batch_reader_->stats();
}

arrow::Result<std::shared_ptr<Table>> ToTable() override {
return ToTable(stop_token_);
}
Expand All @@ -278,7 +286,7 @@ class ClientStreamReader : public FlightStreamReader {
StopToken stop_token_;
std::shared_ptr<MemoryManager> memory_manager_;
std::shared_ptr<internal::PeekableFlightDataReader> peekable_reader_;
std::shared_ptr<ipc::RecordBatchReader> batch_reader_;
std::shared_ptr<ipc::RecordBatchStreamReader> batch_reader_;
std::shared_ptr<Buffer> app_metadata_;
};

Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/flight/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ class ARROW_FLIGHT_EXPORT FlightStreamReader : public MetadataRecordBatchReader
using MetadataRecordBatchReader::ToTable;
/// \brief Consume entire stream as a Table
arrow::Result<std::shared_ptr<Table>> ToTable(const StopToken& stop_token);

using MetadataRecordBatchReader::stats;
/// \brief Return current read statistics
virtual arrow::ipc::ReadStats stats() const = 0;
};

// Silence warning
Expand Down
17 changes: 9 additions & 8 deletions cpp/src/arrow/flight/transport_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,13 @@ class TransportMessageReader final : public FlightMessageReader {
return out;
}

ipc::ReadStats stats() const override {
if (batch_reader_ == nullptr) {
return ipc::ReadStats{};
}
return batch_reader_->stats();
}

private:
/// Ensure we are set up to read data.
Status EnsureDataStarted() {
Expand All @@ -157,7 +164,7 @@ class TransportMessageReader final : public FlightMessageReader {
FlightDescriptor descriptor_;
std::shared_ptr<internal::PeekableFlightDataReader> peekable_reader_;
std::shared_ptr<MemoryManager> memory_manager_;
std::shared_ptr<RecordBatchReader> batch_reader_;
std::shared_ptr<ipc::RecordBatchStreamReader> batch_reader_;
std::shared_ptr<Buffer> app_metadata_;
};

Expand Down Expand Up @@ -233,9 +240,6 @@ class TransportMessageWriter final : public FlightMessageWriter {
return MakeFlightError(FlightStatusCode::Internal,
"Could not write metadata to stream (client disconnect?)");
}
// Those messages are not written through the batch writer,
// count them separately to include them in the stats.
extra_messages_++;
return Status::OK();
}

Expand All @@ -259,9 +263,7 @@ class TransportMessageWriter final : public FlightMessageWriter {

ipc::WriteStats stats() const override {
ARROW_CHECK_NE(batch_writer_, nullptr);
auto write_stats = batch_writer_->stats();
write_stats.num_messages += extra_messages_;
return write_stats;
return batch_writer_->stats();
}

private:
Expand All @@ -276,7 +278,6 @@ class TransportMessageWriter final : public FlightMessageWriter {
std::unique_ptr<ipc::RecordBatchWriter> batch_writer_;
std::shared_ptr<Buffer> app_metadata_;
::arrow::ipc::IpcWriteOptions ipc_options_;
int64_t extra_messages_ = 0;
};

/// \brief Adapt TransportDataStream to the FlightMetadataWriter
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/flight/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class Table;

namespace ipc {
class DictionaryMemo;
struct ReadStats;
} // namespace ipc

namespace util {
Expand Down Expand Up @@ -1179,6 +1180,9 @@ class ARROW_FLIGHT_EXPORT MetadataRecordBatchReader {

/// \brief Consume entire stream as a Table
virtual arrow::Result<std::shared_ptr<Table>> ToTable();

/// \brief Return current read statistics
virtual arrow::ipc::ReadStats stats() const = 0;
};

/// \brief Convert a MetadataRecordBatchReader to a regular RecordBatchReader.
Expand Down
13 changes: 13 additions & 0 deletions python/pyarrow/_flight.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1118,6 +1118,19 @@ cdef class _MetadataRecordBatchReader(_Weakrefable, _ReadPandasMixin):

return reader

@property
def stats(self):
"""
Current Flight read statistics.

Returns
-------
ReadStats
"""
if not self.reader:
raise ValueError("Operation on closed reader")
return _wrap_read_stats((<CMetadataRecordBatchReader*> self.reader.get()).stats())


cdef class MetadataRecordBatchReader(_MetadataRecordBatchReader):
"""The base class for readers for Flight streams.
Expand Down
1 change: 1 addition & 0 deletions python/pyarrow/includes/libarrow_flight.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
CResult[shared_ptr[CSchema]] GetSchema()
CResult[CFlightStreamChunk] Next()
CResult[shared_ptr[CTable]] ToTable()
CIpcReadStats stats() const

CResult[shared_ptr[CRecordBatchReader]] MakeRecordBatchReader\
" arrow::flight::MakeRecordBatchReader"(
Expand Down
1 change: 0 additions & 1 deletion python/pyarrow/ipc.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ class ReadStats(_ReadStats):
__slots__ = ()


@staticmethod
cdef _wrap_read_stats(CIpcReadStats c):
return ReadStats(c.num_messages, c.num_record_batches,
c.num_dictionary_batches, c.num_dictionary_deltas,
Expand Down
3 changes: 3 additions & 0 deletions python/pyarrow/lib.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ cdef class IpcReadOptions(_Weakrefable):
CIpcReadOptions c_options


cdef _wrap_read_stats(CIpcReadStats c)


cdef class Message(_Weakrefable):
cdef:
unique_ptr[CMessage] message
Expand Down
94 changes: 70 additions & 24 deletions python/pyarrow/tests/test_flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import pytest
import pyarrow as pa

from pyarrow.lib import IpcReadOptions, tobytes
from pyarrow.lib import IpcReadOptions, ReadStats, tobytes
from pyarrow.util import find_free_port
from pyarrow.tests import util

Expand Down Expand Up @@ -185,6 +185,7 @@ def do_get(self, context, ticket):
def do_put(self, context, descriptor, reader, writer):
counter = 0
expected_data = [-10, -5, 0, 5, 10]
assert reader.stats.num_messages == 1
for batch, buf in reader:
assert batch.equals(pa.RecordBatch.from_arrays(
[pa.array([expected_data[counter]])],
Expand All @@ -195,6 +196,8 @@ def do_put(self, context, descriptor, reader, writer):
assert counter == client_counter
writer.write(struct.pack('<i', counter))
counter += 1
assert reader.stats.num_messages == 6
assert reader.stats.num_record_batches == 5

@staticmethod
def number_batches(table):
Expand Down Expand Up @@ -421,6 +424,7 @@ def __init__(self, options=None, **kwargs):
self.options = options

def do_exchange(self, context, descriptor, reader, writer):
assert reader.stats.num_messages == 0
if descriptor.descriptor_type != flight.DescriptorType.CMD:
raise pa.ArrowInvalid("Must provide a command descriptor")
elif descriptor.command == b"echo":
Expand Down Expand Up @@ -449,11 +453,14 @@ def exchange_do_put(self, context, reader, writer):
for chunk in reader:
if not chunk.data:
raise pa.ArrowInvalid("All chunks must have data.")
assert reader.stats.num_messages != 0
num_batches += 1
assert reader.stats.num_record_batches == num_batches
writer.write_metadata(str(num_batches).encode("utf-8"))

def exchange_echo(self, context, reader, writer):
"""Run a simple echo server."""
assert reader.stats.num_messages == 0
started = False
for chunk in reader:
if not started and chunk.data:
Expand All @@ -464,16 +471,19 @@ def exchange_echo(self, context, reader, writer):
elif chunk.app_metadata:
writer.write_metadata(chunk.app_metadata)
elif chunk.data:
assert reader.stats.num_messages != 0
writer.write_batch(chunk.data)
else:
assert False, "Should not happen"

def exchange_transform(self, context, reader, writer):
"""Sum rows in an uploaded table."""
assert reader.stats.num_messages == 0
for field in reader.schema:
if not pa.types.is_integer(field.type):
raise pa.ArrowInvalid("Invalid field: " + repr(field))
table = reader.read_all()
assert reader.stats.num_messages != 0
sums = [0] * table.num_rows
for column in table:
for row, value in enumerate(column):
Expand Down Expand Up @@ -1170,8 +1180,17 @@ def test_flight_do_get_dicts():

with ConstantFlightServer() as server, \
flight.connect(('localhost', server.port)) as client:
data = client.do_get(flight.Ticket(b'dicts')).read_all()
reader = client.do_get(flight.Ticket(b'dicts'))
assert reader.stats.num_messages == 1
data = reader.read_all()
assert data.equals(table)
assert reader.stats == ReadStats(
num_messages=6,
num_record_batches=3,
num_dictionary_batches=2,
num_dictionary_deltas=0,
num_replaced_dictionaries=1
)


def test_flight_do_get_ticket():
Expand Down Expand Up @@ -2090,6 +2109,8 @@ def test_doexchange_put():
assert chunk.data is None
expected_buf = str(len(batches)).encode("utf-8")
assert chunk.app_metadata == expected_buf
# Metadata only message is not counted as an ipc data message
assert reader.stats.num_messages == 0


def test_doexchange_echo():
Expand All @@ -2114,12 +2135,15 @@ def test_doexchange_echo():

# Now write data without metadata.
writer.begin(data.schema)
num_batches = 0
for batch in batches:
writer.write_batch(batch)
assert reader.schema == data.schema
chunk = reader.read_chunk()
assert chunk.data == batch
assert chunk.app_metadata is None
num_batches += 1
assert reader.stats.num_record_batches == num_batches

# And write data with metadata.
for i, batch in enumerate(batches):
Expand All @@ -2128,6 +2152,8 @@ def test_doexchange_echo():
chunk = reader.read_chunk()
assert chunk.data == batch
assert chunk.app_metadata == buf
num_batches += 1
assert reader.stats.num_record_batches == num_batches


def test_doexchange_echo_v4():
Expand Down Expand Up @@ -2539,36 +2565,56 @@ def received_headers(self, headers):


def test_flight_dictionary_deltas_do_exchange():
expected_stats = {
'dict_deltas': ReadStats(
num_messages=6,
num_record_batches=3,
num_dictionary_batches=2,
num_dictionary_deltas=1,
num_replaced_dictionaries=0
),
'dict_replacement': ReadStats(
num_messages=6,
num_record_batches=3,
num_dictionary_batches=2,
num_dictionary_deltas=0,
num_replaced_dictionaries=1
)
}

class DeltaFlightServer(ConstantFlightServer):
def do_exchange(self, context, descriptor, reader, writer):
expected_table = simple_dicts_table()
received_table = reader.read_all()
assert received_table.equals(expected_table)
assert reader.stats == expected_stats[descriptor.command.decode()]
if descriptor.command == b'dict_deltas':
expected_table = simple_dicts_table()
received_table = reader.read_all()
assert received_table.equals(expected_table)

options = pa.ipc.IpcWriteOptions(emit_dictionary_deltas=True)
writer.begin(expected_table.schema, options=options)
# TODO: GH-47422: Inspect ReaderStats once exposed and validate deltas
writer.write_table(expected_table)
if descriptor.command == b'dict_replacement':
writer.begin(expected_table.schema)
writer.write_table(expected_table)

with DeltaFlightServer() as server, \
FlightClient(('localhost', server.port)) as client:
expected_table = simple_dicts_table()
for command in ["dict_deltas", "dict_replacement"]:
descriptor = flight.FlightDescriptor.for_command(command)
writer, reader = client.do_exchange(
descriptor,
options=flight.FlightCallOptions(
write_options=pa.ipc.IpcWriteOptions(
emit_dictionary_deltas=True)
)
)
# Send client table with dictionary updates
with writer:
writer.begin(expected_table.schema, options=pa.ipc.IpcWriteOptions(
emit_dictionary_deltas=(command == "dict_deltas")))
writer.write_table(expected_table)
writer.done_writing()
received_table = reader.read_all()

descriptor = flight.FlightDescriptor.for_command(b"dict_deltas")
writer, reader = client.do_exchange(descriptor,
options=flight.FlightCallOptions(
write_options=pa.ipc.IpcWriteOptions(
emit_dictionary_deltas=True)
)
)
# Send client table with dictionary updates (deltas should be sent)
with writer:
writer.begin(expected_table.schema, options=pa.ipc.IpcWriteOptions(
emit_dictionary_deltas=True))
writer.write_table(expected_table)
writer.done_writing()
received_table = reader.read_all()

# TODO: GH-47422: Inspect ReaderStats once exposed and validate deltas
assert received_table.equals(expected_table)
assert received_table.equals(expected_table)
assert reader.stats == expected_stats[command]
Loading