@@ -14,7 +14,9 @@ use arrow_flight::Ticket;
1414use datafusion:: execution:: SessionState ;
1515use futures:: TryStreamExt ;
1616use prost:: Message ;
17+ use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
1718use std:: sync:: Arc ;
19+ use tokio:: sync:: OnceCell ;
1820use tonic:: metadata:: MetadataMap ;
1921use tonic:: { Request , Response , Status } ;
2022
@@ -37,6 +39,13 @@ pub struct DoGet {
3739 pub stage_key : Option < StageKey > ,
3840}
3941
42+ #[ derive( Clone , Debug ) ]
43+ pub struct TaskData {
44+ pub ( super ) state : SessionState ,
45+ pub ( super ) stage : Arc < ExecutionStage > ,
46+ partition_count : Arc < AtomicUsize > ,
47+ }
48+
4049impl ArrowFlightEndpoint {
4150 pub ( super ) async fn get (
4251 & self ,
@@ -50,7 +59,10 @@ impl ArrowFlightEndpoint {
5059
5160 let partition = doget. partition as usize ;
5261 let task_number = doget. task_number as usize ;
53- let ( mut state, stage) = self . get_state_and_stage ( doget, metadata) . await ?;
62+ let task_data = self . get_state_and_stage ( doget, metadata) . await ?;
63+
64+ let stage = task_data. stage ;
65+ let mut state = task_data. state ;
5466
5567 // find out which partition group we are executing
5668 let task = stage
@@ -90,16 +102,15 @@ impl ArrowFlightEndpoint {
90102 & self ,
91103 doget : DoGet ,
92104 metadata_map : MetadataMap ,
93- ) -> Result < ( SessionState , Arc < ExecutionStage > ) , Status > {
105+ ) -> Result < TaskData , Status > {
94106 let key = doget
95107 . stage_key
96108 . ok_or ( Status :: invalid_argument ( "DoGet is missing the stage key" ) ) ?;
97- let once_stage = {
98- let entry = self . stages . entry ( key) . or_default ( ) ;
99- Arc :: clone ( & entry)
100- } ;
109+ let once_stage = self
110+ . stages
111+ . get_or_init ( key. clone ( ) , || Arc :: new ( OnceCell :: < TaskData > :: new ( ) ) ) ;
101112
102- let ( state , stage ) = once_stage
113+ let stage_data = once_stage
103114 . get_or_try_init ( || async {
104115 let stage_proto = doget
105116 . stage_proto
@@ -134,10 +145,203 @@ impl ArrowFlightEndpoint {
134145 config. set_extension ( stage. clone ( ) ) ;
135146 config. set_extension ( Arc :: new ( ContextGrpcMetadata ( headers) ) ) ;
136147
137- Ok :: < _ , Status > ( ( state, stage) )
148+ // Initialize partition count to the number of partitions in the stage
149+ let total_partitions = stage. plan . properties ( ) . partitioning . partition_count ( ) ;
150+ Ok :: < _ , Status > ( TaskData {
151+ state,
152+ stage,
153+ partition_count : Arc :: new ( AtomicUsize :: new ( total_partitions) ) ,
154+ } )
138155 } )
139156 . await ?;
140157
141- Ok ( ( state. clone ( ) , stage. clone ( ) ) )
158+ let remaining_partitions = stage_data. partition_count . fetch_sub ( 1 , Ordering :: SeqCst ) ;
159+ if remaining_partitions <= 1 {
160+ self . stages . remove ( key. clone ( ) ) ;
161+ }
162+
163+ Ok ( stage_data. clone ( ) )
164+ }
165+ }
166+
167+ #[ cfg( test) ]
168+ mod tests {
169+ use super :: * ;
170+ use arrow_flight:: Ticket ;
171+ use prost:: { bytes:: Bytes , Message } ;
172+ use uuid:: Uuid ;
173+
174+ #[ tokio:: test]
175+ async fn test_task_data_partition_counting ( ) {
176+ use crate :: task:: ExecutionTask ;
177+ use arrow_flight:: Ticket ;
178+ use prost:: { bytes:: Bytes , Message } ;
179+ use tonic:: Request ;
180+ use url:: Url ;
181+
182+ // Create a mock channel resolver for ArrowFlightEndpoint
183+ #[ derive( Clone ) ]
184+ struct MockChannelResolver ;
185+
186+ #[ async_trait:: async_trait]
187+ impl crate :: ChannelResolver for MockChannelResolver {
188+ fn get_urls ( & self ) -> Result < Vec < Url > , datafusion:: error:: DataFusionError > {
189+ Ok ( vec ! [ ] )
190+ }
191+
192+ async fn get_channel_for_url (
193+ & self ,
194+ _url : & Url ,
195+ ) -> Result < crate :: BoxCloneSyncChannel , datafusion:: error:: DataFusionError >
196+ {
197+ Err ( datafusion:: error:: DataFusionError :: NotImplemented (
198+ "Mock resolver" . to_string ( ) ,
199+ ) )
200+ }
201+ }
202+
203+ // Create ArrowFlightEndpoint
204+ let endpoint =
205+ ArrowFlightEndpoint :: new ( MockChannelResolver ) . expect ( "Failed to create endpoint" ) ;
206+
207+ // Test parameters: 2 stages, each with 3 tasks, each task has 10 partitions
208+ let num_tasks = 3 ;
209+ let num_partitions_per_task = 3 ;
210+ let stage_id = 1 ;
211+ let query_id_uuid = Uuid :: new_v4 ( ) ;
212+ let query_id = query_id_uuid. as_bytes ( ) . to_vec ( ) ;
213+
214+ // Create ExecutionStage proto with multiple tasks
215+ let mut tasks = Vec :: new ( ) ;
216+ for i in 0 ..num_tasks {
217+ tasks. push ( ExecutionTask {
218+ url_str : None ,
219+ partition_group : vec ! [ i] , // Different partition group for each task
220+ } ) ;
221+ }
222+
223+ let stage_proto = ExecutionStageProto {
224+ query_id : query_id. clone ( ) ,
225+ num : 1 ,
226+ name : format ! ( "test_stage_{}" , 1 ) ,
227+ plan : Some ( Box :: new ( create_mock_physical_plan_proto (
228+ num_partitions_per_task,
229+ ) ) ) ,
230+ inputs : vec ! [ ] ,
231+ tasks,
232+ } ;
233+
234+ let task_keys = vec ! [
235+ StageKey {
236+ query_id: query_id_uuid. to_string( ) ,
237+ stage_id,
238+ task_number: 0 ,
239+ } ,
240+ StageKey {
241+ query_id: query_id_uuid. to_string( ) ,
242+ stage_id,
243+ task_number: 1 ,
244+ } ,
245+ StageKey {
246+ query_id: query_id_uuid. to_string( ) ,
247+ stage_id,
248+ task_number: 2 ,
249+ } ,
250+ ] ;
251+
252+ let stage_proto_for_closure = stage_proto. clone ( ) ;
253+ let endpoint_ref = & endpoint;
254+ let do_get = async move |partition : u64 , task_number : u64 , stage_key : StageKey | {
255+ let stage_proto = stage_proto_for_closure. clone ( ) ;
256+ // Create DoGet message
257+ let doget = DoGet {
258+ stage_proto : Some ( stage_proto) ,
259+ task_number,
260+ partition,
261+ stage_key : Some ( stage_key) ,
262+ } ;
263+
264+ // Create Flight ticket
265+ let ticket = Ticket {
266+ ticket : Bytes :: from ( doget. encode_to_vec ( ) ) ,
267+ } ;
268+
269+ // Call the actual get() method
270+ let request = Request :: new ( ticket) ;
271+ endpoint_ref. get ( request) . await
272+ } ;
273+
274+ // For each task, call do_get() for each partition except the last.
275+ for task_number in 0 ..num_tasks {
276+ for partition in 0 ..num_partitions_per_task - 1 {
277+ let result = do_get (
278+ partition as u64 ,
279+ task_number,
280+ task_keys[ task_number as usize ] . clone ( ) ,
281+ )
282+ . await ;
283+ assert ! ( result. is_ok( ) ) ;
284+ }
285+ }
286+
287+ // Check that the endpoint has not evicted any tasks.
288+ assert_eq ! ( endpoint. stages. len( ) , num_tasks as usize ) ;
289+
290+ // Run the last partition of task 0. Any partition number works. Verify that the task state
291+ // is evicted because all partitions have been processed.
292+ let result = do_get ( 1 , 0 , task_keys[ 0 ] . clone ( ) ) . await ;
293+ assert ! ( result. is_ok( ) ) ;
294+ let stored_stage_keys = endpoint. stages . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
295+ assert_eq ! ( stored_stage_keys. len( ) , 2 ) ;
296+ assert ! ( stored_stage_keys. contains( & task_keys[ 1 ] ) ) ;
297+ assert ! ( stored_stage_keys. contains( & task_keys[ 2 ] ) ) ;
298+
299+ // Run the last partition of task 1.
300+ let result = do_get ( 1 , 1 , task_keys[ 1 ] . clone ( ) ) . await ;
301+ assert ! ( result. is_ok( ) ) ;
302+ let stored_stage_keys = endpoint. stages . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
303+ assert_eq ! ( stored_stage_keys. len( ) , 1 ) ;
304+ assert ! ( stored_stage_keys. contains( & task_keys[ 2 ] ) ) ;
305+
306+ // Run the last partition of the last task.
307+ let result = do_get ( 1 , 2 , task_keys[ 2 ] . clone ( ) ) . await ;
308+ assert ! ( result. is_ok( ) ) ;
309+ let stored_stage_keys = endpoint. stages . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
310+ assert_eq ! ( stored_stage_keys. len( ) , 0 ) ;
311+ }
312+
313+ // Helper to create a mock physical plan proto
314+ fn create_mock_physical_plan_proto (
315+ partitions : usize ,
316+ ) -> datafusion_proto:: protobuf:: PhysicalPlanNode {
317+ use datafusion_proto:: protobuf:: partitioning:: PartitionMethod ;
318+ use datafusion_proto:: protobuf:: {
319+ Partitioning , PhysicalPlanNode , RepartitionExecNode , Schema ,
320+ } ;
321+
322+ // Create a repartition node that will have the desired partition count
323+ PhysicalPlanNode {
324+ physical_plan_type : Some (
325+ datafusion_proto:: protobuf:: physical_plan_node:: PhysicalPlanType :: Repartition (
326+ Box :: new ( RepartitionExecNode {
327+ input : Some ( Box :: new ( PhysicalPlanNode {
328+ physical_plan_type : Some (
329+ datafusion_proto:: protobuf:: physical_plan_node:: PhysicalPlanType :: Empty (
330+ datafusion_proto:: protobuf:: EmptyExecNode {
331+ schema : Some ( Schema {
332+ columns : vec ! [ ] ,
333+ metadata : std:: collections:: HashMap :: new ( ) ,
334+ } )
335+ }
336+ )
337+ ) ,
338+ } ) ) ,
339+ partitioning : Some ( Partitioning {
340+ partition_method : Some ( PartitionMethod :: RoundRobin ( partitions as u64 ) ) ,
341+ } ) ,
342+ } )
343+ )
344+ ) ,
345+ }
142346 }
143347}
0 commit comments