Skip to content

Commit 9c86551

Browse files
committed
Refactor multithreading implementation to support using scoped tasks without WebAssembly
1 parent b6c1c53 commit 9c86551

File tree

10 files changed

+120
-69
lines changed

10 files changed

+120
-69
lines changed

crates/aoc/src/puzzles.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
use crate::all_puzzles;
22
use utils::date::Date;
3-
use utils::input::strip_final_newline;
43

54
// These imports are unused if none of the year features are enabled
65
#[allow(clippy::allow_attributes, unused_imports)]
76
use utils::{
87
PuzzleDate,
9-
input::{InputError, InputType},
8+
input::{InputError, InputType, strip_final_newline},
109
};
1110

1211
/// Represents a wrapper function around a puzzle solution.

crates/aoc_wasm/src/multithreading.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,5 @@ extern "C" fn allocate_stack(size: usize, align: usize) -> *mut u8 {
1212
/// Run worker thread.
1313
#[unsafe(no_mangle)]
1414
extern "C" fn worker_thread() {
15-
#[cfg(target_family = "wasm")]
16-
aoc::utils::wasm::scoped_tasks::worker();
17-
18-
#[cfg(not(target_family = "wasm"))]
19-
panic!("worker_thread is not supported on this target");
15+
aoc::utils::multithreading::scoped_tasks::worker();
2016
}

crates/utils/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ rust-version = { workspace = true }
1313
default = ["unsafe", "all-simd"]
1414
unsafe = []
1515
all-simd = []
16-
wasm-multithreading = ["unsafe"]
16+
scoped-tasks = ["unsafe"]
17+
wasm-multithreading = ["unsafe", "scoped-tasks"]
1718

1819
[lints]
1920
workspace = true

crates/utils/src/lib.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,15 @@ pub mod graph;
1212
pub mod grid;
1313
pub mod input;
1414
pub mod md5;
15-
#[cfg(not(target_family = "wasm"))]
1615
pub mod multithreading;
1716
pub mod multiversion;
1817
pub mod number;
1918
pub mod parser;
2019
pub mod queue;
2120
pub mod simd;
2221
pub mod slice;
23-
#[cfg(target_family = "wasm")]
24-
pub mod wasm;
2522

2623
pub use framework::{PuzzleDate, PuzzleExamples};
27-
#[cfg(target_family = "wasm")]
28-
pub use wasm::multithreading;
2924

3025
/// Standard imports for puzzle solutions.
3126
pub mod prelude {

crates/utils/src/multithreading.rs renamed to crates/utils/src/multithreading/impl_native.rs

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
//! Multithreading helpers.
2-
//!
3-
//! The main purpose of this module is to allow the number of worker threads used by each puzzle
4-
//! solution to be controlled by a CLI argument.
5-
61
use std::num::NonZeroUsize;
72
use std::sync::atomic::AtomicUsize;
83
use std::sync::atomic::Ordering::Relaxed;
@@ -27,7 +22,8 @@ pub fn get_thread_count() -> NonZeroUsize {
2722

2823
/// Set the number of worker threads to use.
2924
///
30-
/// This will affect any future call to [`get_thread_count`].
25+
/// This will affect any future call to [`get_thread_count`], unless scoped tasks are enabled and
26+
/// the thread pool has already been created.
3127
pub fn set_thread_count(count: NonZeroUsize) {
3228
NUM_THREADS.store(count.get(), Relaxed);
3329
}
@@ -42,10 +38,34 @@ pub fn worker_pool(worker: impl Fn() + Copy + Send) {
4238
if threads == 1 {
4339
worker();
4440
} else {
45-
std::thread::scope(|scope| {
46-
for _ in 0..threads {
47-
scope.spawn(worker);
48-
}
49-
});
41+
#[cfg(feature = "scoped-tasks")]
42+
{
43+
use super::scoped_tasks;
44+
45+
static ONCE: std::sync::Once = std::sync::Once::new();
46+
ONCE.call_once(|| {
47+
for i in 0..threads {
48+
std::thread::Builder::new()
49+
.name(format!("worker-{i}"))
50+
.spawn(scoped_tasks::worker)
51+
.expect("failed to spawn worker thread");
52+
}
53+
});
54+
55+
scoped_tasks::scope(|scope| {
56+
for _ in 0..threads {
57+
scope.spawn(worker);
58+
}
59+
});
60+
}
61+
62+
#[cfg(not(feature = "scoped-tasks"))]
63+
{
64+
std::thread::scope(|scope| {
65+
for _ in 0..threads {
66+
scope.spawn(worker);
67+
}
68+
});
69+
}
5070
}
5171
}

crates/utils/src/wasm/multithreading_wasm.rs renamed to crates/utils/src/multithreading/impl_wasm_scoped.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,13 @@
33
use super::scoped_tasks::{scope, worker_count};
44
use std::num::{NonZero, NonZeroUsize};
55

6+
#[cfg(not(all(
7+
target_feature = "atomics",
8+
target_feature = "bulk-memory",
9+
target_feature = "mutable-globals",
10+
)))]
11+
compile_error!("Required target features not enabled");
12+
613
#[must_use]
714
pub fn get_thread_count() -> NonZeroUsize {
815
// If there are no workers, `scoped_task` will fall back to running tasks on the current thread.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
//! Multithreading helpers.
2+
3+
#[cfg(feature = "scoped-tasks")]
4+
pub mod scoped_tasks;
5+
6+
#[cfg_attr(
7+
all(target_family = "wasm", feature = "wasm-multithreading"),
8+
path = "impl_wasm_scoped.rs"
9+
)]
10+
#[cfg_attr(
11+
all(target_family = "wasm", not(feature = "wasm-multithreading")),
12+
path = "impl_wasm_stub.rs"
13+
)]
14+
#[cfg_attr(not(target_family = "wasm"), path = "impl_native.rs")]
15+
mod multithreading_impl;
16+
pub use multithreading_impl::*;

