@@ -14,7 +14,9 @@ use arrow_flight::Ticket;
1414use datafusion:: execution:: SessionState ;
1515use futures:: TryStreamExt ;
1616use prost:: Message ;
17+ use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
1718use std:: sync:: Arc ;
19+ use tokio:: sync:: OnceCell ;
1820use tonic:: metadata:: MetadataMap ;
1921use tonic:: { Request , Response , Status } ;
2022
@@ -37,6 +39,19 @@ pub struct DoGet {
3739 pub stage_key : Option < StageKey > ,
3840}
3941
42+ #[ derive( Clone , Debug ) ]
43+ /// TaskData stores state for a single task being executed by this Endpoint. It may be shared
44+ /// by concurrent requests for the same task which execute separate partitions.
45+ pub struct TaskData {
46+ pub ( super ) state : SessionState ,
47+ pub ( super ) stage : Arc < ExecutionStage > ,
48+ /// Initialized to the total number of partitions in the task (not only tasks in the partition
49+ /// group). This is decremented for each request to the endpoint for this task. Once this count
50+ /// is zero, the task is likely complete. This count does not account for retried requests
51+ /// for the same partition.
52+ approx_partitions_remaining : Arc < AtomicUsize > ,
53+ }
54+
4055impl ArrowFlightEndpoint {
4156 pub ( super ) async fn get (
4257 & self ,
@@ -50,7 +65,10 @@ impl ArrowFlightEndpoint {
5065
5166 let partition = doget. partition as usize ;
5267 let task_number = doget. task_number as usize ;
53- let ( mut state, stage) = self . get_state_and_stage ( doget, metadata) . await ?;
68+ let task_data = self . get_state_and_stage ( doget, metadata) . await ?;
69+
70+ let stage = task_data. stage ;
71+ let mut state = task_data. state ;
5472
5573 // find out which partition group we are executing
5674 let task = stage
@@ -90,16 +108,15 @@ impl ArrowFlightEndpoint {
90108 & self ,
91109 doget : DoGet ,
92110 metadata_map : MetadataMap ,
93- ) -> Result < ( SessionState , Arc < ExecutionStage > ) , Status > {
111+ ) -> Result < TaskData , Status > {
94112 let key = doget
95113 . stage_key
96114 . ok_or ( Status :: invalid_argument ( "DoGet is missing the stage key" ) ) ?;
97- let once_stage = {
98- let entry = self . stages . entry ( key) . or_default ( ) ;
99- Arc :: clone ( & entry)
100- } ;
115+ let once_stage = self
116+ . stages
117+ . get_or_init ( key. clone ( ) , || Arc :: new ( OnceCell :: < TaskData > :: new ( ) ) ) ;
101118
102- let ( state , stage ) = once_stage
119+ let stage_data = once_stage
103120 . get_or_try_init ( || async {
104121 let stage_proto = doget
105122 . stage_proto
@@ -134,10 +151,206 @@ impl ArrowFlightEndpoint {
134151 config. set_extension ( stage. clone ( ) ) ;
135152 config. set_extension ( Arc :: new ( ContextGrpcMetadata ( headers) ) ) ;
136153
137- Ok :: < _ , Status > ( ( state, stage) )
154+ // Initialize partition count to the number of partitions in the stage
155+ let total_partitions = stage. plan . properties ( ) . partitioning . partition_count ( ) ;
156+ Ok :: < _ , Status > ( TaskData {
157+ state,
158+ stage,
159+ approx_partitions_remaining : Arc :: new ( AtomicUsize :: new ( total_partitions) ) ,
160+ } )
138161 } )
139162 . await ?;
140163
141- Ok ( ( state. clone ( ) , stage. clone ( ) ) )
164+ // If all the partitions are done, remove the stage from the cache.
165+ let remaining_partitions = stage_data
166+ . approx_partitions_remaining
167+ . fetch_sub ( 1 , Ordering :: SeqCst ) ;
168+ if remaining_partitions <= 1 {
169+ self . stages . remove ( key. clone ( ) ) ;
170+ }
171+
172+ Ok ( stage_data. clone ( ) )
173+ }
174+ }
175+
176+ #[ cfg( test) ]
177+ mod tests {
178+ use super :: * ;
179+ use arrow_flight:: Ticket ;
180+ use prost:: { bytes:: Bytes , Message } ;
181+ use uuid:: Uuid ;
182+
183+ #[ tokio:: test]
184+ async fn test_task_data_partition_counting ( ) {
185+ use crate :: task:: ExecutionTask ;
186+ use arrow_flight:: Ticket ;
187+ use prost:: { bytes:: Bytes , Message } ;
188+ use tonic:: Request ;
189+ use url:: Url ;
190+
191+ // Create a mock channel resolver for ArrowFlightEndpoint
192+ #[ derive( Clone ) ]
193+ struct MockChannelResolver ;
194+
195+ #[ async_trait:: async_trait]
196+ impl crate :: ChannelResolver for MockChannelResolver {
197+ fn get_urls ( & self ) -> Result < Vec < Url > , datafusion:: error:: DataFusionError > {
198+ Ok ( vec ! [ ] )
199+ }
200+
201+ async fn get_channel_for_url (
202+ & self ,
203+ _url : & Url ,
204+ ) -> Result < crate :: BoxCloneSyncChannel , datafusion:: error:: DataFusionError >
205+ {
206+ Err ( datafusion:: error:: DataFusionError :: NotImplemented (
207+ "Mock resolver" . to_string ( ) ,
208+ ) )
209+ }
210+ }
211+
212+ // Create ArrowFlightEndpoint
213+ let endpoint =
214+ ArrowFlightEndpoint :: new ( MockChannelResolver ) . expect ( "Failed to create endpoint" ) ;
215+
216+ // Create 3 tasks with 3 partitions each.
217+ let num_tasks = 3 ;
218+ let num_partitions_per_task = 3 ;
219+ let stage_id = 1 ;
220+ let query_id_uuid = Uuid :: new_v4 ( ) ;
221+ let query_id = query_id_uuid. as_bytes ( ) . to_vec ( ) ;
222+
223+ // Set up protos.
224+ let mut tasks = Vec :: new ( ) ;
225+ for i in 0 ..num_tasks {
226+ tasks. push ( ExecutionTask {
227+ url_str : None ,
228+ partition_group : vec ! [ i] , // Set a random partition in the partition group.
229+ } ) ;
230+ }
231+
232+ let stage_proto = ExecutionStageProto {
233+ query_id : query_id. clone ( ) ,
234+ num : 1 ,
235+ name : format ! ( "test_stage_{}" , 1 ) ,
236+ plan : Some ( Box :: new ( create_mock_physical_plan_proto (
237+ num_partitions_per_task,
238+ ) ) ) ,
239+ inputs : vec ! [ ] ,
240+ tasks,
241+ } ;
242+
243+ let task_keys = vec ! [
244+ StageKey {
245+ query_id: query_id_uuid. to_string( ) ,
246+ stage_id,
247+ task_number: 0 ,
248+ } ,
249+ StageKey {
250+ query_id: query_id_uuid. to_string( ) ,
251+ stage_id,
252+ task_number: 1 ,
253+ } ,
254+ StageKey {
255+ query_id: query_id_uuid. to_string( ) ,
256+ stage_id,
257+ task_number: 2 ,
258+ } ,
259+ ] ;
260+
261+ let stage_proto_for_closure = stage_proto. clone ( ) ;
262+ let endpoint_ref = & endpoint;
263+ let do_get = async move |partition : u64 , task_number : u64 , stage_key : StageKey | {
264+ let stage_proto = stage_proto_for_closure. clone ( ) ;
265+ // Create DoGet message
266+ let doget = DoGet {
267+ stage_proto : Some ( stage_proto) ,
268+ task_number,
269+ partition,
270+ stage_key : Some ( stage_key) ,
271+ } ;
272+
273+ // Create Flight ticket
274+ let ticket = Ticket {
275+ ticket : Bytes :: from ( doget. encode_to_vec ( ) ) ,
276+ } ;
277+
278+ // Call the actual get() method
279+ let request = Request :: new ( ticket) ;
280+ endpoint_ref. get ( request) . await
281+ } ;
282+
283+ // For each task, call do_get() for each partition except the last.
284+ for task_number in 0 ..num_tasks {
285+ for partition in 0 ..num_partitions_per_task - 1 {
286+ let result = do_get (
287+ partition as u64 ,
288+ task_number,
289+ task_keys[ task_number as usize ] . clone ( ) ,
290+ )
291+ . await ;
292+ assert ! ( result. is_ok( ) ) ;
293+ }
294+ }
295+
296+ // Check that the endpoint has not evicted any task states.
297+ assert_eq ! ( endpoint. stages. len( ) , num_tasks as usize ) ;
298+
299+ // Run the last partition of task 0. Any partition number works. Verify that the task state
300+ // is evicted because all partitions have been processed.
301+ let result = do_get ( 1 , 0 , task_keys[ 0 ] . clone ( ) ) . await ;
302+ assert ! ( result. is_ok( ) ) ;
303+ let stored_stage_keys = endpoint. stages . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
304+ assert_eq ! ( stored_stage_keys. len( ) , 2 ) ;
305+ assert ! ( stored_stage_keys. contains( & task_keys[ 1 ] ) ) ;
306+ assert ! ( stored_stage_keys. contains( & task_keys[ 2 ] ) ) ;
307+
308+ // Run the last partition of task 1.
309+ let result = do_get ( 1 , 1 , task_keys[ 1 ] . clone ( ) ) . await ;
310+ assert ! ( result. is_ok( ) ) ;
311+ let stored_stage_keys = endpoint. stages . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
312+ assert_eq ! ( stored_stage_keys. len( ) , 1 ) ;
313+ assert ! ( stored_stage_keys. contains( & task_keys[ 2 ] ) ) ;
314+
315+ // Run the last partition of the last task.
316+ let result = do_get ( 1 , 2 , task_keys[ 2 ] . clone ( ) ) . await ;
317+ assert ! ( result. is_ok( ) ) ;
318+ let stored_stage_keys = endpoint. stages . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
319+ assert_eq ! ( stored_stage_keys. len( ) , 0 ) ;
320+ }
321+
322+ // Helper to create a mock physical plan proto
323+ fn create_mock_physical_plan_proto (
324+ partitions : usize ,
325+ ) -> datafusion_proto:: protobuf:: PhysicalPlanNode {
326+ use datafusion_proto:: protobuf:: partitioning:: PartitionMethod ;
327+ use datafusion_proto:: protobuf:: {
328+ Partitioning , PhysicalPlanNode , RepartitionExecNode , Schema ,
329+ } ;
330+
331+ // Create a repartition node that will have the desired partition count
332+ PhysicalPlanNode {
333+ physical_plan_type : Some (
334+ datafusion_proto:: protobuf:: physical_plan_node:: PhysicalPlanType :: Repartition (
335+ Box :: new ( RepartitionExecNode {
336+ input : Some ( Box :: new ( PhysicalPlanNode {
337+ physical_plan_type : Some (
338+ datafusion_proto:: protobuf:: physical_plan_node:: PhysicalPlanType :: Empty (
339+ datafusion_proto:: protobuf:: EmptyExecNode {
340+ schema : Some ( Schema {
341+ columns : vec ! [ ] ,
342+ metadata : std:: collections:: HashMap :: new ( ) ,
343+ } )
344+ }
345+ )
346+ ) ,
347+ } ) ) ,
348+ partitioning : Some ( Partitioning {
349+ partition_method : Some ( PartitionMethod :: RoundRobin ( partitions as u64 ) ) ,
350+ } ) ,
351+ } )
352+ )
353+ ) ,
354+ }
142355 }
143356}
0 commit comments