Skip to content

Commit 7d4c109

Browse files
committed
Fix scoped tasks panicking if the panic was joined
1 parent 9c86551 commit 7d4c109

File tree

1 file changed

+65
-23
lines changed

1 file changed

+65
-23
lines changed

crates/utils/src/multithreading/scoped_tasks.rs

Lines changed: 65 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ where
8787
{
8888
let scope = Scope {
8989
data: Arc::new(ScopeData {
90-
mutex: Mutex::new((0, false)),
90+
mutex: Mutex::new(ScopeCounts::default()),
9191
condvar: Condvar::new(),
9292
}),
9393
_scope: PhantomData,
@@ -98,13 +98,13 @@ where
9898

9999
// Wait for tasks to finish
100100
let mut guard = scope.data.mutex.lock().unwrap();
101-
while guard.0 > 0 {
101+
while guard.running > 0 {
102102
guard = scope.data.condvar.wait(guard).unwrap();
103103
}
104104

105105
match result {
106106
Err(e) => resume_unwind(e),
107-
Ok(_) if guard.1 => panic!("scoped task panicked"),
107+
Ok(_) if guard.unhandled_panics > 0 => panic!("scoped task panicked"),
108108
Ok(x) => x,
109109
}
110110
}
@@ -162,74 +162,103 @@ impl<'scope> Scope<'scope, '_> {
162162
mutex: Mutex::new(None),
163163
condvar: Condvar::new(),
164164
}),
165+
scope_data: self.data.clone(),
165166
_scope: PhantomData,
166167
};
167168

168169
let handle_data = handle.data.clone();
169-
let closure: Box<dyn FnOnce() -> bool + Send + 'scope> = Box::new(
170+
let scope_data = self.data.clone();
171+
let closure: Box<dyn FnOnce() + Send + 'scope> = Box::new(
170172
#[inline(never)]
171173
move || {
172174
let result = catch_unwind(AssertUnwindSafe(f));
173-
let panicked = result.is_err();
175+
176+
if result.is_err() {
177+
// Updating the panic count must happen before updating the handle data, to
178+
// avoid the handle being joined in another thread which then tries to decrement
179+
// the unhandled panic count before it is incremented
180+
scope_data.task_panicked();
181+
}
174182

175183
// Send the result to ScopedJoinHandle and wake any blocked threads
176184
let HandleData { mutex, condvar } = handle_data.as_ref();
177185
let mut guard = mutex.lock().unwrap();
178186
*guard = Some(result);
179187
condvar.notify_all();
180-
181-
panicked
182188
},
183189
);
184190

185191
// SAFETY: The `scope` function ensures all closures are finished before returning
186192
let closure = unsafe {
187193
#[expect(clippy::unnecessary_cast, reason = "casting lifetimes")]
188-
Box::from_raw(Box::into_raw(closure) as *mut (dyn FnOnce() -> bool + Send + 'static))
194+
Box::from_raw(Box::into_raw(closure) as *mut (dyn FnOnce() + Send + 'static))
189195
};
190196

191197
let scope_data = self.data.clone();
192198
let task_closure = Box::new(move || {
193199
// Use a second closure to ensure that the closure which borrows from 'scope is dropped
194200
// before `ScopeData::task_end` is called. This prevents `scope()` from returning while
195201
// the inner closure still exists, which causes UB as detected by Miri.
196-
let panicked = closure();
197-
scope_data.task_end(panicked);
202+
closure();
203+
scope_data.task_end();
198204
});
199205

200206
(task_closure, handle)
201207
}
202208
}
203209

204-
// Stores the number of currently running tasks, and if a panic occurred.
210+
// Stores the number of currently running tasks and unhandled panics.
205211
#[derive(Debug)]
206212
struct ScopeData {
207-
mutex: Mutex<(usize, bool)>,
213+
mutex: Mutex<ScopeCounts>,
208214
condvar: Condvar,
209215
}
210216

217+
#[derive(Debug, Default)]
218+
struct ScopeCounts {
219+
running: usize,
220+
unhandled_panics: usize,
221+
}
222+
211223
impl ScopeData {
212224
fn task_start(&self) {
213225
let mut guard = self.mutex.lock().unwrap();
214-
if let Some(new_running) = guard.0.checked_add(1) {
215-
guard.0 = new_running;
226+
if let Some(new_running) = guard.running.checked_add(1) {
227+
guard.running = new_running;
216228
} else {
217229
panic!("too many running tasks in scope");
218230
}
219231
}
220232

221-
fn task_end(&self, panicked: bool) {
233+
fn task_end(&self) {
222234
let mut guard = self.mutex.lock().unwrap();
223-
guard.1 |= panicked;
224-
if let Some(new_running) = guard.0.checked_sub(1) {
225-
guard.0 = new_running;
235+
if let Some(new_running) = guard.running.checked_sub(1) {
236+
guard.running = new_running;
226237
if new_running == 0 {
227238
self.condvar.notify_all();
228239
}
229240
} else {
230241
panic!("more tasks finished than started?")
231242
}
232243
}
244+
245+
fn task_panicked(&self) {
246+
let mut guard = self.mutex.lock().unwrap();
247+
if let Some(new_panicked) = guard.unhandled_panics.checked_add(1) {
248+
guard.unhandled_panics = new_panicked;
249+
} else {
250+
panic!("too many panicking tasks in scope");
251+
}
252+
}
253+
254+
fn panic_joined(&self) {
255+
let mut guard = self.mutex.lock().unwrap();
256+
if let Some(new_panicked) = guard.unhandled_panics.checked_sub(1) {
257+
guard.unhandled_panics = new_panicked;
258+
} else {
259+
panic!("more panics joined than tasks panicked?")
260+
}
261+
}
233262
}
234263

235264
/// Handle to block on a task's termination.
@@ -240,6 +269,7 @@ impl ScopeData {
240269
#[derive(Debug)]
241270
pub struct ScopedJoinHandle<'scope, T> {
242271
data: Arc<HandleData<T>>,
272+
scope_data: Arc<ScopeData>,
243273
_scope: PhantomData<&'scope mut &'scope ()>,
244274
}
245275

@@ -251,13 +281,25 @@ struct HandleData<T> {
251281

252282
impl<T> ScopedJoinHandle<'_, T> {
253283
/// Wait for the task to finish.
284+
///
285+
/// The [`Err`] variant contains the panic value if the task panicked.
254286
pub fn join(self) -> Result<T, Box<dyn Any + Send + 'static>> {
255-
let HandleData { mutex, condvar } = self.data.as_ref();
256-
let mut guard = mutex.lock().unwrap();
257-
while guard.is_none() {
258-
guard = condvar.wait(guard).unwrap();
287+
let result = {
288+
let HandleData { mutex, condvar } = self.data.as_ref();
289+
let mut guard = mutex.lock().unwrap();
290+
loop {
291+
if let Some(result) = guard.take() {
292+
break result;
293+
}
294+
guard = condvar.wait(guard).unwrap();
295+
}
296+
};
297+
298+
if result.is_err() {
299+
self.scope_data.panic_joined();
259300
}
260-
guard.take().unwrap()
301+
302+
result
261303
}
262304

263305
/// Check if the task is finished.

0 commit comments

Comments
 (0)