11use crate :: config_extension_ext:: ContextGrpcMetadata ;
2- use crate :: execution_plans:: { DistributedTaskContext , StageExec } ;
2+ use crate :: execution_plans:: {
3+ DistributedTaskContext , StageExec , collect_and_create_metrics_flight_data,
4+ } ;
35use crate :: flight_service:: service:: ArrowFlightEndpoint ;
46use crate :: flight_service:: session_builder:: DistributedSessionBuilderContext ;
7+ use crate :: flight_service:: trailing_flight_data_stream:: TrailingFlightDataStream ;
58use crate :: protobuf:: {
69 DistributedCodec , StageExecProto , StageKey , datafusion_error_to_tonic_status, stage_from_proto,
710} ;
@@ -94,12 +97,6 @@ impl ArrowFlightEndpoint {
9497 } )
9598 . await ?;
9699 let stage = Arc :: clone ( & stage_data. stage ) ;
97- let num_partitions_remaining = Arc :: clone ( & stage_data. num_partitions_remaining ) ;
98-
99- // If all the partitions are done, remove the stage from the cache.
100- if num_partitions_remaining. fetch_sub ( 1 , Ordering :: SeqCst ) <= 1 {
101- self . task_data_entries . remove ( key) ;
102- }
103100
104101 // Find out which partition group we are executing
105102 let cfg = session_state. config_mut ( ) ;
@@ -126,15 +123,30 @@ impl ArrowFlightEndpoint {
126123 . execute ( doget. target_partition as usize , session_state. task_ctx ( ) )
127124 . map_err ( |err| Status :: internal ( format ! ( "Error executing stage plan: {err:#?}" ) ) ) ?;
128125
129- Ok ( record_batch_stream_to_response ( stream) )
126+ let task_data_capture = self . task_data_entries . clone ( ) ;
127+ Ok ( flight_stream_from_record_batch_stream (
128+ key. clone ( ) ,
129+ stage,
130+ stage_data. clone ( ) ,
131+ move || {
132+ task_data_capture. remove ( key. clone ( ) ) ;
133+ } ,
134+ stream,
135+ ) )
130136 }
131137}
132138
133139fn missing ( field : & ' static str ) -> impl FnOnce ( ) -> Status {
134140 move || Status :: invalid_argument ( format ! ( "Missing field '{field}'" ) )
135141}
136142
137- fn record_batch_stream_to_response (
143+ // Creates a tonic response from a stream of record batches. Handles
144+ // - RecordBatch to flight conversion partition tracking, stage eviction, and trailing metrics.
145+ fn flight_stream_from_record_batch_stream (
146+ stage_key : StageKey ,
147+ stage : Arc < StageExec > ,
148+ stage_data : TaskData ,
149+ evict_stage : impl FnOnce ( ) + Send + ' static ,
138150 stream : SendableRecordBatchStream ,
139151) -> Response < <ArrowFlightEndpoint as FlightService >:: DoGetStream > {
140152 let flight_data_stream =
@@ -144,7 +156,31 @@ fn record_batch_stream_to_response(
144156 FlightError :: Tonic ( Box :: new ( datafusion_error_to_tonic_status ( & err) ) )
145157 } ) ) ;
146158
147- Response :: new ( Box :: pin ( flight_data_stream. map_err ( |err| match err {
159+ let trailing_metrics_stream = TrailingFlightDataStream :: new (
160+ move || {
161+ if stage_data
162+ . num_partitions_remaining
163+ . fetch_sub ( 1 , Ordering :: SeqCst )
164+ == 1
165+ {
166+ evict_stage ( ) ;
167+
168+ let metrics_stream = collect_and_create_metrics_flight_data ( stage_key, stage)
169+ . map_err ( |err| {
170+ Status :: internal ( format ! (
171+ "error collecting metrics in arrow flight endpoint: {err}"
172+ ) )
173+ } ) ?;
174+
175+ return Ok ( Some ( metrics_stream) ) ;
176+ }
177+
178+ Ok ( None )
179+ } ,
180+ flight_data_stream,
181+ ) ;
182+
183+ Response :: new ( Box :: pin ( trailing_metrics_stream. map_err ( |err| match err {
148184 FlightError :: Tonic ( status) => * status,
149185 _ => Status :: internal ( format ! ( "Error during flight stream: {err}" ) ) ,
150186 } ) ) )
@@ -215,24 +251,27 @@ mod tests {
215251 let stage_proto = proto_from_stage ( & stage, & DefaultPhysicalExtensionCodec { } ) . unwrap ( ) ;
216252 let stage_proto_for_closure = stage_proto. clone ( ) ;
217253 let endpoint_ref = & endpoint;
254+
218255 let do_get = async move |partition : u64 , task_number : u64 , stage_key : StageKey | {
219256 let stage_proto = stage_proto_for_closure. clone ( ) ;
220- // Create DoGet message
221257 let doget = DoGet {
222258 stage_proto : Some ( stage_proto) ,
223259 target_task_index : task_number,
224260 target_partition : partition,
225261 stage_key : Some ( stage_key) ,
226262 } ;
227263
228- // Create Flight ticket
229264 let ticket = Ticket {
230265 ticket : Bytes :: from ( doget. encode_to_vec ( ) ) ,
231266 } ;
232267
233- // Call the actual get() method
234268 let request = Request :: new ( ticket) ;
235- endpoint_ref. get ( request) . await
269+ let response = endpoint_ref. get ( request) . await ?;
270+ let mut stream = response. into_inner ( ) ;
271+
272+ // Consume the stream.
273+ while let Some ( _flight_data) = stream. try_next ( ) . await ? { }
274+ Ok :: < ( ) , Status > ( ( ) )
236275 } ;
237276
238277 // For each task, call do_get() for each partition except the last.
@@ -248,22 +287,22 @@ mod tests {
248287
249288 // Run the last partition of task 0. Any partition number works. Verify that the task state
250289 // is evicted because all partitions have been processed.
251- let result = do_get ( 1 , 0 , task_keys[ 0 ] . clone ( ) ) . await ;
290+ let result = do_get ( 2 , 0 , task_keys[ 0 ] . clone ( ) ) . await ;
252291 assert ! ( result. is_ok( ) ) ;
253292 let stored_stage_keys = endpoint. task_data_entries . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
254293 assert_eq ! ( stored_stage_keys. len( ) , 2 ) ;
255294 assert ! ( stored_stage_keys. contains( & task_keys[ 1 ] ) ) ;
256295 assert ! ( stored_stage_keys. contains( & task_keys[ 2 ] ) ) ;
257296
258297 // Run the last partition of task 1.
259- let result = do_get ( 1 , 1 , task_keys[ 1 ] . clone ( ) ) . await ;
298+ let result = do_get ( 2 , 1 , task_keys[ 1 ] . clone ( ) ) . await ;
260299 assert ! ( result. is_ok( ) ) ;
261300 let stored_stage_keys = endpoint. task_data_entries . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
262301 assert_eq ! ( stored_stage_keys. len( ) , 1 ) ;
263302 assert ! ( stored_stage_keys. contains( & task_keys[ 2 ] ) ) ;
264303
265304 // Run the last partition of the last task.
266- let result = do_get ( 1 , 2 , task_keys[ 2 ] . clone ( ) ) . await ;
305+ let result = do_get ( 2 , 2 , task_keys[ 2 ] . clone ( ) ) . await ;
267306 assert ! ( result. is_ok( ) ) ;
268307 let stored_stage_keys = endpoint. task_data_entries . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
269308 assert_eq ! ( stored_stage_keys. len( ) , 0 ) ;
0 commit comments