@@ -14,7 +14,9 @@ use arrow_flight::Ticket;
1414use datafusion:: execution:: SessionState ;
1515use futures:: TryStreamExt ;
1616use prost:: Message ;
17+ use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
1718use std:: sync:: Arc ;
19+ use tokio:: sync:: OnceCell ;
1820use tonic:: metadata:: MetadataMap ;
1921use tonic:: { Request , Response , Status } ;
2022
@@ -37,6 +39,20 @@ pub struct DoGet {
3739 pub stage_key : Option < StageKey > ,
3840}
3941
42+ #[ derive( Clone , Debug ) ]
43+ /// TaskData stores state for a single task being executed by this Endpoint. It may be shared
44+ /// by concurrent requests for the same task which execute separate partitions.
45+ pub struct TaskData {
46+ pub ( super ) state : SessionState ,
47+ pub ( super ) stage : Arc < ExecutionStage > ,
48+ ///num_partitions_remaining is initialized to the total number of partitions in the task (not
49+ /// only tasks in the partition group). This is decremented for each request to the endpoint
50+ /// for this task. Once this count is zero, the task is likely complete. The task may not be
51+ /// complete because it's possible that the same partition was retried and this count was
52+ /// decremented more than once for the same partition.
53+ num_partitions_remaining : Arc < AtomicUsize > ,
54+ }
55+
4056impl ArrowFlightEndpoint {
4157 pub ( super ) async fn get (
4258 & self ,
@@ -50,7 +66,10 @@ impl ArrowFlightEndpoint {
5066
5167 let partition = doget. partition as usize ;
5268 let task_number = doget. task_number as usize ;
53- let ( mut state, stage) = self . get_state_and_stage ( doget, metadata) . await ?;
69+ let task_data = self . get_state_and_stage ( doget, metadata) . await ?;
70+
71+ let stage = task_data. stage ;
72+ let mut state = task_data. state ;
5473
5574 // find out which partition group we are executing
5675 let task = stage
@@ -90,16 +109,15 @@ impl ArrowFlightEndpoint {
90109 & self ,
91110 doget : DoGet ,
92111 metadata_map : MetadataMap ,
93- ) -> Result < ( SessionState , Arc < ExecutionStage > ) , Status > {
112+ ) -> Result < TaskData , Status > {
94113 let key = doget
95114 . stage_key
96115 . ok_or ( Status :: invalid_argument ( "DoGet is missing the stage key" ) ) ?;
97- let once_stage = {
98- let entry = self . stages . entry ( key) . or_default ( ) ;
99- Arc :: clone ( & entry)
100- } ;
116+ let once_stage = self
117+ . stages
118+ . get_or_init ( key. clone ( ) , || Arc :: new ( OnceCell :: < TaskData > :: new ( ) ) ) ;
101119
102- let ( state , stage ) = once_stage
120+ let stage_data = once_stage
103121 . get_or_try_init ( || async {
104122 let stage_proto = doget
105123 . stage_proto
@@ -133,10 +151,183 @@ impl ArrowFlightEndpoint {
133151 config. set_extension ( stage. clone ( ) ) ;
134152 config. set_extension ( Arc :: new ( ContextGrpcMetadata ( headers) ) ) ;
135153
136- Ok :: < _ , Status > ( ( state, stage) )
154+ // Initialize partition count to the number of partitions in the stage
155+ let total_partitions = stage. plan . properties ( ) . partitioning . partition_count ( ) ;
156+ Ok :: < _ , Status > ( TaskData {
157+ state,
158+ stage,
159+ num_partitions_remaining : Arc :: new ( AtomicUsize :: new ( total_partitions) ) ,
160+ } )
137161 } )
138162 . await ?;
139163
140- Ok ( ( state. clone ( ) , stage. clone ( ) ) )
164+ // If all the partitions are done, remove the stage from the cache.
165+ let remaining_partitions = stage_data
166+ . num_partitions_remaining
167+ . fetch_sub ( 1 , Ordering :: SeqCst ) ;
168+ if remaining_partitions <= 1 {
169+ self . stages . remove ( key. clone ( ) ) ;
170+ }
171+
172+ Ok ( stage_data. clone ( ) )
173+ }
174+ }
175+
176+ #[ cfg( test) ]
177+ mod tests {
178+ use super :: * ;
179+ use uuid:: Uuid ;
180+
181+ #[ tokio:: test]
182+ async fn test_task_data_partition_counting ( ) {
183+ use crate :: flight_service:: session_builder:: DefaultSessionBuilder ;
184+ use crate :: task:: ExecutionTask ;
185+ use arrow_flight:: Ticket ;
186+ use prost:: { bytes:: Bytes , Message } ;
187+ use tonic:: Request ;
188+
189+ // Create ArrowFlightEndpoint with DefaultSessionBuilder
190+ let endpoint =
191+ ArrowFlightEndpoint :: new ( DefaultSessionBuilder ) . expect ( "Failed to create endpoint" ) ;
192+
193+ // Create 3 tasks with 3 partitions each.
194+ let num_tasks = 3 ;
195+ let num_partitions_per_task = 3 ;
196+ let stage_id = 1 ;
197+ let query_id_uuid = Uuid :: new_v4 ( ) ;
198+ let query_id = query_id_uuid. as_bytes ( ) . to_vec ( ) ;
199+
200+ // Set up protos.
201+ let mut tasks = Vec :: new ( ) ;
202+ for i in 0 ..num_tasks {
203+ tasks. push ( ExecutionTask {
204+ url_str : None ,
205+ partition_group : vec ! [ i] , // Set a random partition in the partition group.
206+ } ) ;
207+ }
208+
209+ let stage_proto = ExecutionStageProto {
210+ query_id : query_id. clone ( ) ,
211+ num : 1 ,
212+ name : format ! ( "test_stage_{}" , 1 ) ,
213+ plan : Some ( Box :: new ( create_mock_physical_plan_proto (
214+ num_partitions_per_task,
215+ ) ) ) ,
216+ inputs : vec ! [ ] ,
217+ tasks,
218+ } ;
219+
220+ let task_keys = vec ! [
221+ StageKey {
222+ query_id: query_id_uuid. to_string( ) ,
223+ stage_id,
224+ task_number: 0 ,
225+ } ,
226+ StageKey {
227+ query_id: query_id_uuid. to_string( ) ,
228+ stage_id,
229+ task_number: 1 ,
230+ } ,
231+ StageKey {
232+ query_id: query_id_uuid. to_string( ) ,
233+ stage_id,
234+ task_number: 2 ,
235+ } ,
236+ ] ;
237+
238+ let stage_proto_for_closure = stage_proto. clone ( ) ;
239+ let endpoint_ref = & endpoint;
240+ let do_get = async move |partition : u64 , task_number : u64 , stage_key : StageKey | {
241+ let stage_proto = stage_proto_for_closure. clone ( ) ;
242+ // Create DoGet message
243+ let doget = DoGet {
244+ stage_proto : Some ( stage_proto) ,
245+ task_number,
246+ partition,
247+ stage_key : Some ( stage_key) ,
248+ } ;
249+
250+ // Create Flight ticket
251+ let ticket = Ticket {
252+ ticket : Bytes :: from ( doget. encode_to_vec ( ) ) ,
253+ } ;
254+
255+ // Call the actual get() method
256+ let request = Request :: new ( ticket) ;
257+ endpoint_ref. get ( request) . await
258+ } ;
259+
260+ // For each task, call do_get() for each partition except the last.
261+ for task_number in 0 ..num_tasks {
262+ for partition in 0 ..num_partitions_per_task - 1 {
263+ let result = do_get (
264+ partition as u64 ,
265+ task_number,
266+ task_keys[ task_number as usize ] . clone ( ) ,
267+ )
268+ . await ;
269+ assert ! ( result. is_ok( ) ) ;
270+ }
271+ }
272+
273+ // Check that the endpoint has not evicted any task states.
274+ assert_eq ! ( endpoint. stages. len( ) , num_tasks as usize ) ;
275+
276+ // Run the last partition of task 0. Any partition number works. Verify that the task state
277+ // is evicted because all partitions have been processed.
278+ let result = do_get ( 1 , 0 , task_keys[ 0 ] . clone ( ) ) . await ;
279+ assert ! ( result. is_ok( ) ) ;
280+ let stored_stage_keys = endpoint. stages . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
281+ assert_eq ! ( stored_stage_keys. len( ) , 2 ) ;
282+ assert ! ( stored_stage_keys. contains( & task_keys[ 1 ] ) ) ;
283+ assert ! ( stored_stage_keys. contains( & task_keys[ 2 ] ) ) ;
284+
285+ // Run the last partition of task 1.
286+ let result = do_get ( 1 , 1 , task_keys[ 1 ] . clone ( ) ) . await ;
287+ assert ! ( result. is_ok( ) ) ;
288+ let stored_stage_keys = endpoint. stages . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
289+ assert_eq ! ( stored_stage_keys. len( ) , 1 ) ;
290+ assert ! ( stored_stage_keys. contains( & task_keys[ 2 ] ) ) ;
291+
292+ // Run the last partition of the last task.
293+ let result = do_get ( 1 , 2 , task_keys[ 2 ] . clone ( ) ) . await ;
294+ assert ! ( result. is_ok( ) ) ;
295+ let stored_stage_keys = endpoint. stages . keys ( ) . collect :: < Vec < StageKey > > ( ) ;
296+ assert_eq ! ( stored_stage_keys. len( ) , 0 ) ;
297+ }
298+
299+ // Helper to create a mock physical plan proto
300+ fn create_mock_physical_plan_proto (
301+ partitions : usize ,
302+ ) -> datafusion_proto:: protobuf:: PhysicalPlanNode {
303+ use datafusion_proto:: protobuf:: partitioning:: PartitionMethod ;
304+ use datafusion_proto:: protobuf:: {
305+ Partitioning , PhysicalPlanNode , RepartitionExecNode , Schema ,
306+ } ;
307+
308+ // Create a repartition node that will have the desired partition count
309+ PhysicalPlanNode {
310+ physical_plan_type : Some (
311+ datafusion_proto:: protobuf:: physical_plan_node:: PhysicalPlanType :: Repartition (
312+ Box :: new ( RepartitionExecNode {
313+ input : Some ( Box :: new ( PhysicalPlanNode {
314+ physical_plan_type : Some (
315+ datafusion_proto:: protobuf:: physical_plan_node:: PhysicalPlanType :: Empty (
316+ datafusion_proto:: protobuf:: EmptyExecNode {
317+ schema : Some ( Schema {
318+ columns : vec ! [ ] ,
319+ metadata : std:: collections:: HashMap :: new ( ) ,
320+ } )
321+ }
322+ )
323+ ) ,
324+ } ) ) ,
325+ partitioning : Some ( Partitioning {
326+ partition_method : Some ( PartitionMethod :: RoundRobin ( partitions as u64 ) ) ,
327+ } ) ,
328+ } )
329+ )
330+ ) ,
331+ }
141332 }
142333}
0 commit comments