3636import pytest
3737import pyarrow as pa
3838
39- from pyarrow .lib import IpcReadOptions , tobytes
39+ from pyarrow .lib import IpcReadOptions , ReadStats , tobytes
4040from pyarrow .util import find_free_port
4141from pyarrow .tests import util
4242
@@ -185,6 +185,7 @@ def do_get(self, context, ticket):
185185 def do_put (self , context , descriptor , reader , writer ):
186186 counter = 0
187187 expected_data = [- 10 , - 5 , 0 , 5 , 10 ]
188+ assert reader .stats .num_messages == 1
188189 for batch , buf in reader :
189190 assert batch .equals (pa .RecordBatch .from_arrays (
190191 [pa .array ([expected_data [counter ]])],
@@ -195,6 +196,8 @@ def do_put(self, context, descriptor, reader, writer):
195196 assert counter == client_counter
196197 writer .write (struct .pack ('<i' , counter ))
197198 counter += 1
199+ assert reader .stats .num_messages == 6
200+ assert reader .stats .num_record_batches == 5
198201
199202 @staticmethod
200203 def number_batches (table ):
@@ -421,6 +424,7 @@ def __init__(self, options=None, **kwargs):
421424 self .options = options
422425
423426 def do_exchange (self , context , descriptor , reader , writer ):
427+ assert reader .stats .num_messages == 0
424428 if descriptor .descriptor_type != flight .DescriptorType .CMD :
425429 raise pa .ArrowInvalid ("Must provide a command descriptor" )
426430 elif descriptor .command == b"echo" :
@@ -449,11 +453,14 @@ def exchange_do_put(self, context, reader, writer):
449453 for chunk in reader :
450454 if not chunk .data :
451455 raise pa .ArrowInvalid ("All chunks must have data." )
456+ assert reader .stats .num_messages != 0
452457 num_batches += 1
458+ assert reader .stats .num_record_batches == num_batches
453459 writer .write_metadata (str (num_batches ).encode ("utf-8" ))
454460
455461 def exchange_echo (self , context , reader , writer ):
456462 """Run a simple echo server."""
463+ assert reader .stats .num_messages == 0
457464 started = False
458465 for chunk in reader :
459466 if not started and chunk .data :
@@ -464,16 +471,19 @@ def exchange_echo(self, context, reader, writer):
464471 elif chunk .app_metadata :
465472 writer .write_metadata (chunk .app_metadata )
466473 elif chunk .data :
474+ assert reader .stats .num_messages != 0
467475 writer .write_batch (chunk .data )
468476 else :
469477 assert False , "Should not happen"
470478
471479 def exchange_transform (self , context , reader , writer ):
472480 """Sum rows in an uploaded table."""
481+ assert reader .stats .num_messages == 0
473482 for field in reader .schema :
474483 if not pa .types .is_integer (field .type ):
475484 raise pa .ArrowInvalid ("Invalid field: " + repr (field ))
476485 table = reader .read_all ()
486+ assert reader .stats .num_messages != 0
477487 sums = [0 ] * table .num_rows
478488 for column in table :
479489 for row , value in enumerate (column ):
@@ -1170,8 +1180,17 @@ def test_flight_do_get_dicts():
11701180
11711181 with ConstantFlightServer () as server , \
11721182 flight .connect (('localhost' , server .port )) as client :
1173- data = client .do_get (flight .Ticket (b'dicts' )).read_all ()
1183+ reader = client .do_get (flight .Ticket (b'dicts' ))
1184+ assert reader .stats .num_messages == 1
1185+ data = reader .read_all ()
11741186 assert data .equals (table )
1187+ assert reader .stats == ReadStats (
1188+ num_messages = 6 ,
1189+ num_record_batches = 3 ,
1190+ num_dictionary_batches = 2 ,
1191+ num_dictionary_deltas = 0 ,
1192+ num_replaced_dictionaries = 1
1193+ )
11751194
11761195
11771196def test_flight_do_get_ticket ():
@@ -2090,6 +2109,8 @@ def test_doexchange_put():
20902109 assert chunk .data is None
20912110 expected_buf = str (len (batches )).encode ("utf-8" )
20922111 assert chunk .app_metadata == expected_buf
2112+ # Metadata only message is not counted as an ipc data message
2113+ assert reader .stats .num_messages == 0
20932114
20942115
20952116def test_doexchange_echo ():
@@ -2114,12 +2135,15 @@ def test_doexchange_echo():
21142135
21152136 # Now write data without metadata.
21162137 writer .begin (data .schema )
2138+ num_batches = 0
21172139 for batch in batches :
21182140 writer .write_batch (batch )
21192141 assert reader .schema == data .schema
21202142 chunk = reader .read_chunk ()
21212143 assert chunk .data == batch
21222144 assert chunk .app_metadata is None
2145+ num_batches += 1
2146+ assert reader .stats .num_record_batches == num_batches
21232147
21242148 # And write data with metadata.
21252149 for i , batch in enumerate (batches ):
@@ -2128,6 +2152,8 @@ def test_doexchange_echo():
21282152 chunk = reader .read_chunk ()
21292153 assert chunk .data == batch
21302154 assert chunk .app_metadata == buf
2155+ num_batches += 1
2156+ assert reader .stats .num_record_batches == num_batches
21312157
21322158
21332159def test_doexchange_echo_v4 ():
@@ -2539,36 +2565,56 @@ def received_headers(self, headers):
25392565
25402566
25412567def test_flight_dictionary_deltas_do_exchange ():
2568+ expected_stats = {
2569+ 'dict_deltas' : ReadStats (
2570+ num_messages = 6 ,
2571+ num_record_batches = 3 ,
2572+ num_dictionary_batches = 2 ,
2573+ num_dictionary_deltas = 1 ,
2574+ num_replaced_dictionaries = 0
2575+ ),
2576+ 'dict_replacement' : ReadStats (
2577+ num_messages = 6 ,
2578+ num_record_batches = 3 ,
2579+ num_dictionary_batches = 2 ,
2580+ num_dictionary_deltas = 0 ,
2581+ num_replaced_dictionaries = 1
2582+ )
2583+ }
2584+
25422585 class DeltaFlightServer (ConstantFlightServer ):
25432586 def do_exchange (self , context , descriptor , reader , writer ):
2587+ expected_table = simple_dicts_table ()
2588+ received_table = reader .read_all ()
2589+ assert received_table .equals (expected_table )
2590+ assert reader .stats == expected_stats [descriptor .command .decode ()]
25442591 if descriptor .command == b'dict_deltas' :
2545- expected_table = simple_dicts_table ()
2546- received_table = reader .read_all ()
2547- assert received_table .equals (expected_table )
2548-
25492592 options = pa .ipc .IpcWriteOptions (emit_dictionary_deltas = True )
25502593 writer .begin (expected_table .schema , options = options )
2551- # TODO: GH-47422: Inspect ReaderStats once exposed and validate deltas
2594+ writer .write_table (expected_table )
2595+ if descriptor .command == b'dict_replacement' :
2596+ writer .begin (expected_table .schema )
25522597 writer .write_table (expected_table )
25532598
25542599 with DeltaFlightServer () as server , \
25552600 FlightClient (('localhost' , server .port )) as client :
25562601 expected_table = simple_dicts_table ()
2602+ for command in ["dict_deltas" , "dict_replacement" ]:
2603+ descriptor = flight .FlightDescriptor .for_command (command )
2604+ writer , reader = client .do_exchange (
2605+ descriptor ,
2606+ options = flight .FlightCallOptions (
2607+ write_options = pa .ipc .IpcWriteOptions (
2608+ emit_dictionary_deltas = True )
2609+ )
2610+ )
2611+ # Send client table with dictionary updates
2612+ with writer :
2613+ writer .begin (expected_table .schema , options = pa .ipc .IpcWriteOptions (
2614+ emit_dictionary_deltas = (command == "dict_deltas" )))
2615+ writer .write_table (expected_table )
2616+ writer .done_writing ()
2617+ received_table = reader .read_all ()
25572618
2558- descriptor = flight .FlightDescriptor .for_command (b"dict_deltas" )
2559- writer , reader = client .do_exchange (descriptor ,
2560- options = flight .FlightCallOptions (
2561- write_options = pa .ipc .IpcWriteOptions (
2562- emit_dictionary_deltas = True )
2563- )
2564- )
2565- # Send client table with dictionary updates (deltas should be sent)
2566- with writer :
2567- writer .begin (expected_table .schema , options = pa .ipc .IpcWriteOptions (
2568- emit_dictionary_deltas = True ))
2569- writer .write_table (expected_table )
2570- writer .done_writing ()
2571- received_table = reader .read_all ()
2572-
2573- # TODO: GH-47422: Inspect ReaderStats once exposed and validate deltas
2574- assert received_table .equals (expected_table )
2619+ assert received_table .equals (expected_table )
2620+ assert reader .stats == expected_stats [command ]
0 commit comments