@@ -3,9 +3,17 @@ use crate::config_extension_ext::ContextGrpcMetadata;
33use crate :: execution_plans:: { DistributedTaskContext , StageExec } ;
44use crate :: flight_service:: service:: ArrowFlightEndpoint ;
55use crate :: flight_service:: session_builder:: DistributedSessionBuilderContext ;
6+ use crate :: flight_service:: trailing_flight_data_stream:: TrailingFlightDataStream ;
7+ use crate :: metrics:: TaskMetricsCollector ;
8+ use crate :: metrics:: proto:: df_metrics_set_to_proto;
69use crate :: protobuf:: {
7- DistributedCodec , StageKey , datafusion_error_to_tonic_status, stage_from_proto,
10+ AppMetadata , DistributedCodec , FlightAppMetadata , MetricsCollection , StageKey , TaskMetrics ,
11+ datafusion_error_to_tonic_status, stage_from_proto,
812} ;
13+ use arrow:: array:: RecordBatch ;
14+ use arrow:: datatypes:: SchemaRef ;
15+ use arrow:: ipc:: writer:: { DictionaryTracker , IpcDataGenerator , IpcWriteOptions } ;
16+ use arrow_flight:: FlightData ;
917use arrow_flight:: Ticket ;
1018use arrow_flight:: encode:: FlightDataEncoderBuilder ;
1119use arrow_flight:: error:: FlightError ;
@@ -15,6 +23,7 @@ use datafusion::common::exec_datafusion_err;
1523use datafusion:: execution:: SendableRecordBatchStream ;
1624use datafusion:: physical_plan:: stream:: RecordBatchStreamAdapter ;
1725use futures:: TryStreamExt ;
26+ use futures:: { Stream , stream} ;
1827use prost:: Message ;
1928use std:: sync:: Arc ;
2029use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
@@ -97,12 +106,6 @@ impl ArrowFlightEndpoint {
97106 } )
98107 . await ?;
99108 let stage = Arc :: clone ( & stage_data. stage ) ;
100- let num_partitions_remaining = Arc :: clone ( & stage_data. num_partitions_remaining ) ;
101-
102- // If all the partitions are done, remove the stage from the cache.
103- if num_partitions_remaining. fetch_sub ( 1 , Ordering :: SeqCst ) <= 1 {
104- self . task_data_entries . remove ( key) ;
105- }
106109
107110 // Find out which partition group we are executing
108111 let cfg = session_state. config_mut ( ) ;
@@ -130,24 +133,42 @@ impl ArrowFlightEndpoint {
130133 . map_err ( |err| Status :: internal ( format ! ( "Error executing stage plan: {err:#?}" ) ) ) ?;
131134
132135 let schema = stream. schema ( ) ;
136+
137+ // TODO: We don't need to do this since the stage / plan is captured again by the
138+ // TrailingFlightDataStream. However, we will eventuall only use the TrailingFlightDataStream
139+ // if we are running an `explain (analyze)` command. We should update this section
140+ // to only use one or the other - not both.
141+ let plan_capture = stage. plan . clone ( ) ;
133142 let stream = with_callback ( stream, move |_| {
134143 // We need to hold a reference to the plan for at least as long as the stream is
135144 // execution. Some plans might store state necessary for the stream to work, and
136145 // dropping the plan early could drop this state too soon.
137- let _ = stage . plan ;
146+ let _ = plan_capture ;
138147 } ) ;
139148
140- Ok ( record_batch_stream_to_response ( Box :: pin (
141- RecordBatchStreamAdapter :: new ( schema, stream) ,
142- ) ) )
149+ let record_batch_stream = Box :: pin ( RecordBatchStreamAdapter :: new ( schema, stream) ) ;
150+ let task_data_capture = self . task_data_entries . clone ( ) ;
151+ Ok ( flight_stream_from_record_batch_stream (
152+ key. clone ( ) ,
153+ stage_data. clone ( ) ,
154+ move || {
155+ task_data_capture. remove ( key. clone ( ) ) ;
156+ } ,
157+ record_batch_stream,
158+ ) )
143159 }
144160}
145161
146162fn missing ( field : & ' static str ) -> impl FnOnce ( ) -> Status {
147163 move || Status :: invalid_argument ( format ! ( "Missing field '{field}'" ) )
148164}
149165
150- fn record_batch_stream_to_response (
166+ /// Creates a tonic response from a stream of record batches. Handles
167+ /// - RecordBatch to flight conversion partition tracking, stage eviction, and trailing metrics.
168+ fn flight_stream_from_record_batch_stream (
169+ stage_key : StageKey ,
170+ stage_data : TaskData ,
171+ evict_stage : impl FnOnce ( ) + Send + ' static ,
151172 stream : SendableRecordBatchStream ,
152173) -> Response < <ArrowFlightEndpoint as FlightService >:: DoGetStream > {
153174 let flight_data_stream =
@@ -157,12 +178,109 @@ fn record_batch_stream_to_response(
157178 FlightError :: Tonic ( Box :: new ( datafusion_error_to_tonic_status ( & err) ) )
158179 } ) ) ;
159180
160- Response :: new ( Box :: pin ( flight_data_stream. map_err ( |err| match err {
181+ let trailing_metrics_stream = TrailingFlightDataStream :: new (
182+ move || {
183+ if stage_data
184+ . num_partitions_remaining
185+ . fetch_sub ( 1 , Ordering :: SeqCst )
186+ == 1
187+ {
188+ evict_stage ( ) ;
189+
190+ let metrics_stream =
191+ collect_and_create_metrics_flight_data ( stage_key, stage_data. stage ) . map_err (
192+ |err| {
193+ Status :: internal ( format ! (
194+ "error collecting metrics in arrow flight endpoint: {err}"
195+ ) )
196+ } ,
197+ ) ?;
198+
199+ return Ok ( Some ( metrics_stream) ) ;
200+ }
201+
202+ Ok ( None )
203+ } ,
204+ flight_data_stream,
205+ ) ;
206+
207+ Response :: new ( Box :: pin ( trailing_metrics_stream. map_err ( |err| match err {
161208 FlightError :: Tonic ( status) => * status,
162209 _ => Status :: internal ( format ! ( "Error during flight stream: {err}" ) ) ,
163210 } ) ) )
164211}
165212
213+ // Collects metrics from the provided stage and encodes it into a stream of flight data using
214+ // the schema of the stage.
215+ fn collect_and_create_metrics_flight_data (
216+ stage_key : StageKey ,
217+ stage : Arc < StageExec > ,
218+ ) -> Result < impl Stream < Item = Result < FlightData , FlightError > > + Send + ' static , FlightError > {
219+ // Get the metrics for the task executed on this worker. Separately, collect metrics for child tasks.
220+ let mut result = TaskMetricsCollector :: new ( )
221+ . collect ( stage. plan . clone ( ) )
222+ . map_err ( |err| FlightError :: ProtocolError ( err. to_string ( ) ) ) ?;
223+
224+ // Add the metrics for this task into the collection of task metrics.
225+ // Skip any metrics that can't be converted to proto (unsupported types)
226+ let proto_task_metrics = result
227+ . task_metrics
228+ . iter ( )
229+ . map ( |metrics| {
230+ df_metrics_set_to_proto ( metrics)
231+ . map_err ( |err| FlightError :: ProtocolError ( err. to_string ( ) ) )
232+ } )
233+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
234+ result
235+ . input_task_metrics
236+ . insert ( stage_key, proto_task_metrics) ;
237+
238+ // Serialize the metrics for all tasks.
239+ let mut task_metrics_set = vec ! [ ] ;
240+ for ( stage_key, metrics) in result. input_task_metrics . into_iter ( ) {
241+ task_metrics_set. push ( TaskMetrics {
242+ stage_key : Some ( stage_key) ,
243+ metrics,
244+ } ) ;
245+ }
246+
247+ let flight_app_metadata = FlightAppMetadata {
248+ content : Some ( AppMetadata :: MetricsCollection ( MetricsCollection {
249+ tasks : task_metrics_set,
250+ } ) ) ,
251+ } ;
252+
253+ let metrics_flight_data =
254+ empty_flight_data_with_app_metadata ( flight_app_metadata, stage. plan . schema ( ) ) ?;
255+ Ok ( Box :: pin ( stream:: once (
256+ async move { Ok ( metrics_flight_data) } ,
257+ ) ) )
258+ }
259+
260+ /// Creates a FlightData with the given app_metadata and empty RecordBatch using the provided schema.
261+ /// We don't use [arrow_flight::encode::FlightDataEncoder] (and by extension, the [arrow_flight::encode::FlightDataEncoderBuilder])
262+ /// since they skip messages with empty RecordBatch data.
263+ pub fn empty_flight_data_with_app_metadata (
264+ metadata : FlightAppMetadata ,
265+ schema : SchemaRef ,
266+ ) -> Result < FlightData , FlightError > {
267+ let mut buf = vec ! [ ] ;
268+ metadata
269+ . encode ( & mut buf)
270+ . map_err ( |err| FlightError :: ProtocolError ( err. to_string ( ) ) ) ?;
271+
272+ let empty_batch = RecordBatch :: new_empty ( schema) ;
273+ let options = IpcWriteOptions :: default ( ) ;
274+ let data_gen = IpcDataGenerator :: default ( ) ;
275+ let mut dictionary_tracker = DictionaryTracker :: new ( true ) ;
276+ let ( _, encoded_data) = data_gen
277+ . encoded_batch ( & empty_batch, & mut dictionary_tracker, & options)
278+ . map_err ( |e| {
279+ FlightError :: ProtocolError ( format ! ( "Failed to create empty batch FlightData: {e}" ) )
280+ } ) ?;
281+ Ok ( FlightData :: from ( encoded_data) . with_app_metadata ( buf) )
282+ }
283+
166284#[ cfg( test) ]
167285mod tests {
168286 use super :: * ;
@@ -228,24 +346,27 @@ mod tests {
228346 let stage_proto = proto_from_stage ( & stage, & DefaultPhysicalExtensionCodec { } ) . unwrap ( ) ;
229347 let stage_proto_for_closure = stage_proto. clone ( ) ;
230348 let endpoint_ref = & endpoint;
349+
231350 let do_get = async move |partition : u64 , task_number : u64 , stage_key : StageKey | {
232351 let stage_proto = stage_proto_for_closure. clone ( ) ;
233- // Create DoGet message
234352 let doget = DoGet {
235353 stage_proto : stage_proto. encode_to_vec ( ) . into ( ) ,
236354 target_task_index : task_number,
237355 target_partition : partition,
238356 stage_key : Some ( stage_key) ,
239357 } ;
240358
241- // Create Flight ticket
242359 let ticket = Ticket {
243360 ticket : Bytes :: from ( doget. encode_to_vec ( ) ) ,
244361 } ;
245362
246- // Call the actual get() method
247363 let request = Request :: new ( ticket) ;
248- endpoint_ref. get ( request) . await
364+ let response = endpoint_ref. get ( request) . await ?;
365+ let mut stream = response. into_inner ( ) ;
366+
367+ // Consume the stream.
368+ while let Some ( _flight_data) = stream. try_next ( ) . await ? { }
369+ Ok :: < ( ) , Status > ( ( ) )
249370 } ;
250371
251372 // For each task, call do_get() for each partition except the last.
@@ -261,22 +382,22 @@ mod tests {
261382
262383 // Run the last partition of task 0. Any partition number works. Verify that the task state
263384 // is evicted because all partitions have been processed.
264- let result = do_get ( 1 , 0 , task_keys[ 0 ] . clone ( ) ) . await ;
385+ let result = do_get ( 2 , 0 , task_keys[ 0 ] . clone ( ) ) . await ;
265386 assert ! ( result. is_ok( ) ) ;
266387 let stored_stage_keys = endpoint. task_data_entries . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
267388 assert_eq ! ( stored_stage_keys. len( ) , 2 ) ;
268389 assert ! ( stored_stage_keys. contains( & task_keys[ 1 ] ) ) ;
269390 assert ! ( stored_stage_keys. contains( & task_keys[ 2 ] ) ) ;
270391
271392 // Run the last partition of task 1.
272- let result = do_get ( 1 , 1 , task_keys[ 1 ] . clone ( ) ) . await ;
393+ let result = do_get ( 2 , 1 , task_keys[ 1 ] . clone ( ) ) . await ;
273394 assert ! ( result. is_ok( ) ) ;
274395 let stored_stage_keys = endpoint. task_data_entries . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
275396 assert_eq ! ( stored_stage_keys. len( ) , 1 ) ;
276397 assert ! ( stored_stage_keys. contains( & task_keys[ 2 ] ) ) ;
277398
278399 // Run the last partition of the last task.
279- let result = do_get ( 1 , 2 , task_keys[ 2 ] . clone ( ) ) . await ;
400+ let result = do_get ( 2 , 2 , task_keys[ 2 ] . clone ( ) ) . await ;
280401 assert ! ( result. is_ok( ) ) ;
281402 let stored_stage_keys = endpoint. task_data_entries . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
282403 assert_eq ! ( stored_stage_keys. len( ) , 0 ) ;
0 commit comments