crates/utils/src/wasm/scoped_tasks.rs renamed to crates/utils/src/multithreading/scoped_tasks.rs

Lines changed: 62 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,38 @@
1-
//! Experimental drop-in replacement for [`std::thread::scope`] for WebAssembly.
1+
//! Experimental replacement for [`std::thread::scope`] using a fixed worker pool.
22
//!
3-
//! Uses a pool of web worker threads spawned by the host JS environment to run scoped tasks.
3+
//! *Scoped tasks* are similar to *scoped threads* but run on an existing thread pool instead of
4+
//! spawning dedicated threads.
5+
//!
6+
//! # WebAssembly support
7+
//!
8+
//! This module was originally designed for WebAssembly, where it can use a pool of
9+
//! [web worker](https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API) threads spawned
10+
//! by the host JS environment to run scoped tasks.
411
//!
512
//! Requires the `atomics`, `bulk-memory` and `mutable-globals` target features to be enabled, and
613
//! for all threads to be using web workers as `memory.atomic.wait` doesn't work on the main thread.
714
//!
815
//! Catching unwinding panics should be supported, but at the time of writing, the Rust standard
9-
//! library doesn't support panic=unwind on WebAssembly.
16+
//! library doesn't support `panic=unwind` on WebAssembly.
1017
//!
1118
//! # Examples
1219
//!
1320
//! ```
14-
//! // Setup pool of workers. In WebAssembly, this would be done by spawning more web workers which
15-
//! // then call the exported worker function.
21+
//! # use std::num::NonZero;
22+
//! # use std::sync::atomic::AtomicU32;
23+
//! # use std::sync::atomic::Ordering;
24+
//! # use utils::multithreading::scoped_tasks;
25+
//! // Setup pool of workers. In WebAssembly, where std::thread::spawn is not available, this would
26+
//! // be implemented by spawning more web workers which then call the exported worker function.
1627
//! for _ in 0..std::thread::available_parallelism().map_or(4, NonZero::get) {
17-
//! std::thread::spawn(scoped::worker);
28+
//! std::thread::spawn(scoped_tasks::worker);
1829
//! }
1930
//!
2031
//! let data = vec![1, 2, 3];
2132
//! let mut data2 = vec![10, 100, 1000];
2233
//!
2334
//! // Start scoped tasks which may run on other threads
24-
//! scoped::scope(|s| {
35+
//! scoped_tasks::scope(|s| {
2536
//! s.spawn(|| {
2637
//! println!("[task 1] data={:?}", data);
2738
//! });
@@ -42,7 +53,7 @@
4253
//!
4354
//! // Start another set of scoped tasks
4455
//! let counter = AtomicU32::new(0);
45-
//! scoped::scope(|s| {
56+
//! scoped_tasks::scope(|s| {
4657
//! let counter = &counter;
4758
//! for t in 0..4 {
4859
//! s.spawn(move || {
@@ -61,17 +72,14 @@ use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
6172
use std::sync::mpsc::{SyncSender, TrySendError};
6273
use std::sync::{Arc, Condvar, Mutex};
6374

