Skip to content

Commit 64b0a38

Browse files
committed
Improve scoped tasks panic on drop handling
1 parent 5d1ac9d commit 64b0a38

File tree

1 file changed

+115
-108
lines changed

1 file changed

+115
-108
lines changed

crates/utils/src/multithreading/scoped_tasks.rs

Lines changed: 115 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ use std::any::Any;
6969
use std::collections::VecDeque;
7070
use std::marker::PhantomData;
7171
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
72-
use std::sync::atomic::{AtomicU32, Ordering};
72+
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
7373
use std::sync::mpsc::{SyncSender, TrySendError};
7474
use std::sync::{Arc, Condvar, Mutex};
7575

@@ -87,25 +87,19 @@ where
8787
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> T,
8888
{
8989
let scope = Scope {
90-
data: Arc::new(ScopeData {
91-
mutex: Mutex::new(ScopeCounts::default()),
92-
condvar: Condvar::new(),
93-
}),
90+
running: Arc::default(),
91+
panicked: Arc::default(),
9492
_scope: PhantomData,
9593
_env: PhantomData,
9694
};
9795

9896
let result = catch_unwind(AssertUnwindSafe(|| f(&scope)));
9997

100-
// Wait for tasks to finish
101-
let mut guard = scope.data.mutex.lock().unwrap();
102-
while guard.running > 0 {
103-
guard = scope.data.condvar.wait(guard).unwrap();
104-
}
98+
scope.running.wait_for_tasks();
10599

106100
match result {
107101
Err(e) => resume_unwind(e),
108-
Ok(_) if guard.unhandled_panics > 0 => panic!("scoped task panicked"),
102+
Ok(_) if scope.panicked.did_panic() => panic!("scoped task panicked"),
109103
Ok(x) => x,
110104
}
111105
}
@@ -124,7 +118,8 @@ where
124118
#[derive(Debug)]
125119
#[expect(clippy::struct_field_names)]
126120
pub struct Scope<'scope, 'env: 'scope> {
127-
data: Arc<ScopeData>,
121+
running: Arc<ScopeRunning>,
122+
panicked: Arc<ScopePanicked>,
128123
// &'scope mut &'scope is needed to prevent lifetimes from shrinking
129124
_scope: PhantomData<&'scope mut &'scope ()>,
130125
_env: PhantomData<&'env mut &'env ()>,
@@ -162,7 +157,7 @@ impl<'scope> Scope<'scope, '_> {
162157
Some(handle)
163158
} else {
164159
// Closure will never be run
165-
self.data.task_end();
160+
self.running.task_finished();
166161

167162
None
168163
}
@@ -205,109 +200,144 @@ impl<'scope> Scope<'scope, '_> {
205200
F: FnOnce() -> T + Send + 'scope,
206201
T: Send + 'scope,
207202
{
208-
self.data.task_start();
203+
self.running.task_created();
209204

210205
let handle = ScopedJoinHandle {
211-
data: Arc::new(HandleData {
206+
data: Arc::new(TaskResult {
212207
mutex: Mutex::new(None),
213208
condvar: Condvar::new(),
209+
scope_panicked: self.panicked.clone(),
214210
}),
215-
scope_data: self.data.clone(),
216211
_scope: PhantomData,
217212
};
218213

219-
let handle_data = handle.data.clone();
220-
let scope_data = self.data.clone();
221-
let closure: Box<dyn FnOnce() + Send + 'scope> = Box::new(
222-
#[inline(never)]
223-
move || {
224-
let result = catch_unwind(AssertUnwindSafe(f));
225-
226-
if result.is_err() {
227-
// Updating the panic count must happen before updating the handle data, to
228-
// avoid the handle being joined in another thread which then tries to decrement
229-
// the unhandled panic count before it is incremented
230-
scope_data.task_panicked();
231-
}
232-
233-
// Send the result to ScopedJoinHandle and wake any blocked threads
234-
let HandleData { mutex, condvar } = handle_data.as_ref();
235-
let mut guard = mutex.lock().unwrap();
236-
*guard = Some(result);
237-
condvar.notify_all();
238-
},
239-
);
214+
let task_result = handle.data.clone();
215+
let scope_running = self.running.clone();
216+
let closure: Box<dyn FnOnce() + Send + 'scope> = Box::new(move || {
217+
task_result.store(catch_unwind(AssertUnwindSafe(f)));
218+
219+
// If the JoinHandle has already been dropped, this will drop the TaskResult inside the
220+
// Arc, dropping the result and storing if an unhandled panic occurred.
221+
drop(task_result);
222+
223+
// Mark the task as finished after all the borrows from the environment are dropped.
224+
scope_running.task_finished();
225+
});
240226

