@@ -14,7 +14,9 @@ use arrow_flight::Ticket;
1414use datafusion:: execution:: SessionState ;
1515use futures:: TryStreamExt ;
1616use prost:: Message ;
17+ use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
1718use std:: sync:: Arc ;
19+ use tokio:: sync:: OnceCell ;
1820use tonic:: metadata:: MetadataMap ;
1921use tonic:: { Request , Response , Status } ;
2022
@@ -37,6 +39,18 @@ pub struct DoGet {
3739 pub stage_key : Option < StageKey > ,
3840}
3941
42+ #[ derive( Clone , Debug ) ]
43+ /// TaskData stores state for a single task being executed by this Endpoint. It may be shared
44+ /// by concurrent requests for the same task which execute separate partitions.
45+ pub struct TaskData {
46+ pub ( super ) state : SessionState ,
47+ pub ( super ) stage : Arc < ExecutionStage > ,
48+ /// Initialized to the total number of partitions in the task (not only tasks in the partition
49+ /// group). This is decremented for each request to the endpoint for this task. Once this count
50+ /// is zero, the task is likely complete.
51+ approx_partitions_remaining : Arc < AtomicUsize > ,
52+ }
53+
4054impl ArrowFlightEndpoint {
4155 pub ( super ) async fn get (
4256 & self ,
@@ -50,7 +64,10 @@ impl ArrowFlightEndpoint {
5064
5165 let partition = doget. partition as usize ;
5266 let task_number = doget. task_number as usize ;
53- let ( mut state, stage) = self . get_state_and_stage ( doget, metadata) . await ?;
67+ let task_data = self . get_state_and_stage ( doget, metadata) . await ?;
68+
69+ let stage = task_data. stage ;
70+ let mut state = task_data. state ;
5471
5572 // find out which partition group we are executing
5673 let task = stage
@@ -90,16 +107,15 @@ impl ArrowFlightEndpoint {
90107 & self ,
91108 doget : DoGet ,
92109 metadata_map : MetadataMap ,
93- ) -> Result < ( SessionState , Arc < ExecutionStage > ) , Status > {
110+ ) -> Result < TaskData , Status > {
94111 let key = doget
95112 . stage_key
96113 . ok_or ( Status :: invalid_argument ( "DoGet is missing the stage key" ) ) ?;
97- let once_stage = {
98- let entry = self . stages . entry ( key) . or_default ( ) ;
99- Arc :: clone ( & entry)
100- } ;
114+ let once_stage = self
115+ . stages
116+ . get_or_init ( key. clone ( ) , || Arc :: new ( OnceCell :: < TaskData > :: new ( ) ) ) ;
101117
102- let ( state , stage ) = once_stage
118+ let stage_data = once_stage
103119 . get_or_try_init ( || async {
104120 let stage_proto = doget
105121 . stage_proto
@@ -134,10 +150,206 @@ impl ArrowFlightEndpoint {
134150 config. set_extension ( stage. clone ( ) ) ;
135151 config. set_extension ( Arc :: new ( ContextGrpcMetadata ( headers) ) ) ;
136152
137- Ok :: < _ , Status > ( ( state, stage) )
153+ // Initialize partition count to the number of partitions in the stage
154+ let total_partitions = stage. plan . properties ( ) . partitioning . partition_count ( ) ;
155+ Ok :: < _ , Status > ( TaskData {
156+ state,
157+ stage,
158+ approx_partitions_remaining : Arc :: new ( AtomicUsize :: new ( total_partitions) ) ,
159+ } )
138160 } )
139161 . await ?;
140162
141- Ok ( ( state. clone ( ) , stage. clone ( ) ) )
163+ // If all the partitions are done, remove the stage from the cache.
164+ let remaining_partitions = stage_data
165+ . approx_partitions_remaining
166+ . fetch_sub ( 1 , Ordering :: SeqCst ) ;
167+ if remaining_partitions <= 1 {
168+ self . stages . remove ( key. clone ( ) ) ;
169+ }
170+
171+ Ok ( stage_data. clone ( ) )
172+ }
173+ }
174+
175+ #[ cfg( test) ]
176+ mod tests {
177+ use super :: * ;
178+ use arrow_flight:: Ticket ;
179+ use prost:: { bytes:: Bytes , Message } ;
180+ use uuid:: Uuid ;
181+
182+ #[ tokio:: test]
183+ async fn test_task_data_partition_counting ( ) {
184+ use crate :: task:: ExecutionTask ;
185+ use arrow_flight:: Ticket ;
186+ use prost:: { bytes:: Bytes , Message } ;
187+ use tonic:: Request ;
188+ use url:: Url ;
189+
190+ // Create a mock channel resolver for ArrowFlightEndpoint
191+ #[ derive( Clone ) ]
192+ struct MockChannelResolver ;
193+
194+ #[ async_trait:: async_trait]
195+ impl crate :: ChannelResolver for MockChannelResolver {
196+ fn get_urls ( & self ) -> Result < Vec < Url > , datafusion:: error:: DataFusionError > {
197+ Ok ( vec ! [ ] )
198+ }
199+
200+ async fn get_channel_for_url (
201+ & self ,
202+ _url : & Url ,
203+ ) -> Result < crate :: BoxCloneSyncChannel , datafusion:: error:: DataFusionError >
204+ {
205+ Err ( datafusion:: error:: DataFusionError :: NotImplemented (
206+ "Mock resolver" . to_string ( ) ,
207+ ) )
208+ }
209+ }
210+
211+ // Create ArrowFlightEndpoint
212+ let endpoint =
213+ ArrowFlightEndpoint :: new ( MockChannelResolver ) . expect ( "Failed to create endpoint" ) ;
214+
215+ // Create 3 tasks with 3 partitions each.
216+ let num_tasks = 3 ;
217+ let num_partitions_per_task = 3 ;
218+ let stage_id = 1 ;
219+ let query_id_uuid = Uuid :: new_v4 ( ) ;
220+ let query_id = query_id_uuid. as_bytes ( ) . to_vec ( ) ;
221+
222+ // Set up protos.
223+ let mut tasks = Vec :: new ( ) ;
224+ for i in 0 ..num_tasks {
225+ tasks. push ( ExecutionTask {
226+ url_str : None ,
227+ partition_group : vec ! [ i] , // Set a random partition in the partition group.
228+ } ) ;
229+ }
230+
231+ let stage_proto = ExecutionStageProto {
232+ query_id : query_id. clone ( ) ,
233+ num : 1 ,
234+ name : format ! ( "test_stage_{}" , 1 ) ,
235+ plan : Some ( Box :: new ( create_mock_physical_plan_proto (
236+ num_partitions_per_task,
237+ ) ) ) ,
238+ inputs : vec ! [ ] ,
239+ tasks,
240+ } ;
241+
242+ let task_keys = vec ! [
243+ StageKey {
244+ query_id: query_id_uuid. to_string( ) ,
245+ stage_id,
246+ task_number: 0 ,
247+ } ,
248+ StageKey {
249+ query_id: query_id_uuid. to_string( ) ,
250+ stage_id,
251+ task_number: 1 ,
252+ } ,
253+ StageKey {
254+ query_id: query_id_uuid. to_string( ) ,
255+ stage_id,
256+ task_number: 2 ,
257+ } ,
258+ ] ;
259+
260+ let stage_proto_for_closure = stage_proto. clone ( ) ;
261+ let endpoint_ref = & endpoint;
262+ let do_get = async move |partition : u64 , task_number : u64 , stage_key : StageKey | {
263+ let stage_proto = stage_proto_for_closure. clone ( ) ;
264+ // Create DoGet message
265+ let doget = DoGet {
266+ stage_proto : Some ( stage_proto) ,
267+ task_number,
268+ partition,
269+ stage_key : Some ( stage_key) ,
270+ } ;
271+
272+ // Create Flight ticket
273+ let ticket = Ticket {
274+ ticket : Bytes :: from ( doget. encode_to_vec ( ) ) ,
275+ } ;
276+
277+ // Call the actual get() method
278+ let request = Request :: new ( ticket) ;
279+ endpoint_ref. get ( request) . await
280+ } ;
281+
282+ // For each task, call do_get() for each partition except the last.
283+ for task_number in 0 ..num_tasks {
284+ for partition in 0 ..num_partitions_per_task - 1 {
285+ let result = do_get (
286+ partition as u64 ,
287+ task_number,
288+ task_keys[ task_number as usize ] . clone ( ) ,
289+ )
290+ . await ;
291+ assert ! ( result. is_ok( ) ) ;
292+ }
293+ }
294+
295+ // Check that the endpoint has not evicted any task states.
296+ assert_eq ! ( endpoint. stages. len( ) , num_tasks as usize ) ;
297+
298+ // Run the last partition of task 0. Any partition number works. Verify that the task state
299+ // is evicted because all partitions have been processed.
300+ let result = do_get ( 1 , 0 , task_keys[ 0 ] . clone ( ) ) . await ;
301+ assert ! ( result. is_ok( ) ) ;
302+ let stored_stage_keys = endpoint. stages . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
303+ assert_eq ! ( stored_stage_keys. len( ) , 2 ) ;
304+ assert ! ( stored_stage_keys. contains( & task_keys[ 1 ] ) ) ;
305+ assert ! ( stored_stage_keys. contains( & task_keys[ 2 ] ) ) ;
306+
307+ // Run the last partition of task 1.
308+ let result = do_get ( 1 , 1 , task_keys[ 1 ] . clone ( ) ) . await ;
309+ assert ! ( result. is_ok( ) ) ;
310+ let stored_stage_keys = endpoint. stages . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
311+ assert_eq ! ( stored_stage_keys. len( ) , 1 ) ;
312+ assert ! ( stored_stage_keys. contains( & task_keys[ 2 ] ) ) ;
313+
314+ // Run the last partition of the last task.
315+ let result = do_get ( 1 , 2 , task_keys[ 2 ] . clone ( ) ) . await ;
316+ assert ! ( result. is_ok( ) ) ;
317+ let stored_stage_keys = endpoint. stages . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
318+ assert_eq ! ( stored_stage_keys. len( ) , 0 ) ;
319+ }
320+
321+ // Helper to create a mock physical plan proto
322+ fn create_mock_physical_plan_proto (
323+ partitions : usize ,
324+ ) -> datafusion_proto:: protobuf:: PhysicalPlanNode {
325+ use datafusion_proto:: protobuf:: partitioning:: PartitionMethod ;
326+ use datafusion_proto:: protobuf:: {
327+ Partitioning , PhysicalPlanNode , RepartitionExecNode , Schema ,
328+ } ;
329+
330+ // Create a repartition node that will have the desired partition count
331+ PhysicalPlanNode {
332+ physical_plan_type : Some (
333+ datafusion_proto:: protobuf:: physical_plan_node:: PhysicalPlanType :: Repartition (
334+ Box :: new ( RepartitionExecNode {
335+ input : Some ( Box :: new ( PhysicalPlanNode {
336+ physical_plan_type : Some (
337+ datafusion_proto:: protobuf:: physical_plan_node:: PhysicalPlanType :: Empty (
338+ datafusion_proto:: protobuf:: EmptyExecNode {
339+ schema : Some ( Schema {
340+ columns : vec ! [ ] ,
341+ metadata : std:: collections:: HashMap :: new ( ) ,
342+ } )
343+ }
344+ )
345+ ) ,
346+ } ) ) ,
347+ partitioning : Some ( Partitioning {
348+ partition_method : Some ( PartitionMethod :: RoundRobin ( partitions as u64 ) ) ,
349+ } ) ,
350+ } )
351+ )
352+ ) ,
353+ }
142354 }
143355}
0 commit comments