64-
#[cfg(not(all(
65-
target_feature = "atomics",
66-
target_feature = "bulk-memory",
67-
target_feature = "mutable-globals",
68-
)))]
69-
compile_error!("Required target features not enabled");
70-
7175
/// Create a scope for spawning scoped tasks.
7276
///
7377
/// Scoped tasks may borrow non-`static` data, and may run in parallel depending on thread pool
7478
/// worker availability.
79+
///
80+
/// All scoped tasks are automatically joined before this function returns.
81+
///
82+
/// Designed to match the [`std::thread::scope`] API.
7583
#[inline(never)]
7684
pub fn scope<'env, F, T>(f: F) -> T
7785
where
@@ -103,6 +111,8 @@ where
103111

104112
/// Scope to spawn tasks in.
105113
///
114+
/// Designed to match the [`std::thread::Scope`] API.
115+
///
106116
/// # Lifetimes
107117
///
108118
/// The `'scope` lifetime represents the lifetime of the scope itself, starting when the closure
@@ -111,19 +121,36 @@ where
111121
/// The `'env` lifetime represents the lifetime of the data borrowed by the scoped tasks, and must
112122
/// outlive `'scope`.
113123
#[derive(Debug)]
124+
#[expect(clippy::struct_field_names)]
114125
pub struct Scope<'scope, 'env: 'scope> {
115126
data: Arc<ScopeData>,
116127
// &'scope mut &'scope is needed to prevent lifetimes from shrinking
117128
_scope: PhantomData<&'scope mut &'scope ()>,
118129
_env: PhantomData<&'env mut &'env ()>,
119130
}
120131

