@@ -9,14 +9,12 @@ use arrow_flight::encode::FlightDataEncoderBuilder;
99use arrow_flight:: error:: FlightError ;
1010use arrow_flight:: flight_service_server:: FlightService ;
1111use arrow_flight:: Ticket ;
12- use datafusion:: execution:: { SendableRecordBatchStream , SessionState } ;
12+ use datafusion:: execution:: SendableRecordBatchStream ;
1313use futures:: TryStreamExt ;
14- use http:: HeaderMap ;
1514use prost:: Message ;
1615use std:: fmt:: Display ;
1716use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
1817use std:: sync:: Arc ;
19- use tokio:: sync:: OnceCell ;
2018use tonic:: { Request , Response , Status } ;
2119
2220#[ derive( Clone , PartialEq , :: prost:: Message ) ]
@@ -42,7 +40,6 @@ pub struct DoGet {
4240/// TaskData stores state for a single task being executed by this Endpoint. It may be shared
4341/// by concurrent requests for the same task which execute separate partitions.
4442pub struct TaskData {
45- pub ( super ) session_state : SessionState ,
4643 pub ( super ) stage : Arc < StageExec > ,
4744 /// `num_partitions_remaining` is initialized to the total number of partitions in the task (not
4845 /// only tasks in the partition group). This is decremented for each request to the endpoint
@@ -62,15 +59,47 @@ impl ArrowFlightEndpoint {
6259 Status :: invalid_argument ( format ! ( "Cannot decode DoGet message: {err}" ) )
6360 } ) ?;
6461
62+ let mut session_state = self
63+ . session_builder
64+ . build_session_state ( DistributedSessionBuilderContext {
65+ runtime_env : Arc :: clone ( & self . runtime ) ,
66+ headers : metadata. clone ( ) . into_headers ( ) ,
67+ } )
68+ . await
69+ . map_err ( |err| datafusion_error_to_tonic_status ( & err) ) ?;
70+
71+ let codec = DistributedCodec :: new_combined_with_user ( session_state. config ( ) ) ;
72+
6573 // There's only 1 `StageExec` responsible for all requests that share the same `stage_key`,
6674 // so here we either retrieve the existing one or create a new one if it does not exist.
67- let ( mut session_state, stage) = self
68- . get_state_and_stage (
69- doget. stage_key . ok_or_else ( missing ( "stage_key" ) ) ?,
70- doget. stage_proto . ok_or_else ( missing ( "stage_proto" ) ) ?,
71- metadata. clone ( ) . into_headers ( ) ,
72- )
75+ let key = doget. stage_key . ok_or_else ( missing ( "stage_key" ) ) ?;
76+ let once = self
77+ . task_data_entries
78+ . get_or_init ( key. clone ( ) , Default :: default) ;
79+
80+ let stage_data = once
81+ . get_or_try_init ( || async {
82+ let stage_proto = doget. stage_proto . ok_or_else ( missing ( "stage_proto" ) ) ?;
83+ let stage = stage_from_proto ( stage_proto, & session_state, & self . runtime , & codec)
84+ . map_err ( |err| {
85+ Status :: invalid_argument ( format ! ( "Cannot decode stage proto: {err}" ) )
86+ } ) ?;
87+
88+ // Initialize partition count to the number of partitions in the stage
89+ let total_partitions = stage. plan . properties ( ) . partitioning . partition_count ( ) ;
90+ Ok :: < _ , Status > ( TaskData {
91+ stage : Arc :: new ( stage) ,
92+ num_partitions_remaining : Arc :: new ( AtomicUsize :: new ( total_partitions) ) ,
93+ } )
94+ } )
7395 . await ?;
96+ let stage = Arc :: clone ( & stage_data. stage ) ;
97+ let num_partitions_remaining = Arc :: clone ( & stage_data. num_partitions_remaining ) ;
98+
99+ // If all the partitions are done, remove the stage from the cache.
100+ if num_partitions_remaining. fetch_sub ( 1 , Ordering :: SeqCst ) <= 1 {
101+ self . task_data_entries . remove ( key) ;
102+ }
74103
75104 // Find out which partition group we are executing
76105 let partition = doget. partition as usize ;
@@ -95,55 +124,6 @@ impl ArrowFlightEndpoint {
95124
96125 Ok ( record_batch_stream_to_response ( stream) )
97126 }
98-
99- async fn get_state_and_stage (
100- & self ,
101- key : StageKey ,
102- stage_proto : StageExecProto ,
103- headers : HeaderMap ,
104- ) -> Result < ( SessionState , Arc < StageExec > ) , Status > {
105- let once = self
106- . task_data_entries
107- . get_or_init ( key. clone ( ) , || Arc :: new ( OnceCell :: < TaskData > :: new ( ) ) ) ;
108-
109- let stage_data = once
110- . get_or_try_init ( || async {
111- let session_state = self
112- . session_builder
113- . build_session_state ( DistributedSessionBuilderContext {
114- runtime_env : Arc :: clone ( & self . runtime ) ,
115- headers,
116- } )
117- . await
118- . map_err ( |err| datafusion_error_to_tonic_status ( & err) ) ?;
119-
120- let codec = DistributedCodec :: new_combined_with_user ( session_state. config ( ) ) ;
121-
122- let stage = stage_from_proto ( stage_proto, & session_state, & self . runtime , & codec)
123- . map_err ( |err| {
124- Status :: invalid_argument ( format ! ( "Cannot decode stage proto: {err}" ) )
125- } ) ?;
126-
127- // Initialize partition count to the number of partitions in the stage
128- let total_partitions = stage. plan . properties ( ) . partitioning . partition_count ( ) ;
129- Ok :: < _ , Status > ( TaskData {
130- session_state,
131- stage : Arc :: new ( stage) ,
132- num_partitions_remaining : Arc :: new ( AtomicUsize :: new ( total_partitions) ) ,
133- } )
134- } )
135- . await ?;
136-
137- // If all the partitions are done, remove the stage from the cache.
138- let remaining_partitions = stage_data
139- . num_partitions_remaining
140- . fetch_sub ( 1 , Ordering :: SeqCst ) ;
141- if remaining_partitions <= 1 {
142- self . task_data_entries . remove ( key) ;
143- }
144-
145- Ok ( ( stage_data. session_state . clone ( ) , stage_data. stage . clone ( ) ) )
146- }
147127}
148128
149129fn missing ( field : & ' static str ) -> impl FnOnce ( ) -> Status {
0 commit comments