241227
// SAFETY: The `scope` function ensures all closures are finished before returning
242228
let closure = unsafe {
243229
#[expect(clippy::unnecessary_cast, reason = "casting lifetimes")]
244230
Box::from_raw(Box::into_raw(closure) as *mut (dyn FnOnce() + Send + 'static))
245231
};
246232

247-
let scope_data = self.data.clone();
248-
let task_closure = Box::new(move || {
249-
// Use a second closure to ensure that the closure which borrows from 'scope is dropped
250-
// before `ScopeData::task_end` is called. This prevents `scope()` from returning while
251-
// the inner closure still exists, which causes UB as detected by Miri.
252-
closure();
253-
scope_data.task_end();
254-
});
255-
256-
(task_closure, handle)
233+
(closure, handle)
257234
}
258235
}
259236

260-
// Stores the number of currently running tasks and unhandled panics.
261-
#[derive(Debug)]
262-
struct ScopeData {
263-
mutex: Mutex<ScopeCounts>,
264-
condvar: Condvar,
237+
/// Stores the number of currently running tasks.
238+
#[derive(Debug, Default)]
239+
struct ScopeRunning {
240+
counter: AtomicUsize,
241+
wait_mutex: Mutex<()>,
242+
wait_condvar: Condvar,
265243
}
266244

245+
impl ScopeRunning {
246+
fn task_created(&self) {
247+
self.counter.fetch_add(1, Ordering::AcqRel);
248+
}
249+
250+
fn task_finished(&self) {
251+
let prev = self.counter.fetch_sub(1, Ordering::AcqRel);
252+
if prev == 1 {
253+
self.wait_condvar.notify_all();
254+
} else if prev == 0 {
255+
panic!("more tasks finished than started?")
256+
}
257+
}
258+
259+
fn wait_for_tasks(&self) {
260+
let mut guard = self.wait_mutex.lock().unwrap();
261+
while self.counter.load(Ordering::Acquire) > 0 {
262+
guard = self.wait_condvar.wait(guard).unwrap();
263+
}
264+
}
265+
}
266+
267+
/// Stores whether any of the tasks panicked.
267268
#[derive(Debug, Default)]
268-
struct ScopeCounts {
269-
running: usize,
270-
unhandled_panics: usize,
269+
struct ScopePanicked {
270+
value: AtomicBool,
271+
}
272+
273+
impl ScopePanicked {
274+
fn store_panic(&self) {
275+
self.value.store(true, Ordering::Release);
276+
}
277+
278+
fn did_panic(&self) -> bool {
279+
self.value.load(Ordering::Acquire)
280+
}
271281
}
272282

273-
impl ScopeData {
274-
fn task_start(&self) {
283+
/// Stores the result of a task, ensuring the result is dropped safely and [`ScopePanicked`] is
284+
/// updated.
285+
#[derive(Debug)]
286+
struct TaskResult<T> {
287+
mutex: Mutex<Option<Result<T, Box<dyn Any + Send + 'static>>>>,
288+
condvar: Condvar,
289+
scope_panicked: Arc<ScopePanicked>,
290+
}
291+
292+
impl<T> TaskResult<T> {
293+
fn store(&self, result: Result<T, Box<dyn Any + Send + 'static>>) {
275294
let mut guard = self.mutex.lock().unwrap();
276-
if let Some(new_running) = guard.running.checked_add(1) {
277-
guard.running = new_running;
278-
} else {
279-
panic!("too many running tasks in scope");
280-
}
295+
*guard = Some(result);
296+
self.condvar.notify_all();
281297
}
282298

283-
fn task_end(&self) {
299+
fn wait_and_take(&self) -> Result<T, Box<dyn Any + Send + 'static>> {
284300
let mut guard = self.mutex.lock().unwrap();
285-
if let Some(new_running) = guard.running.checked_sub(1) {
286-
guard.running = new_running;
287-
if new_running == 0 {
288-
self.condvar.notify_all();
301+
loop {
302+
if let Some(result) = guard.take() {
303+
return result;
289304
}
290-
} else {
291-
panic!("more tasks finished than started?")
305+
guard = self.condvar.wait(guard).unwrap();
292306
}
293307
}
294308

295-
fn task_panicked(&self) {
296-
let mut guard = self.mutex.lock().unwrap();
297-
if let Some(new_panicked) = guard.unhandled_panics.checked_add(1) {
298-
guard.unhandled_panics = new_panicked;
299-
} else {
300-
panic!("too many panicking tasks in scope");
301-
}
309+
fn is_finished(&self) -> bool {
310+
self.mutex.lock().unwrap().is_some()
302311
}
312+
}
303313

304-
fn panic_joined(&self) {
305-
let mut guard = self.mutex.lock().unwrap();
306-
if let Some(new_panicked) = guard.unhandled_panics.checked_sub(1) {
307-
guard.unhandled_panics = new_panicked;
308-
} else {
309-
panic!("more panics joined than tasks panicked?")
314+
impl<T> Drop for TaskResult<T> {
315+
#[expect(clippy::print_stderr)]
316+
fn drop(&mut self) {
317+
let Some(result) = self
318+
.mutex
319+
.get_mut()
320+
.expect("worker panicked while storing result")
321+
.take()
322+
else {
323+
return; // Result was already taken and handled
324+
};
325+
326+
let panic;
327+
match result {
328+
Ok(v) => match catch_unwind(AssertUnwindSafe(|| drop(v))) {
329+
Ok(()) => return,
330+
Err(e) => panic = e,
331+
},
332+
Err(e) => panic = e,
333+
}
334+
335+
if let Err(_panic) = catch_unwind(AssertUnwindSafe(|| drop(panic))) {
336+
eprintln!("panic while dropping scoped task panic");
337+
std::process::abort();
310338
}
339+
340+
self.scope_panicked.store_panic();
311341
}
312342
}
313343

@@ -318,44 +348,22 @@ impl ScopeData {
318348
/// threads.
319349
#[derive(Debug)]
320350
pub struct ScopedJoinHandle<'scope, T> {
321-
data: Arc<HandleData<T>>,
322-
scope_data: Arc<ScopeData>,
351+
data: Arc<TaskResult<T>>,
323352
_scope: PhantomData<&'scope mut &'scope ()>,
324353
}
325354

326-
#[derive(Debug)]
327-
struct HandleData<T> {
328-
mutex: Mutex<Option<Result<T, Box<dyn Any + Send + 'static>>>>,
329-
condvar: Condvar,
330-
}
331-
332355
impl<T> ScopedJoinHandle<'_, T> {
333356
/// Wait for the task to finish.
334357
///
335358
/// The [`Err`] variant contains the panic value if the task panicked.
336359
pub fn join(self) -> Result<T, Box<dyn Any + Send + 'static>> {
337-
let result = {
338-
let HandleData { mutex, condvar } = self.data.as_ref();
339-
let mut guard = mutex.lock().unwrap();
340-
loop {
341-
if let Some(result) = guard.take() {
342-
break result;
343-
}
344-
guard = condvar.wait(guard).unwrap();
345-
}
346-
};
347-
348-
if result.is_err() {
349-
self.scope_data.panic_joined();
350-
}
351-
352-
result
360+
self.data.wait_and_take()
353361
}
354362

355363
/// Check if the task is finished.
356364
#[must_use]
357365
pub fn is_finished(&self) -> bool {
358-
self.data.mutex.lock().unwrap().is_some()
366+
self.data.is_finished()
359367
}
360368
}
361369

@@ -386,7 +394,6 @@ fn try_queue_task(mut closure: Box<dyn FnOnce() + Send>) -> Result<(), Box<dyn F
386394
}
387395
}
388396
}
389-
drop(guard);
390397

391398
Err(closure)
392399
}

0 commit comments

Comments
 (0)