11use crate :: common:: with_callback;
22use crate :: config_extension_ext:: ContextGrpcMetadata ;
3- use crate :: execution_plans:: { DistributedTaskContext , StageExec } ;
3+ use crate :: execution_plans:: {
4+ DistributedTaskContext , StageExec , collect_and_create_metrics_flight_data,
5+ } ;
46use crate :: flight_service:: service:: ArrowFlightEndpoint ;
57use crate :: flight_service:: session_builder:: DistributedSessionBuilderContext ;
8+ use crate :: flight_service:: trailing_flight_data_stream:: TrailingFlightDataStream ;
69use crate :: protobuf:: {
710 DistributedCodec , StageKey , datafusion_error_to_tonic_status, stage_from_proto,
811} ;
@@ -97,12 +100,6 @@ impl ArrowFlightEndpoint {
97100 } )
98101 . await ?;
99102 let stage = Arc :: clone ( & stage_data. stage ) ;
100- let num_partitions_remaining = Arc :: clone ( & stage_data. num_partitions_remaining ) ;
101-
102- // If all the partitions are done, remove the stage from the cache.
103- if num_partitions_remaining. fetch_sub ( 1 , Ordering :: SeqCst ) <= 1 {
104- self . task_data_entries . remove ( key) ;
105- }
106103
107104 // Find out which partition group we are executing
108105 let cfg = session_state. config_mut ( ) ;
@@ -130,24 +127,44 @@ impl ArrowFlightEndpoint {
130127 . map_err ( |err| Status :: internal ( format ! ( "Error executing stage plan: {err:#?}" ) ) ) ?;
131128
132129 let schema = stream. schema ( ) ;
130+
131+ // TODO: We don't need to do this since the stage / plan is captured again by the
132+ // TrailingFlightDataStream. However, we will eventuall only use the TrailingFlightDataStream
133+ // if we are running an `explain (analyze)` command. We should update this section
134+ // to only use one or the other - not both.
135+ let plan_capture = stage. plan . clone ( ) ;
133136 let stream = with_callback ( stream, move |_| {
134137 // We need to hold a reference to the plan for at least as long as the stream is
135138 // execution. Some plans might store state necessary for the stream to work, and
136139 // dropping the plan early could drop this state too soon.
137- let _ = stage . plan ;
140+ let _ = plan_capture ;
138141 } ) ;
139142
140- Ok ( record_batch_stream_to_response ( Box :: pin (
141- RecordBatchStreamAdapter :: new ( schema, stream) ,
142- ) ) )
143+ let record_batch_stream = Box :: pin ( RecordBatchStreamAdapter :: new ( schema, stream) ) ;
144+ let task_data_capture = self . task_data_entries . clone ( ) ;
145+ Ok ( flight_stream_from_record_batch_stream (
146+ key. clone ( ) ,
147+ stage,
148+ stage_data. clone ( ) ,
149+ move || {
150+ task_data_capture. remove ( key. clone ( ) ) ;
151+ } ,
152+ record_batch_stream,
153+ ) )
143154 }
144155}
145156
146157fn missing ( field : & ' static str ) -> impl FnOnce ( ) -> Status {
147158 move || Status :: invalid_argument ( format ! ( "Missing field '{field}'" ) )
148159}
149160
150- fn record_batch_stream_to_response (
161+ // Creates a tonic response from a stream of record batches. Handles
162+ // - RecordBatch to flight conversion partition tracking, stage eviction, and trailing metrics.
163+ fn flight_stream_from_record_batch_stream (
164+ stage_key : StageKey ,
165+ stage : Arc < StageExec > ,
166+ stage_data : TaskData ,
167+ evict_stage : impl FnOnce ( ) + Send + ' static ,
151168 stream : SendableRecordBatchStream ,
152169) -> Response < <ArrowFlightEndpoint as FlightService >:: DoGetStream > {
153170 let flight_data_stream =
@@ -157,7 +174,31 @@ fn record_batch_stream_to_response(
157174 FlightError :: Tonic ( Box :: new ( datafusion_error_to_tonic_status ( & err) ) )
158175 } ) ) ;
159176
160- Response :: new ( Box :: pin ( flight_data_stream. map_err ( |err| match err {
177+ let trailing_metrics_stream = TrailingFlightDataStream :: new (
178+ move || {
179+ if stage_data
180+ . num_partitions_remaining
181+ . fetch_sub ( 1 , Ordering :: SeqCst )
182+ == 1
183+ {
184+ evict_stage ( ) ;
185+
186+ let metrics_stream = collect_and_create_metrics_flight_data ( stage_key, stage)
187+ . map_err ( |err| {
188+ Status :: internal ( format ! (
189+ "error collecting metrics in arrow flight endpoint: {err}"
190+ ) )
191+ } ) ?;
192+
193+ return Ok ( Some ( metrics_stream) ) ;
194+ }
195+
196+ Ok ( None )
197+ } ,
198+ flight_data_stream,
199+ ) ;
200+
201+ Response :: new ( Box :: pin ( trailing_metrics_stream. map_err ( |err| match err {
161202 FlightError :: Tonic ( status) => * status,
162203 _ => Status :: internal ( format ! ( "Error during flight stream: {err}" ) ) ,
163204 } ) ) )
@@ -228,24 +269,27 @@ mod tests {
228269 let stage_proto = proto_from_stage ( & stage, & DefaultPhysicalExtensionCodec { } ) . unwrap ( ) ;
229270 let stage_proto_for_closure = stage_proto. clone ( ) ;
230271 let endpoint_ref = & endpoint;
272+
231273 let do_get = async move |partition : u64 , task_number : u64 , stage_key : StageKey | {
232274 let stage_proto = stage_proto_for_closure. clone ( ) ;
233- // Create DoGet message
234275 let doget = DoGet {
235276 stage_proto : stage_proto. encode_to_vec ( ) . into ( ) ,
236277 target_task_index : task_number,
237278 target_partition : partition,
238279 stage_key : Some ( stage_key) ,
239280 } ;
240281
241- // Create Flight ticket
242282 let ticket = Ticket {
243283 ticket : Bytes :: from ( doget. encode_to_vec ( ) ) ,
244284 } ;
245285
246- // Call the actual get() method
247286 let request = Request :: new ( ticket) ;
248- endpoint_ref. get ( request) . await
287+ let response = endpoint_ref. get ( request) . await ?;
288+ let mut stream = response. into_inner ( ) ;
289+
290+ // Consume the stream.
291+ while let Some ( _flight_data) = stream. try_next ( ) . await ? { }
292+ Ok :: < ( ) , Status > ( ( ) )
249293 } ;
250294
251295 // For each task, call do_get() for each partition except the last.
@@ -261,22 +305,22 @@ mod tests {
261305
262306 // Run the last partition of task 0. Any partition number works. Verify that the task state
263307 // is evicted because all partitions have been processed.
264- let result = do_get ( 1 , 0 , task_keys[ 0 ] . clone ( ) ) . await ;
308+ let result = do_get ( 2 , 0 , task_keys[ 0 ] . clone ( ) ) . await ;
265309 assert ! ( result. is_ok( ) ) ;
266310 let stored_stage_keys = endpoint. task_data_entries . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
267311 assert_eq ! ( stored_stage_keys. len( ) , 2 ) ;
268312 assert ! ( stored_stage_keys. contains( & task_keys[ 1 ] ) ) ;
269313 assert ! ( stored_stage_keys. contains( & task_keys[ 2 ] ) ) ;
270314
271315 // Run the last partition of task 1.
272- let result = do_get ( 1 , 1 , task_keys[ 1 ] . clone ( ) ) . await ;
316+ let result = do_get ( 2 , 1 , task_keys[ 1 ] . clone ( ) ) . await ;
273317 assert ! ( result. is_ok( ) ) ;
274318 let stored_stage_keys = endpoint. task_data_entries . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
275319 assert_eq ! ( stored_stage_keys. len( ) , 1 ) ;
276320 assert ! ( stored_stage_keys. contains( & task_keys[ 2 ] ) ) ;
277321
278322 // Run the last partition of the last task.
279- let result = do_get ( 1 , 2 , task_keys[ 2 ] . clone ( ) ) . await ;
323+ let result = do_get ( 2 , 2 , task_keys[ 2 ] . clone ( ) ) . await ;
280324 assert ! ( result. is_ok( ) ) ;
281325 let stored_stage_keys = endpoint. task_data_entries . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
282326 assert_eq ! ( stored_stage_keys. len( ) , 0 ) ;
0 commit comments