diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 2003d778052..e33b6c3fd07 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -259,6 +259,14 @@ class ClientStreamReader : public FlightStreamReader { } return batches; } + + arrow::ipc::ReadStats stats() const override { + if (batch_reader_ == nullptr) { + return ipc::ReadStats{}; + } + return batch_reader_->stats(); + } + arrow::Result> ToTable() override { return ToTable(stop_token_); } @@ -278,7 +286,7 @@ class ClientStreamReader : public FlightStreamReader { StopToken stop_token_; std::shared_ptr memory_manager_; std::shared_ptr peekable_reader_; - std::shared_ptr batch_reader_; + std::shared_ptr batch_reader_; std::shared_ptr app_metadata_; }; diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index ae6011b117a..3ad9f26275b 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -141,6 +141,10 @@ class ARROW_FLIGHT_EXPORT FlightStreamReader : public MetadataRecordBatchReader using MetadataRecordBatchReader::ToTable; /// \brief Consume entire stream as a Table arrow::Result> ToTable(const StopToken& stop_token); + + using MetadataRecordBatchReader::stats; + /// \brief Return current read statistics + virtual arrow::ipc::ReadStats stats() const = 0; }; // Silence warning diff --git a/cpp/src/arrow/flight/transport_server.cc b/cpp/src/arrow/flight/transport_server.cc index b58e6a58b13..197def7cabd 100644 --- a/cpp/src/arrow/flight/transport_server.cc +++ b/cpp/src/arrow/flight/transport_server.cc @@ -137,6 +137,13 @@ class TransportMessageReader final : public FlightMessageReader { return out; } + ipc::ReadStats stats() const override { + if (batch_reader_ == nullptr) { + return ipc::ReadStats{}; + } + return batch_reader_->stats(); + } + private: /// Ensure we are set up to read data. Status EnsureDataStarted() { @@ -157,7 +164,7 @@ class TransportMessageReader final : public FlightMessageReader { FlightDescriptor descriptor_; std::shared_ptr peekable_reader_; std::shared_ptr memory_manager_; - std::shared_ptr batch_reader_; + std::shared_ptr batch_reader_; std::shared_ptr app_metadata_; }; @@ -233,9 +240,6 @@ class TransportMessageWriter final : public FlightMessageWriter { return MakeFlightError(FlightStatusCode::Internal, "Could not write metadata to stream (client disconnect?)"); } - // Those messages are not written through the batch writer, - // count them separately to include them in the stats. - extra_messages_++; return Status::OK(); } @@ -259,9 +263,7 @@ class TransportMessageWriter final : public FlightMessageWriter { ipc::WriteStats stats() const override { ARROW_CHECK_NE(batch_writer_, nullptr); - auto write_stats = batch_writer_->stats(); - write_stats.num_messages += extra_messages_; - return write_stats; + return batch_writer_->stats(); } private: @@ -276,7 +278,6 @@ class TransportMessageWriter final : public FlightMessageWriter { std::unique_ptr batch_writer_; std::shared_ptr app_metadata_; ::arrow::ipc::IpcWriteOptions ipc_options_; - int64_t extra_messages_ = 0; }; /// \brief Adapt TransportDataStream to the FlightMetadataWriter diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index 656cc00e676..d498ac67f7a 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -47,6 +47,7 @@ class Table; namespace ipc { class DictionaryMemo; +struct ReadStats; } // namespace ipc namespace util { @@ -1179,6 +1180,9 @@ class ARROW_FLIGHT_EXPORT MetadataRecordBatchReader { /// \brief Consume entire stream as a Table virtual arrow::Result> ToTable(); + + /// \brief Return current read statistics + virtual arrow::ipc::ReadStats stats() const = 0; }; /// \brief Convert a MetadataRecordBatchReader to a regular RecordBatchReader. diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index fe2e1b3d674..836e78c520d 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -1118,6 +1118,19 @@ cdef class _MetadataRecordBatchReader(_Weakrefable, _ReadPandasMixin): return reader + @property + def stats(self): + """ + Current Flight read statistics. + + Returns + ------- + ReadStats + """ + if not self.reader: + raise ValueError("Operation on closed reader") + return _wrap_read_stats(( self.reader.get()).stats()) + cdef class MetadataRecordBatchReader(_MetadataRecordBatchReader): """The base class for readers for Flight streams. diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd index b1af6bcb4fa..caf1f67cf8b 100644 --- a/python/pyarrow/includes/libarrow_flight.pxd +++ b/python/pyarrow/includes/libarrow_flight.pxd @@ -202,6 +202,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: CResult[shared_ptr[CSchema]] GetSchema() CResult[CFlightStreamChunk] Next() CResult[shared_ptr[CTable]] ToTable() + CIpcReadStats stats() const CResult[shared_ptr[CRecordBatchReader]] MakeRecordBatchReader\ " arrow::flight::MakeRecordBatchReader"( diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi index d57a899b58d..9608194303d 100644 --- a/python/pyarrow/ipc.pxi +++ b/python/pyarrow/ipc.pxi @@ -125,7 +125,6 @@ class ReadStats(_ReadStats): __slots__ = () -@staticmethod cdef _wrap_read_stats(CIpcReadStats c): return ReadStats(c.num_messages, c.num_record_batches, c.num_dictionary_batches, c.num_dictionary_deltas, diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index cbdac7f1020..2e397cca043 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -48,6 +48,9 @@ cdef class IpcReadOptions(_Weakrefable): CIpcReadOptions c_options +cdef _wrap_read_stats(CIpcReadStats c) + + cdef class Message(_Weakrefable): cdef: unique_ptr[CMessage] message diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py index b1fa2fb310a..bcaf9dcad9b 100644 --- a/python/pyarrow/tests/test_flight.py +++ b/python/pyarrow/tests/test_flight.py @@ -36,7 +36,7 @@ import pytest import pyarrow as pa -from pyarrow.lib import IpcReadOptions, tobytes +from pyarrow.lib import IpcReadOptions, ReadStats, tobytes from pyarrow.util import find_free_port from pyarrow.tests import util @@ -185,6 +185,7 @@ def do_get(self, context, ticket): def do_put(self, context, descriptor, reader, writer): counter = 0 expected_data = [-10, -5, 0, 5, 10] + assert reader.stats.num_messages == 1 for batch, buf in reader: assert batch.equals(pa.RecordBatch.from_arrays( [pa.array([expected_data[counter]])], @@ -195,6 +196,8 @@ def do_put(self, context, descriptor, reader, writer): assert counter == client_counter writer.write(struct.pack('