1- use crate :: common:: with_callback ;
1+ use crate :: common:: map_last_stream ;
22use crate :: config_extension_ext:: ContextGrpcMetadata ;
33use crate :: execution_plans:: { DistributedTaskContext , StageExec } ;
44use crate :: flight_service:: service:: ArrowFlightEndpoint ;
@@ -17,15 +17,11 @@ use arrow_flight::flight_service_server::FlightService;
1717use bytes:: Bytes ;
1818
1919use datafusion:: common:: exec_datafusion_err;
20- use datafusion:: execution:: SendableRecordBatchStream ;
21- use datafusion:: physical_plan:: stream:: RecordBatchStreamAdapter ;
2220use datafusion:: prelude:: SessionContext ;
23- use futures:: stream;
24- use futures:: { StreamExt , TryStreamExt } ;
21+ use futures:: TryStreamExt ;
2522use prost:: Message ;
2623use std:: sync:: Arc ;
2724use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
28- use std:: task:: Poll ;
2925use tonic:: { Request , Response , Status } ;
3026
3127#[ derive( Clone , PartialEq , :: prost:: Message ) ]
@@ -132,115 +128,33 @@ impl ArrowFlightEndpoint {
132128 . execute ( doget. target_partition as usize , session_state. task_ctx ( ) )
133129 . map_err ( |err| Status :: internal ( format ! ( "Error executing stage plan: {err:#?}" ) ) ) ?;
134130
135- let schema = stream. schema ( ) ;
136-
137- // TODO: We don't need to do this since the stage / plan is captured again by the
138- // TrailingFlightDataStream. However, we will eventuall only use the TrailingFlightDataStream
139- // if we are running an `explain (analyze)` command. We should update this section
140- // to only use one or the other - not both.
141- let plan_capture = stage. plan . clone ( ) ;
142- let stream = with_callback ( stream, move |_| {
143- // We need to hold a reference to the plan for at least as long as the stream is
144- // execution. Some plans might store state necessary for the stream to work, and
145- // dropping the plan early could drop this state too soon.
146- let _ = plan_capture;
131+ let stream = FlightDataEncoderBuilder :: new ( )
132+ . with_schema ( stream. schema ( ) . clone ( ) )
133+ . build ( stream. map_err ( |err| {
134+ FlightError :: Tonic ( Box :: new ( datafusion_error_to_tonic_status ( & err) ) )
135+ } ) ) ;
136+
137+ let task_data_entries = Arc :: clone ( & self . task_data_entries ) ;
138+ let num_partitions_remaining = Arc :: clone ( & stage_data. num_partitions_remaining ) ;
139+
140+ let stream = map_last_stream ( stream, move |last| {
141+ if num_partitions_remaining. fetch_sub ( 1 , Ordering :: SeqCst ) == 1 {
142+ task_data_entries. remove ( key. clone ( ) ) ;
143+ }
144+ last. and_then ( |el| collect_and_create_metrics_flight_data ( key, stage, el) )
147145 } ) ;
148146
149- let record_batch_stream = Box :: pin ( RecordBatchStreamAdapter :: new ( schema, stream) ) ;
150- let task_data_capture = self . task_data_entries . clone ( ) ;
151- Ok ( flight_stream_from_record_batch_stream (
152- key. clone ( ) ,
153- stage_data. clone ( ) ,
154- move || {
155- task_data_capture. remove ( key. clone ( ) ) ;
156- } ,
157- record_batch_stream,
158- ) )
147+ Ok ( Response :: new ( Box :: pin ( stream. map_err ( |err| match err {
148+ FlightError :: Tonic ( status) => * status,
149+ _ => Status :: internal ( format ! ( "Error during flight stream: {err}" ) ) ,
150+ } ) ) ) )
159151 }
160152}
161153
162154fn missing ( field : & ' static str ) -> impl FnOnce ( ) -> Status {
163155 move || Status :: invalid_argument ( format ! ( "Missing field '{field}'" ) )
164156}
165157
166- /// Creates a tonic response from a stream of record batches. Handles
167- /// - RecordBatch to flight conversion partition tracking, stage eviction, and trailing metrics.
168- fn flight_stream_from_record_batch_stream (
169- stage_key : StageKey ,
170- stage_data : TaskData ,
171- evict_stage : impl FnOnce ( ) + Send + ' static ,
172- stream : SendableRecordBatchStream ,
173- ) -> Response < <ArrowFlightEndpoint as FlightService >:: DoGetStream > {
174- let mut flight_data_stream =
175- FlightDataEncoderBuilder :: new ( )
176- . with_schema ( stream. schema ( ) . clone ( ) )
177- . build ( stream. map_err ( |err| {
178- FlightError :: Tonic ( Box :: new ( datafusion_error_to_tonic_status ( & err) ) )
179- } ) ) ;
180-
181- // executed once when the stream ends
182- // decorates the last flight data with our metrics
183- let mut final_closure = Some ( move |last_flight_data| {
184- if stage_data
185- . num_partitions_remaining
186- . fetch_sub ( 1 , Ordering :: SeqCst )
187- == 1
188- {
189- evict_stage ( ) ;
190-
191- collect_and_create_metrics_flight_data ( stage_key, stage_data. stage , last_flight_data)
192- } else {
193- Ok ( last_flight_data)
194- }
195- } ) ;
196-
197- // this is used to peek the new value
198- // so that we can add our metrics to the last flight data
199- let mut current_value = None ;
200-
201- let stream =
202- stream:: poll_fn (
203- move |cx| match futures:: ready!( flight_data_stream. poll_next_unpin( cx) ) {
204- Some ( Ok ( new_val) ) => {
205- match current_value. take ( ) {
206- // This is the first value, so we store it and repoll to get the next value
207- None => {
208- current_value = Some ( new_val) ;
209- cx. waker ( ) . wake_by_ref ( ) ;
210- Poll :: Pending
211- }
212-
213- Some ( existing) => {
214- current_value = Some ( new_val) ;
215-
216- Poll :: Ready ( Some ( Ok ( existing) ) )
217- }
218- }
219- }
220- // this is our last value, so we add our metrics to this flight data
221- None => match current_value. take ( ) {
222- Some ( existing) => {
223- // make sure we wake ourselves to finish the stream
224- cx. waker ( ) . wake_by_ref ( ) ;
225-
226- if let Some ( closure) = final_closure. take ( ) {
227- Poll :: Ready ( Some ( closure ( existing) ) )
228- } else {
229- unreachable ! ( "the closure is only executed once" )
230- }
231- }
232- None => Poll :: Ready ( None ) ,
233- } ,
234- err => Poll :: Ready ( err) ,
235- } ,
236- ) ;
237-
238- Response :: new ( Box :: pin ( stream. map_err ( |err| match err {
239- FlightError :: Tonic ( status) => * status,
240- _ => Status :: internal ( format ! ( "Error during flight stream: {err}" ) ) ,
241- } ) ) )
242- }
243-
244158/// Collects metrics from the provided stage and includes it in the flight data
245159fn collect_and_create_metrics_flight_data (
246160 stage_key : StageKey ,
0 commit comments