@@ -39,7 +39,7 @@ pub struct DoGet {
3939 pub stage_key : Option < StageKey > ,
4040}
4141
42- #[ derive( Clone ) ]
42+ #[ derive( Clone , Debug ) ]
4343pub struct TaskData {
4444 pub ( super ) state : SessionState ,
4545 pub ( super ) stage : Arc < ExecutionStage > ,
@@ -108,9 +108,9 @@ impl ArrowFlightEndpoint {
108108 . ok_or ( Status :: invalid_argument ( "DoGet is missing the stage key" ) ) ?;
109109 let once_stage = self
110110 . stages
111- . get_or_init ( key. clone ( ) , || OnceCell :: < TaskData > :: new ( ) ) ;
111+ . get_or_init ( key. clone ( ) , || Arc :: new ( OnceCell :: < TaskData > :: new ( ) ) ) ;
112112
113- let mut stage_data = once_stage
113+ let stage_data = once_stage
114114 . get_or_try_init ( || async {
115115 let stage_proto = doget
116116 . stage_proto
@@ -145,19 +145,203 @@ impl ArrowFlightEndpoint {
145145 config. set_extension ( stage. clone ( ) ) ;
146146 config. set_extension ( Arc :: new ( ContextGrpcMetadata ( headers) ) ) ;
147147
148+ // Initialize partition count to the number of partitions in the stage
149+ let total_partitions = stage. plan . properties ( ) . partitioning . partition_count ( ) ;
148150 Ok :: < _ , Status > ( TaskData {
149151 state,
150152 stage,
151- partition_count : Arc :: new ( AtomicUsize :: new ( 0 ) ) ,
153+ partition_count : Arc :: new ( AtomicUsize :: new ( total_partitions ) ) ,
152154 } )
153155 } )
154156 . await ?;
155157
156- stage_data. partition_count . fetch_sub ( 1 , Ordering :: SeqCst ) ;
157- if stage_data . partition_count . load ( Ordering :: SeqCst ) <= 0 {
158+ let remaining_partitions = stage_data. partition_count . fetch_sub ( 1 , Ordering :: SeqCst ) ;
159+ if remaining_partitions <= 1 {
158160 self . stages . remove ( key. clone ( ) ) ;
159161 }
160162
161163 Ok ( stage_data. clone ( ) )
162164 }
163165}
166+
167+ #[ cfg( test) ]
168+ mod tests {
169+ use super :: * ;
170+ use arrow_flight:: Ticket ;
171+ use prost:: { bytes:: Bytes , Message } ;
172+ use uuid:: Uuid ;
173+
174+ #[ tokio:: test]
175+ async fn test_task_data_partition_counting ( ) {
176+ use crate :: task:: ExecutionTask ;
177+ use arrow_flight:: Ticket ;
178+ use prost:: { bytes:: Bytes , Message } ;
179+ use tonic:: Request ;
180+ use url:: Url ;
181+
182+ // Create a mock channel resolver for ArrowFlightEndpoint
183+ #[ derive( Clone ) ]
184+ struct MockChannelResolver ;
185+
186+ #[ async_trait:: async_trait]
187+ impl crate :: ChannelResolver for MockChannelResolver {
188+ fn get_urls ( & self ) -> Result < Vec < Url > , datafusion:: error:: DataFusionError > {
189+ Ok ( vec ! [ ] )
190+ }
191+
192+ async fn get_channel_for_url (
193+ & self ,
194+ _url : & Url ,
195+ ) -> Result < crate :: BoxCloneSyncChannel , datafusion:: error:: DataFusionError >
196+ {
197+ Err ( datafusion:: error:: DataFusionError :: NotImplemented (
198+ "Mock resolver" . to_string ( ) ,
199+ ) )
200+ }
201+ }
202+
203+ // Create ArrowFlightEndpoint
204+ let endpoint =
205+ ArrowFlightEndpoint :: new ( MockChannelResolver ) . expect ( "Failed to create endpoint" ) ;
206+
207+ // Test parameters: 2 stages, each with 3 tasks, each task has 10 partitions
208+ let num_tasks = 3 ;
209+ let num_partitions_per_task = 3 ;
210+ let stage_id = 1 ;
211+ let query_id_uuid = Uuid :: new_v4 ( ) ;
212+ let query_id = query_id_uuid. as_bytes ( ) . to_vec ( ) ;
213+
214+ // Create ExecutionStage proto with multiple tasks
215+ let mut tasks = Vec :: new ( ) ;
216+ for i in 0 ..num_tasks {
217+ tasks. push ( ExecutionTask {
218+ url_str : None ,
219+ partition_group : vec ! [ i] , // Different partition group for each task
220+ } ) ;
221+ }
222+
223+ let stage_proto = ExecutionStageProto {
224+ query_id : query_id. clone ( ) ,
225+ num : 1 ,
226+ name : format ! ( "test_stage_{}" , 1 ) ,
227+ plan : Some ( Box :: new ( create_mock_physical_plan_proto (
228+ num_partitions_per_task,
229+ ) ) ) ,
230+ inputs : vec ! [ ] ,
231+ tasks,
232+ } ;
233+
234+ let task_keys = vec ! [
235+ StageKey {
236+ query_id: query_id_uuid. to_string( ) ,
237+ stage_id,
238+ task_number: 0 ,
239+ } ,
240+ StageKey {
241+ query_id: query_id_uuid. to_string( ) ,
242+ stage_id,
243+ task_number: 1 ,
244+ } ,
245+ StageKey {
246+ query_id: query_id_uuid. to_string( ) ,
247+ stage_id,
248+ task_number: 2 ,
249+ } ,
250+ ] ;
251+
252+ let stage_proto_for_closure = stage_proto. clone ( ) ;
253+ let endpoint_ref = & endpoint;
254+ let do_get = async move |partition : u64 , task_number : u64 , stage_key : StageKey | {
255+ let stage_proto = stage_proto_for_closure. clone ( ) ;
256+ // Create DoGet message
257+ let doget = DoGet {
258+ stage_proto : Some ( stage_proto) ,
259+ task_number,
260+ partition,
261+ stage_key : Some ( stage_key) ,
262+ } ;
263+
264+ // Create Flight ticket
265+ let ticket = Ticket {
266+ ticket : Bytes :: from ( doget. encode_to_vec ( ) ) ,
267+ } ;
268+
269+ // Call the actual get() method
270+ let request = Request :: new ( ticket) ;
271+ endpoint_ref. get ( request) . await
272+ } ;
273+
274+ // For each task, call do_get() for each partition except the last.
275+ for task_number in 0 ..num_tasks {
276+ for partition in 0 ..num_partitions_per_task - 1 {
277+ let result = do_get (
278+ partition as u64 ,
279+ task_number,
280+ task_keys[ task_number as usize ] . clone ( ) ,
281+ )
282+ . await ;
283+ assert ! ( result. is_ok( ) ) ;
284+ }
285+ }
286+
287+ // Check that the endpoint has not evicted any tasks.
288+ assert_eq ! ( endpoint. stages. len( ) , num_tasks as usize ) ;
289+
290+ // Run the last partition of task 0. Any partition number works. Verify that the task state
291+ // is evicted because all partitions have been processed.
292+ let result = do_get ( 1 , 0 , task_keys[ 0 ] . clone ( ) ) . await ;
293+ assert ! ( result. is_ok( ) ) ;
294+ let stored_stage_keys = endpoint. stages . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
295+ assert_eq ! ( stored_stage_keys. len( ) , 2 ) ;
296+ assert ! ( stored_stage_keys. contains( & task_keys[ 1 ] ) ) ;
297+ assert ! ( stored_stage_keys. contains( & task_keys[ 2 ] ) ) ;
298+
299+ // Run the last partition of task 1.
300+ let result = do_get ( 1 , 1 , task_keys[ 1 ] . clone ( ) ) . await ;
301+ assert ! ( result. is_ok( ) ) ;
302+ let stored_stage_keys = endpoint. stages . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
303+ assert_eq ! ( stored_stage_keys. len( ) , 1 ) ;
304+ assert ! ( stored_stage_keys. contains( & task_keys[ 2 ] ) ) ;
305+
306+ // Run the last partition of the last task.
307+ let result = do_get ( 1 , 2 , task_keys[ 2 ] . clone ( ) ) . await ;
308+ assert ! ( result. is_ok( ) ) ;
309+ let stored_stage_keys = endpoint. stages . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
310+ assert_eq ! ( stored_stage_keys. len( ) , 0 ) ;
311+ }
312+
313+ // Helper to create a mock physical plan proto
314+ fn create_mock_physical_plan_proto (
315+ partitions : usize ,
316+ ) -> datafusion_proto:: protobuf:: PhysicalPlanNode {
317+ use datafusion_proto:: protobuf:: partitioning:: PartitionMethod ;
318+ use datafusion_proto:: protobuf:: {
319+ Partitioning , PhysicalPlanNode , RepartitionExecNode , Schema ,
320+ } ;
321+
322+ // Create a repartition node that will have the desired partition count
323+ PhysicalPlanNode {
324+ physical_plan_type : Some (
325+ datafusion_proto:: protobuf:: physical_plan_node:: PhysicalPlanType :: Repartition (
326+ Box :: new ( RepartitionExecNode {
327+ input : Some ( Box :: new ( PhysicalPlanNode {
328+ physical_plan_type : Some (
329+ datafusion_proto:: protobuf:: physical_plan_node:: PhysicalPlanType :: Empty (
330+ datafusion_proto:: protobuf:: EmptyExecNode {
331+ schema : Some ( Schema {
332+ columns : vec ! [ ] ,
333+ metadata : std:: collections:: HashMap :: new ( ) ,
334+ } )
335+ }
336+ )
337+ ) ,
338+ } ) ) ,
339+ partitioning : Some ( Partitioning {
340+ partition_method : Some ( PartitionMethod :: RoundRobin ( partitions as u64 ) ) ,
341+ } ) ,
342+ } )
343+ )
344+ ) ,
345+ }
346+ }
347+ }
0 commit comments