121-
impl<'scope, 'env> Scope<'scope, 'env> {
132+
impl<'scope> Scope<'scope, '_> {
122133
/// Spawn a new task within the scope.
123134
///
124135
/// If no workers within the thread pool are available, the task will be executed on the current
125136
/// thread.
126137
pub fn spawn<F, T>(&'scope self, f: F) -> ScopedJoinHandle<'scope, T>
138+
where
139+
F: FnOnce() -> T + Send + 'scope,
140+
T: Send + 'scope,
141+
{
142+
let (closure, handle) = self.create_closure(f);
143+
if let Err(closure) = try_queue_task(closure) {
144+
// Fall back to running the closure on this thread
145+
closure();
146+
}
147+
handle
148+
}
149+
150+
fn create_closure<F, T>(
151+
&'scope self,
152+
f: F,
153+
) -> (Box<dyn FnOnce() + Send>, ScopedJoinHandle<'scope, T>)
127154
where
128155
F: FnOnce() -> T + Send + 'scope,
129156
T: Send + 'scope,
@@ -162,19 +189,15 @@ impl<'scope, 'env> Scope<'scope, 'env> {
162189
};
163190

164191
let scope_data = self.data.clone();
165-
scoped_task(Box::new(
166-
#[inline(never)]
167-
move || {
168-
// Use a second closure to ensure that the closure which borrows from 'scope is
169-
// dropped before `ScopeData::task_end` is called. This prevents `scope` from
170-
// returning too soon, while the closures still exist, which causes UB as detected
171-
// by Miri.
172-
let panicked = closure();
173-
scope_data.task_end(panicked);
174-
},
175-
));
176-
177-
handle
192+
let task_closure = Box::new(move || {
193+
// Use a second closure to ensure that the closure which borrows from 'scope is dropped
194+
// before `ScopeData::task_end` is called. This prevents `scope()` from returning while
195+
// the inner closure still exists, which causes UB as detected by Miri.
196+
let panicked = closure();
197+
scope_data.task_end(panicked);
198+
});
199+
200+
(task_closure, handle)
178201
}
179202
}
180203

@@ -210,6 +233,10 @@ impl ScopeData {
210233
}
211234

212235
/// Handle to block on a task's termination.
236+
///
237+
/// Designed to match the [`std::thread::ScopedJoinHandle`] API, except
238+
/// [`std::thread::ScopedJoinHandle::thread`] is not supported as tasks are not run on dedicated
239+
/// threads.
213240
#[derive(Debug)]
214241
pub struct ScopedJoinHandle<'scope, T> {
215242
data: Arc<HandleData<T>>,
@@ -222,10 +249,7 @@ struct HandleData<T> {
222249
condvar: Condvar,
223250
}
224251

225-
impl<'scope, T> ScopedJoinHandle<'scope, T> {
226-
// Unsupported
227-
// pub fn thread(&self) -> &Thread {}
228-
252+
impl<T> ScopedJoinHandle<'_, T> {
229253
/// Wait for the task to finish.
230254
pub fn join(self) -> Result<T, Box<dyn Any + Send + 'static>> {
231255
let HandleData { mutex, condvar } = self.data.as_ref();
@@ -246,7 +270,7 @@ impl<'scope, T> ScopedJoinHandle<'scope, T> {
246270
#[expect(clippy::type_complexity)]
247271
static WORKERS: Mutex<VecDeque<SyncSender<Box<dyn FnOnce() + Send>>>> = Mutex::new(VecDeque::new());
248272

249-
fn scoped_task(mut closure: Box<dyn FnOnce() + Send>) {
273+
fn try_queue_task(mut closure: Box<dyn FnOnce() + Send>) -> Result<(), Box<dyn FnOnce() + Send>> {
250274
let mut guard = WORKERS.lock().unwrap();
251275
let queue = &mut *guard;
252276

@@ -258,7 +282,7 @@ fn scoped_task(mut closure: Box<dyn FnOnce() + Send>) {
258282
match sender.try_send(closure) {
259283
Ok(()) => {
260284
queue.push_back(sender);
261-
return;
285+
return Ok(());
262286
}
263287
Err(TrySendError::Full(v)) => {
264288
closure = v;
@@ -272,11 +296,12 @@ fn scoped_task(mut closure: Box<dyn FnOnce() + Send>) {
272296
}
273297
drop(guard);
274298

275-
// Fall back to run the closure on this thread
276-
closure();
299+
Err(closure)
277300
}
278301

279302
/// Use this thread as a worker in the thread pool for scoped tasks.
303+
///
304+
/// This function never returns.
280305
pub fn worker() {
281306
let (tx, rx) = std::sync::mpsc::sync_channel(0);
282307

crates/utils/src/wasm/mod.rs

Lines changed: 0 additions & 8 deletions
This file was deleted.

0 commit comments

Comments
 (0)