@@ -14,6 +14,7 @@ use arrow_flight::Ticket;
1414use datafusion:: execution:: SessionState ;
1515use futures:: TryStreamExt ;
1616use prost:: Message ;
17+ use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
1718use std:: sync:: Arc ;
1819use tokio:: sync:: OnceCell ;
1920use tonic:: metadata:: MetadataMap ;
@@ -38,6 +39,13 @@ pub struct DoGet {
3839 pub stage_key : Option < StageKey > ,
3940}
4041
42+ #[ derive( Clone ) ]
43+ pub struct TaskData {
44+ pub ( super ) state : SessionState ,
45+ pub ( super ) stage : Arc < ExecutionStage > ,
46+ partition_count : Arc < AtomicUsize > ,
47+ }
48+
4149impl ArrowFlightEndpoint {
4250 pub ( super ) async fn get (
4351 & self ,
@@ -51,7 +59,10 @@ impl ArrowFlightEndpoint {
5159
5260 let partition = doget. partition as usize ;
5361 let task_number = doget. task_number as usize ;
54- let ( mut state, stage) = self . get_state_and_stage ( doget, metadata) . await ?;
62+ let task_data = self . get_state_and_stage ( doget, metadata) . await ?;
63+
64+ let stage = task_data. stage ;
65+ let mut state = task_data. state ;
5566
5667 // find out which partition group we are executing
5768 let task = stage
@@ -91,15 +102,15 @@ impl ArrowFlightEndpoint {
91102 & self ,
92103 doget : DoGet ,
93104 metadata_map : MetadataMap ,
94- ) -> Result < ( SessionState , Arc < ExecutionStage > ) , Status > {
105+ ) -> Result < TaskData , Status > {
95106 let key = doget
96107 . stage_key
97108 . ok_or ( Status :: invalid_argument ( "DoGet is missing the stage key" ) ) ?;
98- let once_stage = self . stages . get_or_init ( key , || {
99- OnceCell :: < ( SessionState , Arc < ExecutionStage > ) > :: new ( )
100- } ) ;
109+ let once_stage = self
110+ . stages
111+ . get_or_init ( key . clone ( ) , || OnceCell :: < TaskData > :: new ( ) ) ;
101112
102- let stage_data = once_stage
113+ let mut stage_data = once_stage
103114 . get_or_try_init ( || async {
104115 let stage_proto = doget
105116 . stage_proto
@@ -134,10 +145,19 @@ impl ArrowFlightEndpoint {
134145 config. set_extension ( stage. clone ( ) ) ;
135146 config. set_extension ( Arc :: new ( ContextGrpcMetadata ( headers) ) ) ;
136147
137- Ok :: < _ , Status > ( ( state, stage) )
148+ Ok :: < _ , Status > ( TaskData {
149+ state,
150+ stage,
151+ partition_count : Arc :: new ( AtomicUsize :: new ( 0 ) ) ,
152+ } )
138153 } )
139154 . await ?;
140155
156+ stage_data. partition_count . fetch_sub ( 1 , Ordering :: SeqCst ) ;
157+ if stage_data. partition_count . load ( Ordering :: SeqCst ) <= 0 {
158+ self . stages . remove ( key. clone ( ) ) ;
159+ }
160+
141161 Ok ( stage_data. clone ( ) )
142162 }
143163}
0 commit comments