Skip to content

Commit 5d1ac9d

Browse files
committed
Add try_spawn and force_spawn scoped tasks methods
1 parent 7d4c109 commit 5d1ac9d

File tree

2 files changed

+65
-20
lines changed

2 files changed

+65
-20
lines changed

crates/utils/src/multithreading/impl_native.rs

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ pub fn get_thread_count() -> NonZeroUsize {
2222

2323
/// Set the number of worker threads to use.
2424
///
25-
/// This will affect any future call to [`get_thread_count`], unless scoped tasks are enabled and
26-
/// the thread pool has already been created.
25+
/// This will affect any future call to [`get_thread_count`].
2726
pub fn set_thread_count(count: NonZeroUsize) {
2827
NUM_THREADS.store(count.get(), Relaxed);
2928
}
@@ -40,31 +39,21 @@ pub fn worker_pool(worker: impl Fn() + Copy + Send) {
4039
} else {
4140
#[cfg(feature = "scoped-tasks")]
4241
{
43-
use super::scoped_tasks;
44-
45-
static ONCE: std::sync::Once = std::sync::Once::new();
46-
ONCE.call_once(|| {
47-
for i in 0..threads {
48-
std::thread::Builder::new()
49-
.name(format!("worker-{i}"))
50-
.spawn(scoped_tasks::worker)
51-
.expect("failed to spawn worker thread");
52-
}
53-
});
54-
55-
scoped_tasks::scope(|scope| {
56-
for _ in 0..threads {
57-
scope.spawn(worker);
42+
super::scoped_tasks::scope(|scope| {
43+
for _ in 1..threads {
44+
scope.force_spawn(worker);
5845
}
46+
worker();
5947
});
6048
}
6149

6250
#[cfg(not(feature = "scoped-tasks"))]
6351
{
64-
std::thread::scope(|scope| {
65-
for _ in 0..threads {
52+
std::thread::scope(move |scope| {
53+
for _ in 1..threads {
6654
scope.spawn(worker);
6755
}
56+
worker();
6857
});
6958
}
7059
}

crates/utils/src/multithreading/scoped_tasks.rs

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ use std::any::Any;
6969
use std::collections::VecDeque;
7070
use std::marker::PhantomData;
7171
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
72+
use std::sync::atomic::{AtomicU32, Ordering};
7273
use std::sync::mpsc::{SyncSender, TrySendError};
7374
use std::sync::{Arc, Condvar, Mutex};
7475

@@ -111,7 +112,7 @@ where
111112

112113
/// Scope to spawn tasks in.
113114
///
114-
/// Designed to match the [`std::thread::Scope`] API.
115+
/// [`Scope::spawn()`] is designed to match the [`std::thread::Scope`] API.
115116
///
116117
/// # Lifetimes
117118
///
@@ -147,6 +148,55 @@ impl<'scope> Scope<'scope, '_> {
147148
handle
148149
}
149150

151+
/// Spawn a new task within the scope if there is a worker available.
152+
///
153+
/// If no workers within the thread pool are available, the task will not be executed and
154+
/// [`None`] will be returned.
155+
pub fn try_spawn<F, T>(&'scope self, f: F) -> Option<ScopedJoinHandle<'scope, T>>
156+
where
157+
F: FnOnce() -> T + Send + 'scope,
158+
T: Send + 'scope,
159+
{
160+
let (closure, handle) = self.create_closure(f);
161+
if let Ok(()) = try_queue_task(closure) {
162+
Some(handle)
163+
} else {
164+
// Closure will never be run
165+
self.data.task_end();
166+
167+
None
168+
}
169+
}
170+
171+
/// Spawn a new task within the scope, spawning a new worker if necessary.
172+
///
173+
/// This function is not available on WebAssembly, as new threads have to be created from the
174+
/// host JS environment.
175+
#[cfg(not(target_family = "wasm"))]
176+
pub fn force_spawn<F, T>(&'scope self, f: F) -> ScopedJoinHandle<'scope, T>
177+
where
178+
F: FnOnce() -> T + Send + 'scope,
179+
T: Send + 'scope,
180+
{
181+
let (closure, handle) = self.create_closure(f);
182+
if let Err(closure) = try_queue_task(closure) {
183+
// Start a worker to process this closure and then join the pool.
184+
static THREAD_NUM: AtomicU32 = AtomicU32::new(1);
185+
std::thread::Builder::new()
186+
.name(format!(
187+
"scoped-tasks-{}",
188+
THREAD_NUM.fetch_add(1, Ordering::Relaxed)
189+
))
190+
.spawn(move || {
191+
// Pass the closure directly to the new worker to avoid race conditions where
192+
// another scope queues a closure before this one.
193+
worker_impl(closure);
194+
})
195+
.expect("failed to spawn worker thread");
196+
}
197+
handle
198+
}
199+
150200
fn create_closure<F, T>(
151201
&'scope self,
152202
f: F,
@@ -345,6 +395,12 @@ fn try_queue_task(mut closure: Box<dyn FnOnce() + Send>) -> Result<(), Box<dyn F
345395
///
346396
/// This function never returns.
347397
pub fn worker() {
398+
worker_impl(|| {});
399+
}
400+
401+
fn worker_impl(initial: impl FnOnce() + Send) {
402+
initial();
403+
348404
let (tx, rx) = std::sync::mpsc::sync_channel(0);
349405

350406
{

0 commit comments

Comments
 (0)