@@ -69,7 +69,7 @@ use std::any::Any;
6969use std:: collections:: VecDeque ;
7070use std:: marker:: PhantomData ;
7171use std:: panic:: { AssertUnwindSafe , catch_unwind, resume_unwind} ;
72- use std:: sync:: atomic:: { AtomicU32 , Ordering } ;
72+ use std:: sync:: atomic:: { AtomicBool , AtomicU32 , AtomicUsize , Ordering } ;
7373use std:: sync:: mpsc:: { SyncSender , TrySendError } ;
7474use std:: sync:: { Arc , Condvar , Mutex } ;
7575
@@ -87,25 +87,19 @@ where
8787 F : for < ' scope > FnOnce ( & ' scope Scope < ' scope , ' env > ) -> T ,
8888{
8989 let scope = Scope {
90- data : Arc :: new ( ScopeData {
91- mutex : Mutex :: new ( ScopeCounts :: default ( ) ) ,
92- condvar : Condvar :: new ( ) ,
93- } ) ,
90+ running : Arc :: default ( ) ,
91+ panicked : Arc :: default ( ) ,
9492 _scope : PhantomData ,
9593 _env : PhantomData ,
9694 } ;
9795
9896 let result = catch_unwind ( AssertUnwindSafe ( || f ( & scope) ) ) ;
9997
100- // Wait for tasks to finish
101- let mut guard = scope. data . mutex . lock ( ) . unwrap ( ) ;
102- while guard. running > 0 {
103- guard = scope. data . condvar . wait ( guard) . unwrap ( ) ;
104- }
98+ scope. running . wait_for_tasks ( ) ;
10599
106100 match result {
107101 Err ( e) => resume_unwind ( e) ,
108- Ok ( _) if guard . unhandled_panics > 0 => panic ! ( "scoped task panicked" ) ,
102+ Ok ( _) if scope . panicked . did_panic ( ) => panic ! ( "scoped task panicked" ) ,
109103 Ok ( x) => x,
110104 }
111105}
@@ -124,7 +118,8 @@ where
124118#[ derive( Debug ) ]
125119#[ expect( clippy:: struct_field_names) ]
126120pub struct Scope < ' scope , ' env : ' scope > {
127- data : Arc < ScopeData > ,
121+ running : Arc < ScopeRunning > ,
122+ panicked : Arc < ScopePanicked > ,
128123 // &'scope mut &'scope is needed to prevent lifetimes from shrinking
129124 _scope : PhantomData < & ' scope mut & ' scope ( ) > ,
130125 _env : PhantomData < & ' env mut & ' env ( ) > ,
@@ -162,7 +157,7 @@ impl<'scope> Scope<'scope, '_> {
162157 Some ( handle)
163158 } else {
164159 // Closure will never be run
165- self . data . task_end ( ) ;
160+ self . running . task_finished ( ) ;
166161
167162 None
168163 }
@@ -205,109 +200,144 @@ impl<'scope> Scope<'scope, '_> {
205200 F : FnOnce ( ) -> T + Send + ' scope ,
206201 T : Send + ' scope ,
207202 {
208- self . data . task_start ( ) ;
203+ self . running . task_created ( ) ;
209204
210205 let handle = ScopedJoinHandle {
211- data : Arc :: new ( HandleData {
206+ data : Arc :: new ( TaskResult {
212207 mutex : Mutex :: new ( None ) ,
213208 condvar : Condvar :: new ( ) ,
209+ scope_panicked : self . panicked . clone ( ) ,
214210 } ) ,
215- scope_data : self . data . clone ( ) ,
216211 _scope : PhantomData ,
217212 } ;
218213
219- let handle_data = handle. data . clone ( ) ;
220- let scope_data = self . data . clone ( ) ;
221- let closure: Box < dyn FnOnce ( ) + Send + ' scope > = Box :: new (
222- #[ inline( never) ]
223- move || {
224- let result = catch_unwind ( AssertUnwindSafe ( f) ) ;
225-
226- if result. is_err ( ) {
227- // Updating the panic count must happen before updating the handle data, to
228- // avoid the handle being joined in another thread which then tries to decrement
229- // the unhandled panic count before it is incremented
230- scope_data. task_panicked ( ) ;
231- }
232-
233- // Send the result to ScopedJoinHandle and wake any blocked threads
234- let HandleData { mutex, condvar } = handle_data. as_ref ( ) ;
235- let mut guard = mutex. lock ( ) . unwrap ( ) ;
236- * guard = Some ( result) ;
237- condvar. notify_all ( ) ;
238- } ,
239- ) ;
214+ let task_result = handle. data . clone ( ) ;
215+ let scope_running = self . running . clone ( ) ;
216+ let closure: Box < dyn FnOnce ( ) + Send + ' scope > = Box :: new ( move || {
217+ task_result. store ( catch_unwind ( AssertUnwindSafe ( f) ) ) ;
218+
219+ // If the JoinHandle has already been dropped, this will drop the TaskResult inside the
220+ // Arc, dropping the result and storing if an unhandled panic occurred.
221+ drop ( task_result) ;
222+
223+ // Mark the task as finished after all the borrows from the environment are dropped.
224+ scope_running. task_finished ( ) ;
225+ } ) ;
240226
241227 // SAFETY: The `scope` function ensures all closures are finished before returning
242228 let closure = unsafe {
243229 #[ expect( clippy:: unnecessary_cast, reason = "casting lifetimes" ) ]
244230 Box :: from_raw ( Box :: into_raw ( closure) as * mut ( dyn FnOnce ( ) + Send + ' static ) )
245231 } ;
246232
247- let scope_data = self . data . clone ( ) ;
248- let task_closure = Box :: new ( move || {
249- // Use a second closure to ensure that the closure which borrows from 'scope is dropped
250- // before `ScopeData::task_end` is called. This prevents `scope()` from returning while
251- // the inner closure still exists, which causes UB as detected by Miri.
252- closure ( ) ;
253- scope_data. task_end ( ) ;
254- } ) ;
255-
256- ( task_closure, handle)
233+ ( closure, handle)
257234 }
258235}
259236
260- // Stores the number of currently running tasks and unhandled panics.
261- #[ derive( Debug ) ]
262- struct ScopeData {
263- mutex : Mutex < ScopeCounts > ,
264- condvar : Condvar ,
237+ /// Stores the number of currently running tasks.
238+ #[ derive( Debug , Default ) ]
239+ struct ScopeRunning {
240+ counter : AtomicUsize ,
241+ wait_mutex : Mutex < ( ) > ,
242+ wait_condvar : Condvar ,
265243}
266244
245+ impl ScopeRunning {
246+ fn task_created ( & self ) {
247+ self . counter . fetch_add ( 1 , Ordering :: AcqRel ) ;
248+ }
249+
250+ fn task_finished ( & self ) {
251+ let prev = self . counter . fetch_sub ( 1 , Ordering :: AcqRel ) ;
252+ if prev == 1 {
253+ self . wait_condvar . notify_all ( ) ;
254+ } else if prev == 0 {
255+ panic ! ( "more tasks finished than started?" )
256+ }
257+ }
258+
259+ fn wait_for_tasks ( & self ) {
260+ let mut guard = self . wait_mutex . lock ( ) . unwrap ( ) ;
261+ while self . counter . load ( Ordering :: Acquire ) > 0 {
262+ guard = self . wait_condvar . wait ( guard) . unwrap ( ) ;
263+ }
264+ }
265+ }
266+
267+ /// Stores whether any of the tasks panicked.
267268#[ derive( Debug , Default ) ]
268- struct ScopeCounts {
269- running : usize ,
270- unhandled_panics : usize ,
269+ struct ScopePanicked {
270+ value : AtomicBool ,
271+ }
272+
273+ impl ScopePanicked {
274+ fn store_panic ( & self ) {
275+ self . value . store ( true , Ordering :: Release ) ;
276+ }
277+
278+ fn did_panic ( & self ) -> bool {
279+ self . value . load ( Ordering :: Acquire )
280+ }
271281}
272282
273- impl ScopeData {
274- fn task_start ( & self ) {
283+ /// Stores the result of a task, ensuring the result is dropped safely and [`ScopePanicked`] is
284+ /// updated.
285+ #[ derive( Debug ) ]
286+ struct TaskResult < T > {
287+ mutex : Mutex < Option < Result < T , Box < dyn Any + Send + ' static > > > > ,
288+ condvar : Condvar ,
289+ scope_panicked : Arc < ScopePanicked > ,
290+ }
291+
292+ impl < T > TaskResult < T > {
293+ fn store ( & self , result : Result < T , Box < dyn Any + Send + ' static > > ) {
275294 let mut guard = self . mutex . lock ( ) . unwrap ( ) ;
276- if let Some ( new_running) = guard. running . checked_add ( 1 ) {
277- guard. running = new_running;
278- } else {
279- panic ! ( "too many running tasks in scope" ) ;
280- }
295+ * guard = Some ( result) ;
296+ self . condvar . notify_all ( ) ;
281297 }
282298
283- fn task_end ( & self ) {
299+ fn wait_and_take ( & self ) -> Result < T , Box < dyn Any + Send + ' static > > {
284300 let mut guard = self . mutex . lock ( ) . unwrap ( ) ;
285- if let Some ( new_running) = guard. running . checked_sub ( 1 ) {
286- guard. running = new_running;
287- if new_running == 0 {
288- self . condvar . notify_all ( ) ;
301+ loop {
302+ if let Some ( result) = guard. take ( ) {
303+ return result;
289304 }
290- } else {
291- panic ! ( "more tasks finished than started?" )
305+ guard = self . condvar . wait ( guard) . unwrap ( ) ;
292306 }
293307 }
294308
295- fn task_panicked ( & self ) {
296- let mut guard = self . mutex . lock ( ) . unwrap ( ) ;
297- if let Some ( new_panicked) = guard. unhandled_panics . checked_add ( 1 ) {
298- guard. unhandled_panics = new_panicked;
299- } else {
300- panic ! ( "too many panicking tasks in scope" ) ;
301- }
309+ fn is_finished ( & self ) -> bool {
310+ self . mutex . lock ( ) . unwrap ( ) . is_some ( )
302311 }
312+ }
303313
304- fn panic_joined ( & self ) {
305- let mut guard = self . mutex . lock ( ) . unwrap ( ) ;
306- if let Some ( new_panicked) = guard. unhandled_panics . checked_sub ( 1 ) {
307- guard. unhandled_panics = new_panicked;
308- } else {
309- panic ! ( "more panics joined than tasks panicked?" )
314+ impl < T > Drop for TaskResult < T > {
315+ #[ expect( clippy:: print_stderr) ]
316+ fn drop ( & mut self ) {
317+ let Some ( result) = self
318+ . mutex
319+ . get_mut ( )
320+ . expect ( "worker panicked while storing result" )
321+ . take ( )
322+ else {
323+ return ; // Result was already taken and handled
324+ } ;
325+
326+ let panic;
327+ match result {
328+ Ok ( v) => match catch_unwind ( AssertUnwindSafe ( || drop ( v) ) ) {
329+ Ok ( ( ) ) => return ,
330+ Err ( e) => panic = e,
331+ } ,
332+ Err ( e) => panic = e,
333+ }
334+
335+ if let Err ( _panic) = catch_unwind ( AssertUnwindSafe ( || drop ( panic) ) ) {
336+ eprintln ! ( "panic while dropping scoped task panic" ) ;
337+ std:: process:: abort ( ) ;
310338 }
339+
340+ self . scope_panicked . store_panic ( ) ;
311341 }
312342}
313343
@@ -318,44 +348,22 @@ impl ScopeData {
318348/// threads.
319349#[ derive( Debug ) ]
320350pub struct ScopedJoinHandle < ' scope , T > {
321- data : Arc < HandleData < T > > ,
322- scope_data : Arc < ScopeData > ,
351+ data : Arc < TaskResult < T > > ,
323352 _scope : PhantomData < & ' scope mut & ' scope ( ) > ,
324353}
325354
326- #[ derive( Debug ) ]
327- struct HandleData < T > {
328- mutex : Mutex < Option < Result < T , Box < dyn Any + Send + ' static > > > > ,
329- condvar : Condvar ,
330- }
331-
332355impl < T > ScopedJoinHandle < ' _ , T > {
333356 /// Wait for the task to finish.
334357 ///
335358 /// The [`Err`] variant contains the panic value if the task panicked.
336359 pub fn join ( self ) -> Result < T , Box < dyn Any + Send + ' static > > {
337- let result = {
338- let HandleData { mutex, condvar } = self . data . as_ref ( ) ;
339- let mut guard = mutex. lock ( ) . unwrap ( ) ;
340- loop {
341- if let Some ( result) = guard. take ( ) {
342- break result;
343- }
344- guard = condvar. wait ( guard) . unwrap ( ) ;
345- }
346- } ;
347-
348- if result. is_err ( ) {
349- self . scope_data . panic_joined ( ) ;
350- }
351-
352- result
360+ self . data . wait_and_take ( )
353361 }
354362
355363 /// Check if the task is finished.
356364 #[ must_use]
357365 pub fn is_finished ( & self ) -> bool {
358- self . data . mutex . lock ( ) . unwrap ( ) . is_some ( )
366+ self . data . is_finished ( )
359367 }
360368}
361369
@@ -386,7 +394,6 @@ fn try_queue_task(mut closure: Box<dyn FnOnce() + Send>) -> Result<(), Box<dyn F
386394 }
387395 }
388396 }
389- drop ( guard) ;
390397
391398 Err ( closure)
392399}
0 commit comments