@@ -9,13 +9,14 @@ use arrow_flight::encode::FlightDataEncoderBuilder;
99use arrow_flight:: error:: FlightError ;
1010use arrow_flight:: flight_service_server:: FlightService ;
1111use arrow_flight:: Ticket ;
12- use datafusion:: execution:: SessionState ;
12+ use datafusion:: execution:: { SendableRecordBatchStream , SessionState } ;
1313use futures:: TryStreamExt ;
14+ use http:: HeaderMap ;
1415use prost:: Message ;
16+ use std:: fmt:: Display ;
1517use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
1618use std:: sync:: Arc ;
1719use tokio:: sync:: OnceCell ;
18- use tonic:: metadata:: MetadataMap ;
1920use tonic:: { Request , Response , Status } ;
2021
2122#[ derive( Clone , PartialEq , :: prost:: Message ) ]
@@ -41,9 +42,9 @@ pub struct DoGet {
4142/// TaskData stores state for a single task being executed by this Endpoint. It may be shared
4243/// by concurrent requests for the same task which execute separate partitions.
4344pub struct TaskData {
44- pub ( super ) state : SessionState ,
45+ pub ( super ) session_state : SessionState ,
4546 pub ( super ) stage : Arc < StageExec > ,
46- ///num_partitions_remaining is initialized to the total number of partitions in the task (not
47+ /// ` num_partitions_remaining` is initialized to the total number of partitions in the task (not
4748 /// only tasks in the partition group). This is decremented for each request to the endpoint
4849 /// for this task. Once this count is zero, the task is likely complete. The task may not be
4950 /// complete because it's possible that the same partition was retried and this count was
@@ -56,98 +57,78 @@ impl ArrowFlightEndpoint {
5657 & self ,
5758 request : Request < Ticket > ,
5859 ) -> Result < Response < <ArrowFlightEndpoint as FlightService >:: DoGetStream > , Status > {
59- let ( metadata, _ext, ticket) = request. into_parts ( ) ;
60- let Ticket { ticket } = ticket;
61- let doget = DoGet :: decode ( ticket) . map_err ( |err| {
60+ let ( metadata, _ext, body) = request. into_parts ( ) ;
61+ let doget = DoGet :: decode ( body. ticket ) . map_err ( |err| {
6262 Status :: invalid_argument ( format ! ( "Cannot decode DoGet message: {err}" ) )
6363 } ) ?;
6464
65+ // There's only 1 `StageExec` responsible for all requests that share the same `stage_key`,
66+ // so here we either retrieve the existing one or create a new one if it does not exist.
67+ let ( mut session_state, stage) = self
68+ . get_state_and_stage (
69+ doget. stage_key . ok_or_else ( missing ( "stage_key" ) ) ?,
70+ doget. stage_proto . ok_or_else ( missing ( "stage_proto" ) ) ?,
71+ metadata. clone ( ) . into_headers ( ) ,
72+ )
73+ . await ?;
74+
75+ // Find out which partition group we are executing
6576 let partition = doget. partition as usize ;
6677 let task_number = doget. task_number as usize ;
67- let task_data = self . get_state_and_stage ( doget, metadata) . await ?;
68-
69- let stage = task_data. stage ;
70- let mut state = task_data. state ;
71-
72- // find out which partition group we are executing
73- let task = stage
74- . tasks
75- . get ( task_number)
76- . ok_or ( Status :: invalid_argument ( format ! (
77- "Task number {} not found in stage {}" ,
78- task_number,
79- stage. name( )
80- ) ) ) ?;
81-
82- let partition_group = PartitionGroup ( task. partition_group . clone ( ) ) ;
83- state. config_mut ( ) . set_extension ( Arc :: new ( partition_group) ) ;
84-
85- let inner_plan = stage. plan . clone ( ) ;
86-
87- let stream = inner_plan
88- . execute ( partition, state. task_ctx ( ) )
78+ let task = stage. tasks . get ( task_number) . ok_or_else ( invalid ( format ! (
79+ "Task number {task_number} not found in stage {}" ,
80+ stage. num
81+ ) ) ) ?;
82+
83+ let cfg = session_state. config_mut ( ) ;
84+ cfg. set_extension ( Arc :: new ( PartitionGroup ( task. partition_group . clone ( ) ) ) ) ;
85+ cfg. set_extension ( Arc :: clone ( & stage) ) ;
86+ cfg. set_extension ( Arc :: new ( ContextGrpcMetadata ( metadata. into_headers ( ) ) ) ) ;
87+
88+ // Rather than executing the `StageExec` itself, we want to execute the inner plan instead,
89+ // as executing `StageExec` performs some worker assignation that should have already been
90+ // done in the head stage.
91+ let stream = stage
92+ . plan
93+ . execute ( partition, session_state. task_ctx ( ) )
8994 . map_err ( |err| Status :: internal ( format ! ( "Error executing stage plan: {err:#?}" ) ) ) ?;
9095
91- let flight_data_stream = FlightDataEncoderBuilder :: new ( )
92- . with_schema ( inner_plan. schema ( ) . clone ( ) )
93- . build ( stream. map_err ( |err| {
94- FlightError :: Tonic ( Box :: new ( datafusion_error_to_tonic_status ( & err) ) )
95- } ) ) ;
96-
97- Ok ( Response :: new ( Box :: pin ( flight_data_stream. map_err (
98- |err| match err {
99- FlightError :: Tonic ( status) => * status,
100- _ => Status :: internal ( format ! ( "Error during flight stream: {err}" ) ) ,
101- } ,
102- ) ) ) )
96+ Ok ( record_batch_stream_to_response ( stream) )
10397 }
10498
10599 async fn get_state_and_stage (
106100 & self ,
107- doget : DoGet ,
108- metadata_map : MetadataMap ,
109- ) -> Result < TaskData , Status > {
110- let key = doget
111- . stage_key
112- . ok_or ( Status :: invalid_argument ( "DoGet is missing the stage key" ) ) ?;
113- let once_stage = self
114- . stages
101+ key : StageKey ,
102+ stage_proto : StageExecProto ,
103+ headers : HeaderMap ,
104+ ) -> Result < ( SessionState , Arc < StageExec > ) , Status > {
105+ let once = self
106+ . task_data_entries
115107 . get_or_init ( key. clone ( ) , || Arc :: new ( OnceCell :: < TaskData > :: new ( ) ) ) ;
116108
117- let stage_data = once_stage
109+ let stage_data = once
118110 . get_or_try_init ( || async {
119- let stage_proto = doget
120- . stage_proto
121- . ok_or ( Status :: invalid_argument ( "DoGet is missing the stage proto" ) ) ?;
122-
123- let headers = metadata_map. into_headers ( ) ;
124- let mut state = self
111+ let session_state = self
125112 . session_builder
126113 . build_session_state ( DistributedSessionBuilderContext {
127114 runtime_env : Arc :: clone ( & self . runtime ) ,
128- headers : headers . clone ( ) ,
115+ headers,
129116 } )
130117 . await
131118 . map_err ( |err| datafusion_error_to_tonic_status ( & err) ) ?;
132119
133- let codec = DistributedCodec :: new_combined_with_user ( state . config ( ) ) ;
120+ let codec = DistributedCodec :: new_combined_with_user ( session_state . config ( ) ) ;
134121
135- let stage = stage_from_proto ( stage_proto, & state, self . runtime . as_ref ( ) , & codec)
136- . map ( Arc :: new)
122+ let stage = stage_from_proto ( stage_proto, & session_state, & self . runtime , & codec)
137123 . map_err ( |err| {
138124 Status :: invalid_argument ( format ! ( "Cannot decode stage proto: {err}" ) )
139125 } ) ?;
140126
141- // Add the extensions that might be required for ExecutionPlan nodes in the plan
142- let config = state. config_mut ( ) ;
143- config. set_extension ( stage. clone ( ) ) ;
144- config. set_extension ( Arc :: new ( ContextGrpcMetadata ( headers) ) ) ;
145-
146127 // Initialize partition count to the number of partitions in the stage
147128 let total_partitions = stage. plan . properties ( ) . partitioning . partition_count ( ) ;
148129 Ok :: < _ , Status > ( TaskData {
149- state ,
150- stage,
130+ session_state ,
131+ stage : Arc :: new ( stage ) ,
151132 num_partitions_remaining : Arc :: new ( AtomicUsize :: new ( total_partitions) ) ,
152133 } )
153134 } )
@@ -158,13 +139,37 @@ impl ArrowFlightEndpoint {
158139 . num_partitions_remaining
159140 . fetch_sub ( 1 , Ordering :: SeqCst ) ;
160141 if remaining_partitions <= 1 {
161- self . stages . remove ( key. clone ( ) ) ;
142+ self . task_data_entries . remove ( key) ;
162143 }
163144
164- Ok ( stage_data. clone ( ) )
145+ Ok ( ( stage_data. session_state . clone ( ) , stage_data . stage . clone ( ) ) )
165146 }
166147}
167148
149+ fn missing ( field : & ' static str ) -> impl FnOnce ( ) -> Status {
150+ move || Status :: invalid_argument ( format ! ( "Missing field '{field}'" ) )
151+ }
152+
153+ fn invalid ( msg : impl Display ) -> impl FnOnce ( ) -> Status {
154+ move || Status :: invalid_argument ( msg. to_string ( ) )
155+ }
156+
157+ fn record_batch_stream_to_response (
158+ stream : SendableRecordBatchStream ,
159+ ) -> Response < <ArrowFlightEndpoint as FlightService >:: DoGetStream > {
160+ let flight_data_stream =
161+ FlightDataEncoderBuilder :: new ( )
162+ . with_schema ( stream. schema ( ) . clone ( ) )
163+ . build ( stream. map_err ( |err| {
164+ FlightError :: Tonic ( Box :: new ( datafusion_error_to_tonic_status ( & err) ) )
165+ } ) ) ;
166+
167+ Response :: new ( Box :: pin ( flight_data_stream. map_err ( |err| match err {
168+ FlightError :: Tonic ( status) => * status,
169+ _ => Status :: internal ( format ! ( "Error during flight stream: {err}" ) ) ,
170+ } ) ) )
171+ }
172+
168173#[ cfg( test) ]
169174mod tests {
170175 use super :: * ;
@@ -262,28 +267,28 @@ mod tests {
262267 }
263268
264269 // Check that the endpoint has not evicted any task states.
265- assert_eq ! ( endpoint. stages . len( ) , num_tasks) ;
270+ assert_eq ! ( endpoint. task_data_entries . len( ) , num_tasks) ;
266271
267272 // Run the last partition of task 0. Any partition number works. Verify that the task state
268273 // is evicted because all partitions have been processed.
269274 let result = do_get ( 1 , 0 , task_keys[ 0 ] . clone ( ) ) . await ;
270275 assert ! ( result. is_ok( ) ) ;
271- let stored_stage_keys = endpoint. stages . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
276+ let stored_stage_keys = endpoint. task_data_entries . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
272277 assert_eq ! ( stored_stage_keys. len( ) , 2 ) ;
273278 assert ! ( stored_stage_keys. contains( & task_keys[ 1 ] ) ) ;
274279 assert ! ( stored_stage_keys. contains( & task_keys[ 2 ] ) ) ;
275280
276281 // Run the last partition of task 1.
277282 let result = do_get ( 1 , 1 , task_keys[ 1 ] . clone ( ) ) . await ;
278283 assert ! ( result. is_ok( ) ) ;
279- let stored_stage_keys = endpoint. stages . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
284+ let stored_stage_keys = endpoint. task_data_entries . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
280285 assert_eq ! ( stored_stage_keys. len( ) , 1 ) ;
281286 assert ! ( stored_stage_keys. contains( & task_keys[ 2 ] ) ) ;
282287
283288 // Run the last partition of the last task.
284289 let result = do_get ( 1 , 2 , task_keys[ 2 ] . clone ( ) ) . await ;
285290 assert ! ( result. is_ok( ) ) ;
286- let stored_stage_keys = endpoint. stages . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
291+ let stored_stage_keys = endpoint. task_data_entries . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
287292 assert_eq ! ( stored_stage_keys. len( ) , 0 ) ;
288293 }
289294
0 commit comments