1313
1414use super :: task_state:: * ;
1515use crate :: core:: types:: * ;
16- use crate :: scheduler:: safety_waker:: create_safety_waker;
1716use crate :: scheduler:: scheduler_mt:: SchedulerTrait ;
1817use crate :: scheduler:: workers:: worker_types:: WorkerId ;
18+ use :: core:: cell:: Cell ;
1919use :: core:: future:: Future ;
2020use :: core:: mem;
2121use :: core:: ops:: { Deref , DerefMut } ;
@@ -82,7 +82,7 @@ pub(crate) enum TaskStage<T, ResultType> {
8282pub ( crate ) struct TaskHeader {
8383 pub ( in crate :: scheduler) state : TaskState ,
8484 id : TaskId ,
85-
85+ is_safety_error : Cell < bool > , // Flag to indicate whether task resulted in safety error
8686 vtable : & ' static TaskVTable , // API entrypoint to typed task
8787}
8888
@@ -97,6 +97,7 @@ impl TaskHeader {
9797 Self {
9898 state : TaskState :: new ( ) ,
9999 id : TaskId :: new ( worker_id) ,
100+ is_safety_error : Cell :: new ( false ) ,
100101 vtable : create_task_vtable :: < T , AllocatedFuture , SchedulerType > ( ) ,
101102 }
102103 }
@@ -111,9 +112,18 @@ impl TaskHeader {
111112 Self {
112113 state : TaskState :: new ( ) ,
113114 id : TaskId :: new ( worker_id) ,
115+ is_safety_error : Cell :: new ( false ) ,
114116 vtable : create_task_s_vtable :: < T , E , AllocatedFuture , SchedulerType > ( ) ,
115117 }
116118 }
119+
120+ pub ( crate ) fn set_safety_error ( & self ) {
121+ self . is_safety_error . set ( true ) ;
122+ }
123+
124+ pub ( crate ) fn get_safety_error ( & self ) -> bool {
125+ self . is_safety_error . get ( )
126+ }
117127}
118128
119129#[ derive( PartialEq , Debug ) ]
@@ -203,13 +213,20 @@ where
203213 }
204214
205215 pub ( crate ) fn set_waker ( & self , waker : Waker ) -> bool {
206- unsafe {
207- self . handle_waker . with_mut ( |ptr| {
208- * ptr = Some ( waker) ;
209- } )
216+ // Safety: Unset join handle flag before setting waker, the flag would have been set previously in the first poll of join handle.
217+ // If flag is not cleared, another worker finishing the task will see the flag set and
218+ // read the waker to call wake() while it is written here
219+ if self . header . state . unset_join_handle ( ) {
220+ unsafe {
221+ self . handle_waker . with_mut ( |ptr| {
222+ * ptr = Some ( waker) ;
223+ } )
224+ } ;
225+
226+ return self . header . state . set_join_handle ( ) ;
210227 }
211228
212- self . header . state . set_waker ( ) // Safety: makes sure storing waker is not reordered behind this operation
229+ false
213230 }
214231
215232 ///
@@ -308,12 +325,10 @@ where
308325 self . handle_waker . with_mut ( |ptr : * mut Option < Waker > | match unsafe { ( * ptr) . take ( ) } {
309326 Some ( v) => {
310327 if is_safety_err && self . is_with_safety {
311- unsafe {
312- create_safety_waker ( v) . wake ( ) ;
313- }
314- } else {
315- v. wake ( ) ;
328+ // Set saftey error flag which will be checked in wake()/wkae_by_ref() to schedule parent task into safety worker
329+ self . header . set_safety_error ( ) ;
316330 }
331+ v. wake ( ) ;
317332 }
318333 None => not_recoverable_error ! ( "We shall never be here if we have HadConnectedJoinHandle set!" ) ,
319334 } )
@@ -614,6 +629,10 @@ impl TaskRef {
614629 snapshot. is_completed ( ) || snapshot. is_canceled ( )
615630 }
616631
632+ pub ( crate ) fn get_task_safety_error ( & self ) -> bool {
633+ unsafe { self . header . as_ref ( ) . get_safety_error ( ) }
634+ }
635+
617636 pub ( crate ) fn id ( & self ) -> TaskId {
618637 unsafe { self . header . as_ref ( ) . id }
619638 }
@@ -641,7 +660,10 @@ mod tests {
641660 safety:: SafetyResult ,
642661 scheduler:: {
643662 scheduler_mt:: SchedulerTrait ,
644- task:: async_task:: { TaskId , TaskRef } ,
663+ task:: {
664+ async_task:: { TaskId , TaskRef } ,
665+ task_context:: TaskContextGuard ,
666+ } ,
645667 } ,
646668 testing:: * ,
647669 } ;
@@ -778,6 +800,7 @@ mod tests {
778800 safety_task_ref. set_join_handle_waker ( waker. clone ( ) ) ; // Mimic that JoinHandler is set
779801
780802 let mut ctx = Context :: from_waker ( & waker) ;
803+ let _guard = TaskContextGuard :: new ( safety_task_ref. clone ( ) ) ;
781804 assert_eq ! ( safety_task_ref. poll( & mut ctx) , TaskPollResult :: Done ) ;
782805
783806 let mut result: SafetyResult < bool , bool > = Ok ( true ) ;
@@ -801,6 +824,7 @@ mod tests {
801824 safety_task_ref. set_join_handle_waker ( waker. clone ( ) ) ; // Mimic that JoinHandler is set
802825
803826 let mut ctx = Context :: from_waker ( & waker) ;
827+ let _guard = TaskContextGuard :: new ( safety_task_ref. clone ( ) ) ;
804828 assert_eq ! ( safety_task_ref. poll( & mut ctx) , TaskPollResult :: Done ) ;
805829
806830 let mut result: SafetyResult < bool , bool > = Ok ( true ) ;
@@ -829,6 +853,7 @@ mod tests {
829853 safety_task_ref. set_join_handle_waker ( waker. clone ( ) ) ; // Mimic that JoinHandler is set
830854
831855 let mut ctx = Context :: from_waker ( & waker) ;
856+ let _guard = TaskContextGuard :: new ( safety_task_ref. clone ( ) ) ;
832857 assert_eq ! ( safety_task_ref. poll( & mut ctx) , TaskPollResult :: Done ) ;
833858
834859 let mut result: SafetyResult < bool , bool > = Ok ( true ) ;
@@ -857,6 +882,7 @@ mod tests {
857882 safety_task_ref. set_join_handle_waker ( waker. clone ( ) ) ; // Mimic that JoinHandler is set
858883
859884 let mut ctx = Context :: from_waker ( & waker) ;
885+ let _guard = TaskContextGuard :: new ( safety_task_ref. clone ( ) ) ;
860886 assert_eq ! ( safety_task_ref. poll( & mut ctx) , TaskPollResult :: Done ) ;
861887
862888 let mut result: SafetyResult < bool , bool > = Ok ( true ) ;
0 commit comments