1313use kyron_foundation:: prelude:: * ;
1414use kyron_foundation:: { not_recoverable_error, prelude:: CommonErrors } ;
1515
16+ use crate :: scheduler:: task:: task_context:: TaskContext ;
1617use crate :: {
1718 futures:: { FutureInternalReturn , FutureState } ,
1819 TaskRef ,
@@ -74,34 +75,52 @@ impl<T: Send + 'static> Future for JoinHandle<T> {
7475 if was_set {
7576 FutureInternalReturn :: default ( )
7677 } else {
78+ // Check whether there is safety error for the completed task and this task is running on async worker
79+ // if this task is already running on safety worker/dedicated worker, do not set the flag to schedule on safety worker.
80+ if self . for_task . get_task_safety_error ( ) && TaskContext :: is_task_running_on_async_worker ( ) {
81+ // Set the flag to wake this task into safety worker
82+ TaskContext :: set_flag_to_wake_parent_task_into_safety ( ) ;
83+ waker. wake_by_ref ( ) ;
84+ FutureInternalReturn :: polled ( )
85+ } else {
86+ let mut ret: Result < T , CommonErrors > = Err ( CommonErrors :: NoData ) ;
87+ let ret_as_ptr = & mut ret as * mut _ ;
88+ self . for_task . get_return_val ( ret_as_ptr as * mut u8 ) ;
89+
90+ match ret {
91+ Ok ( v) => FutureInternalReturn :: ready ( Ok ( v) ) ,
92+ Err ( CommonErrors :: OperationAborted ) => FutureInternalReturn :: ready ( Err ( CommonErrors :: OperationAborted ) ) ,
93+ Err ( e) => {
94+ not_recoverable_error ! ( with e, "There has been an error in a task that is not recoverable ({})!" ) ;
95+ }
96+ }
97+ }
98+ }
99+ }
100+ FutureState :: Polled => {
101+ let waker = cx. waker ( ) ;
102+
103+ // Set the waker, return values tells what have happen and took care about correct synchronization
104+ let was_set = self . for_task . set_join_handle_waker ( waker. clone ( ) ) ;
105+
106+ if was_set {
107+ FutureInternalReturn :: default ( )
108+ } else {
109+ // Safety belows forms AqrRel so waker is really written before we do marking
77110 let mut ret: Result < T , CommonErrors > = Err ( CommonErrors :: NoData ) ;
78111 let ret_as_ptr = & mut ret as * mut _ ;
79112 self . for_task . get_return_val ( ret_as_ptr as * mut u8 ) ;
80113
81114 match ret {
82115 Ok ( v) => FutureInternalReturn :: ready ( Ok ( v) ) ,
116+ Err ( CommonErrors :: NoData ) => FutureInternalReturn :: polled ( ) ,
83117 Err ( CommonErrors :: OperationAborted ) => FutureInternalReturn :: ready ( Err ( CommonErrors :: OperationAborted ) ) ,
84118 Err ( e) => {
85119 not_recoverable_error ! ( with e, "There has been an error in a task that is not recoverable ({})!" ) ;
86120 }
87121 }
88122 }
89123 }
90- FutureState :: Polled => {
91- // Safety belows forms AqrRel so waker is really written before we do marking
92- let mut ret: Result < T , CommonErrors > = Err ( CommonErrors :: NoData ) ;
93- let ret_as_ptr = & mut ret as * mut _ ;
94- self . for_task . get_return_val ( ret_as_ptr as * mut u8 ) ;
95-
96- match ret {
97- Ok ( v) => FutureInternalReturn :: ready ( Ok ( v) ) ,
98- Err ( CommonErrors :: NoData ) => FutureInternalReturn :: polled ( ) ,
99- Err ( CommonErrors :: OperationAborted ) => FutureInternalReturn :: ready ( Err ( CommonErrors :: OperationAborted ) ) ,
100- Err ( e) => {
101- not_recoverable_error ! ( with e, "There has been an error in a task that is not recoverable ({})!" ) ;
102- }
103- }
104- }
105124 FutureState :: Finished => {
106125 not_recoverable_error ! ( "Future polled after it finished!" ) ;
107126 }
@@ -256,6 +275,40 @@ mod tests {
256275 assert_eq ! ( poller. poll( ) , :: core:: task:: Poll :: Ready ( Ok ( 0 ) ) ) ;
257276 }
258277 }
278+
279+ #[ test]
280+ fn test_join_handle_waker_is_set_in_polled_state_also ( ) {
281+ let scheduler = create_mock_scheduler ( ) ;
282+
283+ {
284+ // Data is present before first poll of join handle
285+ let task = ArcInternal :: new ( AsyncTask :: new ( box_future ( test_function :: < u32 > ( ) ) , 1 , scheduler. clone ( ) ) ) ;
286+
287+ let handle = JoinHandle :: < u32 > :: new ( TaskRef :: new ( task. clone ( ) ) ) ;
288+
289+ let mut poller = TestingFuturePoller :: new ( handle) ;
290+
291+ let waker_mock1 = TrackableWaker :: new ( ) ;
292+ let waker1 = waker_mock1. get_waker ( ) ;
293+
294+ let waker_mock2 = TrackableWaker :: new ( ) ;
295+ let waker2 = waker_mock2. get_waker ( ) ;
296+
297+ let _ = poller. poll_with_waker ( & waker1) ;
298+ // Now in polled state, poll again with waker2
299+ let _ = poller. poll_with_waker ( & waker2) ;
300+ {
301+ let waker = noop_waker ( ) ;
302+ let mut cx = Context :: from_waker ( & waker) ;
303+ task. poll ( & mut cx) ; // task done
304+ }
305+
306+ assert ! ( !waker_mock1. was_waked( ) ) ;
307+ // this should be TRUE
308+ assert ! ( waker_mock2. was_waked( ) ) ;
309+ assert_eq ! ( poller. poll( ) , :: core:: task:: Poll :: Ready ( Ok ( 0 ) ) ) ;
310+ }
311+ }
259312}
260313
261314#[ cfg( test) ]
@@ -277,8 +330,9 @@ mod tests {
277330
278331 #[ test]
279332 fn test_join_handler_mt_get_result ( ) {
280- let builder = Builder :: new ( ) ;
281-
333+ let mut builder = Builder :: new ( ) ;
334+ // Limit preemption to avoid loom error "Model exceeded maximum number of branches."
335+ builder. preemption_bound = Some ( 4 ) ;
282336 builder. check ( || {
283337 let scheduler = create_mock_scheduler ( ) ;
284338
@@ -299,22 +353,17 @@ mod tests {
299353
300354 let waker_mock = TrackableWaker :: new ( ) ;
301355 let waker = waker_mock. get_waker ( ) ;
302- let mut was_pending = false ;
303-
304356 loop {
305357 match poller. poll_with_waker ( & waker) {
306358 Poll :: Ready ( v) => {
307359 assert_eq ! ( v, Ok ( 1234 ) ) ;
308-
309- if was_pending {
310- assert ! ( waker_mock. was_waked( ) ) ;
311- }
360+ // Note:
361+ // Cannot check whether the waker was woken or not since the waker is set in the join handle poll every time if task is not yet done.
362+ // So depending on the interleaving, the task may finish before the waker is set.
312363
313364 break ;
314365 }
315- Poll :: Pending => {
316- was_pending = true ;
317- }
366+ Poll :: Pending => { }
318367 }
319368 loom:: hint:: spin_loop ( ) ;
320369 }
0 commit comments