@@ -3,7 +3,6 @@ use crate::config_extension_ext::ContextGrpcMetadata;
33use crate :: execution_plans:: { DistributedTaskContext , StageExec } ;
44use crate :: flight_service:: service:: ArrowFlightEndpoint ;
55use crate :: flight_service:: session_builder:: DistributedSessionBuilderContext ;
6- use crate :: flight_service:: trailing_flight_data_stream:: TrailingFlightDataStream ;
76use crate :: metrics:: TaskMetricsCollector ;
87use crate :: metrics:: proto:: df_metrics_set_to_proto;
98use crate :: protobuf:: {
@@ -16,18 +15,17 @@ use arrow_flight::encode::FlightDataEncoderBuilder;
1615use arrow_flight:: error:: FlightError ;
1716use arrow_flight:: flight_service_server:: FlightService ;
1817use bytes:: Bytes ;
19- use datafusion:: arrow:: array:: RecordBatch ;
20- use datafusion:: arrow:: datatypes:: SchemaRef ;
21- use datafusion:: arrow:: ipc:: writer:: { DictionaryTracker , IpcDataGenerator , IpcWriteOptions } ;
18+
2219use datafusion:: common:: exec_datafusion_err;
2320use datafusion:: execution:: SendableRecordBatchStream ;
2421use datafusion:: physical_plan:: stream:: RecordBatchStreamAdapter ;
2522use datafusion:: prelude:: SessionContext ;
26- use futures:: TryStreamExt ;
27- use futures:: { Stream , stream } ;
23+ use futures:: stream ;
24+ use futures:: { StreamExt , TryStreamExt } ;
2825use prost:: Message ;
2926use std:: sync:: Arc ;
3027use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
28+ use std:: task:: Poll ;
3129use tonic:: { Request , Response , Status } ;
3230
3331#[ derive( Clone , PartialEq , :: prost:: Message ) ]
@@ -173,51 +171,82 @@ fn flight_stream_from_record_batch_stream(
173171 evict_stage : impl FnOnce ( ) + Send + ' static ,
174172 stream : SendableRecordBatchStream ,
175173) -> Response < <ArrowFlightEndpoint as FlightService >:: DoGetStream > {
176- let flight_data_stream =
174+ let mut flight_data_stream =
177175 FlightDataEncoderBuilder :: new ( )
178176 . with_schema ( stream. schema ( ) . clone ( ) )
179177 . build ( stream. map_err ( |err| {
180178 FlightError :: Tonic ( Box :: new ( datafusion_error_to_tonic_status ( & err) ) )
181179 } ) ) ;
182180
183- let trailing_metrics_stream = TrailingFlightDataStream :: new (
184- move || {
185- if stage_data
186- . num_partitions_remaining
187- . fetch_sub ( 1 , Ordering :: SeqCst )
188- == 1
189- {
190- evict_stage ( ) ;
191-
192- let metrics_stream =
193- collect_and_create_metrics_flight_data ( stage_key, stage_data. stage ) . map_err (
194- |err| {
195- Status :: internal ( format ! (
196- "error collecting metrics in arrow flight endpoint: {err}"
197- ) )
198- } ,
199- ) ?;
200-
201- return Ok ( Some ( metrics_stream) ) ;
202- }
203-
204- Ok ( None )
205- } ,
206- flight_data_stream,
207- ) ;
181+ // executed once when the stream ends
182+ // decorates the last flight data with our metrics
183+ let mut final_closure = Some ( move |last_flight_data| {
184+ if stage_data
185+ . num_partitions_remaining
186+ . fetch_sub ( 1 , Ordering :: SeqCst )
187+ == 1
188+ {
189+ evict_stage ( ) ;
190+
191+ collect_and_create_metrics_flight_data ( stage_key, stage_data. stage , last_flight_data)
192+ } else {
193+ Ok ( last_flight_data)
194+ }
195+ } ) ;
196+
197+ // this is used to peek the new value
198+ // so that we can add our metrics to the last flight data
199+ let mut current_value = None ;
200+
201+ let stream =
202+ stream:: poll_fn (
203+ move |cx| match futures:: ready!( flight_data_stream. poll_next_unpin( cx) ) {
204+ Some ( Ok ( new_val) ) => {
205+ match current_value. take ( ) {
206+ // This is the first value, so we store it and repoll to get the next value
207+ None => {
208+ current_value = Some ( new_val) ;
209+ cx. waker ( ) . wake_by_ref ( ) ;
210+ Poll :: Pending
211+ }
212+
213+ Some ( existing) => {
214+ current_value = Some ( new_val) ;
215+
216+ Poll :: Ready ( Some ( Ok ( existing) ) )
217+ }
218+ }
219+ }
220+ // this is our last value, so we add our metrics to this flight data
221+ None => match current_value. take ( ) {
222+ Some ( existing) => {
223+ // make sure we wake ourselves to finish the stream
224+ cx. waker ( ) . wake_by_ref ( ) ;
225+
226+ if let Some ( closure) = final_closure. take ( ) {
227+ Poll :: Ready ( Some ( closure ( existing) ) )
228+ } else {
229+ unreachable ! ( "the closure is only executed once" )
230+ }
231+ }
232+ None => Poll :: Ready ( None ) ,
233+ } ,
234+ err => Poll :: Ready ( err) ,
235+ } ,
236+ ) ;
208237
209- Response :: new ( Box :: pin ( trailing_metrics_stream . map_err ( |err| match err {
238+ Response :: new ( Box :: pin ( stream . map_err ( |err| match err {
210239 FlightError :: Tonic ( status) => * status,
211240 _ => Status :: internal ( format ! ( "Error during flight stream: {err}" ) ) ,
212241 } ) ) )
213242}
214243
215- // Collects metrics from the provided stage and encodes it into a stream of flight data using
216- // the schema of the stage.
244+ /// Collects metrics from the provided stage and includes it in the flight data
217245fn collect_and_create_metrics_flight_data (
218246 stage_key : StageKey ,
219247 stage : Arc < StageExec > ,
220- ) -> Result < impl Stream < Item = Result < FlightData , FlightError > > + Send + ' static , FlightError > {
248+ incoming : FlightData ,
249+ ) -> Result < FlightData , FlightError > {
221250 // Get the metrics for the task executed on this worker. Separately, collect metrics for child tasks.
222251 let mut result = TaskMetricsCollector :: new ( )
223252 . collect ( stage. plan . clone ( ) )
@@ -252,35 +281,12 @@ fn collect_and_create_metrics_flight_data(
252281 } ) ) ,
253282 } ;
254283
255- let metrics_flight_data =
256- empty_flight_data_with_app_metadata ( flight_app_metadata, stage. plan . schema ( ) ) ?;
257- Ok ( Box :: pin ( stream:: once (
258- async move { Ok ( metrics_flight_data) } ,
259- ) ) )
260- }
261-
262- /// Creates a FlightData with the given app_metadata and empty RecordBatch using the provided schema.
263- /// We don't use [arrow_flight::encode::FlightDataEncoder] (and by extension, the [arrow_flight::encode::FlightDataEncoderBuilder])
264- /// since they skip messages with empty RecordBatch data.
265- pub fn empty_flight_data_with_app_metadata (
266- metadata : FlightAppMetadata ,
267- schema : SchemaRef ,
268- ) -> Result < FlightData , FlightError > {
269284 let mut buf = vec ! [ ] ;
270- metadata
285+ flight_app_metadata
271286 . encode ( & mut buf)
272287 . map_err ( |err| FlightError :: ProtocolError ( err. to_string ( ) ) ) ?;
273288
274- let empty_batch = RecordBatch :: new_empty ( schema) ;
275- let options = IpcWriteOptions :: default ( ) ;
276- let data_gen = IpcDataGenerator :: default ( ) ;
277- let mut dictionary_tracker = DictionaryTracker :: new ( false ) ;
278- let ( _, encoded_data) = data_gen
279- . encoded_batch ( & empty_batch, & mut dictionary_tracker, & options)
280- . map_err ( |e| {
281- FlightError :: ProtocolError ( format ! ( "Failed to create empty batch FlightData: {e}" ) )
282- } ) ?;
283- Ok ( FlightData :: from ( encoded_data) . with_app_metadata ( buf) )
289+ Ok ( incoming. with_app_metadata ( buf) )
284290}
285291
286292#[ cfg( test) ]
0 commit comments