diff --git a/compio-dispatcher/src/lib.rs b/compio-dispatcher/src/lib.rs index cfbca861..9c30cb56 100644 --- a/compio-dispatcher/src/lib.rs +++ b/compio-dispatcher/src/lib.rs @@ -81,8 +81,8 @@ impl Dispatcher { mut names, mut proactor_builder, } = builder; - proactor_builder.force_reuse_thread_pool(); - let pool = proactor_builder.create_or_get_thread_pool(); + proactor_builder.force_reuse_thread_pool()?; + let pool = proactor_builder.create_or_get_thread_pool()?; let (sender, receiver) = unbounded::(); let threads = (0..nthreads) diff --git a/compio-driver/src/asyncify.rs b/compio-driver/src/asyncify/fallback.rs similarity index 58% rename from compio-driver/src/asyncify.rs rename to compio-driver/src/asyncify/fallback.rs index 275e26c1..675bc189 100644 --- a/compio-driver/src/asyncify.rs +++ b/compio-driver/src/asyncify/fallback.rs @@ -1,5 +1,5 @@ use std::{ - fmt, + io, sync::{ Arc, atomic::{AtomicUsize, Ordering}, @@ -9,52 +9,10 @@ use std::{ use crossbeam_channel::{Receiver, Sender, TrySendError, bounded}; -/// An error that may be emitted when all worker threads are busy. It simply -/// returns the dispatchable value with a convenient [`fmt::Debug`] and -/// [`fmt::Display`] implementation. -#[derive(Copy, Clone, PartialEq, Eq)] -pub struct DispatchError(pub T); - -impl DispatchError { - /// Consume the error, yielding the dispatchable that failed to be sent. - pub fn into_inner(self) -> T { - self.0 - } -} - -impl fmt::Debug for DispatchError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - "DispatchError(..)".fmt(f) - } -} - -impl fmt::Display for DispatchError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - "all threads are busy".fmt(f) - } -} - -impl std::error::Error for DispatchError {} +use super::{DispatchError, Dispatchable}; type BoxedDispatchable = Box; -/// A trait for dispatching a closure. It's implemented for all `FnOnce() + Send -/// + 'static` but may also be implemented for any other types that are `Send` -/// and `'static`. -pub trait Dispatchable: Send + 'static { - /// Run the dispatchable - fn run(self: Box); -} - -impl Dispatchable for F -where - F: FnOnce() + Send + 'static, -{ - fn run(self: Box) { - (*self)() - } -} - struct CounterGuard(Arc); impl Drop for CounterGuard { @@ -77,7 +35,6 @@ fn worker( } } -/// A thread pool to perform blocking operations in other threads. #[derive(Debug, Clone)] pub struct AsyncifyPool { sender: Sender, @@ -88,23 +45,17 @@ pub struct AsyncifyPool { } impl AsyncifyPool { - /// Create [`AsyncifyPool`] with thread number limit and channel receive - /// timeout. - pub fn new(thread_limit: usize, recv_timeout: Duration) -> Self { + pub fn new(thread_limit: usize, recv_timeout: Duration) -> io::Result { let (sender, receiver) = bounded(0); - Self { + Ok(Self { sender, receiver, counter: Arc::new(AtomicUsize::new(0)), thread_limit, recv_timeout, - } + }) } - /// Send a dispatchable, usually a closure, to another thread. Usually the - /// user should not use it. When all threads are busy and thread number - /// limit has been reached, it will return an error with the original - /// dispatchable. pub fn dispatch(&self, f: D) -> Result<(), DispatchError> { match self.sender.try_send(Box::new(f) as BoxedDispatchable) { Ok(_) => Ok(()), diff --git a/compio-driver/src/asyncify/mod.rs b/compio-driver/src/asyncify/mod.rs new file mode 100644 index 00000000..d2afa587 --- /dev/null +++ b/compio-driver/src/asyncify/mod.rs @@ -0,0 +1,82 @@ +use std::{self, fmt, io, time::Duration}; + +cfg_if::cfg_if! { + if #[cfg(windows)] { + #[path = "windows.rs"] + mod sys; + } else { + #[path = "fallback.rs"] + mod sys; + } +} + +/// An error that may be emitted when all worker threads are busy. It simply +/// returns the dispatchable value with a convenient [`fmt::Debug`] and +/// [`fmt::Display`] implementation. +#[derive(Copy, Clone, PartialEq, Eq)] +pub struct DispatchError(pub T); + +impl DispatchError { + /// Consume the error, yielding the dispatchable that failed to be sent. + pub fn into_inner(self) -> T { + self.0 + } +} + +impl fmt::Debug for DispatchError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + "DispatchError(..)".fmt(f) + } +} + +impl fmt::Display for DispatchError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + "all threads are busy".fmt(f) + } +} + +impl std::error::Error for DispatchError {} + +/// A trait for dispatching a closure. It's implemented for all `FnOnce() + Send +/// + 'static` but may also be implemented for any other types that are `Send` +/// and `'static`. +pub trait Dispatchable: Send + 'static { + /// Run the dispatchable + fn run(self: Box); +} + +impl Dispatchable for F +where + F: FnOnce() + Send + 'static, +{ + fn run(self: Box) { + (*self)() + } +} + +/// A thread pool to perform blocking operations in other threads. +#[derive(Debug, Clone)] +pub struct AsyncifyPool(sys::AsyncifyPool); + +impl AsyncifyPool { + /// Create [`AsyncifyPool`] with thread number limit and channel receive + /// timeout. + pub fn new(thread_limit: usize, recv_timeout: Duration) -> io::Result { + Ok(Self(sys::AsyncifyPool::new(thread_limit, recv_timeout)?)) + } + + /// Send a dispatchable, usually a closure, to another thread. Usually the + /// user should not use it. When all threads are busy and thread number + /// limit has been reached, it will return an error with the original + /// dispatchable. + pub fn dispatch(&self, f: D) -> Result<(), DispatchError> { + self.0.dispatch(f) + } + + #[cfg(windows)] + pub(crate) fn as_ptr( + &self, + ) -> *const windows_sys::Win32::System::Threading::TP_CALLBACK_ENVIRON_V3 { + self.0.as_ptr() + } +} diff --git a/compio-driver/src/asyncify/windows.rs b/compio-driver/src/asyncify/windows.rs new file mode 100644 index 00000000..98e6fe51 --- /dev/null +++ b/compio-driver/src/asyncify/windows.rs @@ -0,0 +1,113 @@ +use std::{ + fmt::Debug, + io, + ptr::{null, null_mut}, + sync::Arc, + time::Duration, +}; + +use windows_sys::Win32::System::Threading::{ + CloseThreadpool, CloseThreadpoolCleanupGroup, CloseThreadpoolCleanupGroupMembers, + CreateThreadpool, CreateThreadpoolCleanupGroup, PTP_CALLBACK_INSTANCE, + SetThreadpoolThreadMaximum, TP_CALLBACK_ENVIRON_V3, TP_CALLBACK_PRIORITY_NORMAL, + TrySubmitThreadpoolCallback, +}; + +use super::{DispatchError, Dispatchable}; +use crate::syscall; + +struct PoolEnv(TP_CALLBACK_ENVIRON_V3); + +impl PoolEnv { + fn as_ref(&self) -> &TP_CALLBACK_ENVIRON_V3 { + &self.0 + } +} + +unsafe impl Send for PoolEnv {} +unsafe impl Sync for PoolEnv {} + +impl Drop for PoolEnv { + fn drop(&mut self) { + unsafe { + let pool = self.0.Pool; + let group = self.0.CleanupGroup; + CloseThreadpoolCleanupGroupMembers(group, 1, null_mut()); + CloseThreadpoolCleanupGroup(group); + CloseThreadpool(pool); + } + } +} + +#[derive(Clone)] +pub struct AsyncifyPool { + inner: Option>, +} + +impl Debug for AsyncifyPool { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AsyncifyPool").finish_non_exhaustive() + } +} + +impl AsyncifyPool { + pub fn new(thread_limit: usize, _recv_timeout: Duration) -> io::Result { + if thread_limit == 0 { + Ok(Self { inner: None }) + } else { + let pool = syscall!(BOOL, CreateThreadpool(null()))?; + let group = syscall!(BOOL, CreateThreadpoolCleanupGroup())?; + let inner = TP_CALLBACK_ENVIRON_V3 { + Version: 3, + Pool: pool, + CleanupGroup: group, + CallbackPriority: TP_CALLBACK_PRIORITY_NORMAL, + Size: size_of::() as u32, + ..Default::default() + }; + unsafe { + SetThreadpoolThreadMaximum(pool, thread_limit as _); + } + Ok(Self { + inner: Some(Arc::new(PoolEnv(inner))), + }) + } + } + + pub fn dispatch(&self, f: D) -> Result<(), DispatchError> { + unsafe extern "system" fn callback( + _: PTP_CALLBACK_INSTANCE, + callback: *mut std::ffi::c_void, + ) { + unsafe { + Box::from_raw(callback as *mut F).run(); + } + } + + if let Some(inner) = &self.inner { + let f = Box::new(f); + let ptr = Box::into_raw(f); + let res = syscall!( + BOOL, + TrySubmitThreadpoolCallback( + Some(callback::), + ptr.cast(), + inner.as_ref().as_ref(), + ) + ); + match res { + Ok(_) => Ok(()), + Err(_) => Err(DispatchError(*unsafe { Box::from_raw(ptr) })), + } + } else { + panic!("the thread pool is needed but no worker thread is running"); + } + } + + pub fn as_ptr(&self) -> *const TP_CALLBACK_ENVIRON_V3 { + self.inner + .as_deref() + .map(|e| e.as_ref() as *const _) + .unwrap_or(null()) + } +} diff --git a/compio-driver/src/iocp/mod.rs b/compio-driver/src/iocp/mod.rs index 62126b8f..b47687b7 100644 --- a/compio-driver/src/iocp/mod.rs +++ b/compio-driver/src/iocp/mod.rs @@ -296,7 +296,7 @@ impl Driver { Ok(Self { port, waits: HashMap::default(), - pool: builder.create_or_get_thread_pool(), + pool: builder.create_or_get_thread_pool()?, notify_overlapped: Arc::new(Overlapped::new(driver)), }) } @@ -350,8 +350,10 @@ impl Driver { } }, OpType::Event(e) => { - self.waits - .insert(user_data, wait::Wait::new(&self.port, e, op)?); + self.waits.insert( + user_data, + wait::Wait::new(&self.port, e, op, self.pool.as_ptr())?, + ); Poll::Pending } } diff --git a/compio-driver/src/iocp/wait/packet.rs b/compio-driver/src/iocp/wait/packet.rs index 7aca774f..8b528d43 100644 --- a/compio-driver/src/iocp/wait/packet.rs +++ b/compio-driver/src/iocp/wait/packet.rs @@ -5,9 +5,12 @@ use std::{ ptr::null_mut, }; -use windows_sys::Win32::Foundation::{ - GENERIC_READ, GENERIC_WRITE, HANDLE, NTSTATUS, RtlNtStatusToDosError, STATUS_PENDING, - STATUS_SUCCESS, +use windows_sys::Win32::{ + Foundation::{ + GENERIC_READ, GENERIC_WRITE, HANDLE, NTSTATUS, RtlNtStatusToDosError, STATUS_PENDING, + STATUS_SUCCESS, + }, + System::Threading::TP_CALLBACK_ENVIRON_V3, }; use crate::{Key, OpCode, RawFd, sys::cp}; @@ -52,7 +55,12 @@ fn check_status(status: NTSTATUS) -> io::Result<()> { } impl Wait { - pub fn new(port: &cp::Port, event: RawFd, op: &mut Key) -> io::Result { + pub fn new( + port: &cp::Port, + event: RawFd, + op: &mut Key, + _env: *const TP_CALLBACK_ENVIRON_V3, + ) -> io::Result { let mut handle = null_mut(); check_status(unsafe { NtCreateWaitCompletionPacket(&mut handle, GENERIC_READ | GENERIC_WRITE, null_mut()) diff --git a/compio-driver/src/iocp/wait/thread_pool.rs b/compio-driver/src/iocp/wait/thread_pool.rs index b19ddd4b..8adb85aa 100644 --- a/compio-driver/src/iocp/wait/thread_pool.rs +++ b/compio-driver/src/iocp/wait/thread_pool.rs @@ -8,7 +8,7 @@ use windows_sys::Win32::{ Foundation::{ERROR_IO_PENDING, ERROR_TIMEOUT, WAIT_OBJECT_0, WAIT_TIMEOUT}, System::Threading::{ CloseThreadpoolWait, CreateThreadpoolWait, PTP_CALLBACK_INSTANCE, PTP_WAIT, - SetThreadpoolWait, WaitForThreadpoolWaitCallbacks, + SetThreadpoolWait, TP_CALLBACK_ENVIRON_V3, WaitForThreadpoolWaitCallbacks, }, }; @@ -22,7 +22,12 @@ pub struct Wait { } impl Wait { - pub fn new(port: &cp::Port, event: RawFd, op: &mut Key) -> io::Result { + pub fn new( + port: &cp::Port, + event: RawFd, + op: &mut Key, + env: *const TP_CALLBACK_ENVIRON_V3, + ) -> io::Result { let port = port.handle(); let mut context = Box::new(WinThreadpoolWaitContext { port, @@ -33,7 +38,7 @@ impl Wait { CreateThreadpoolWait( Some(Self::wait_callback), (&mut *context) as *mut WinThreadpoolWaitContext as _, - null() + env ) )?; unsafe { diff --git a/compio-driver/src/iour/mod.rs b/compio-driver/src/iour/mod.rs index a26be648..3fe4dac4 100644 --- a/compio-driver/src/iour/mod.rs +++ b/compio-driver/src/iour/mod.rs @@ -114,7 +114,7 @@ impl Driver { Ok(Self { inner, notifier, - pool: builder.create_or_get_thread_pool(), + pool: builder.create_or_get_thread_pool()?, pool_completed: Arc::new(SegQueue::new()), buffer_group_ids: Slab::new(), need_push_notifier: true, diff --git a/compio-driver/src/lib.rs b/compio-driver/src/lib.rs index 9840d402..4a481b65 100644 --- a/compio-driver/src/lib.rs +++ b/compio-driver/src/lib.rs @@ -441,10 +441,10 @@ impl ThreadPoolBuilder { } } - pub fn create_or_reuse(&self) -> AsyncifyPool { + pub fn create_or_reuse(&self) -> io::Result { match self { Self::Create { limit, recv_limit } => AsyncifyPool::new(*limit, *recv_limit), - Self::Reuse(pool) => pool.clone(), + Self::Reuse(pool) => Ok(pool.clone()), } } } @@ -527,13 +527,13 @@ impl ProactorBuilder { /// Force reuse the thread pool for each proactor created by this builder, /// even `reuse_thread_pool` is not set. - pub fn force_reuse_thread_pool(&mut self) -> &mut Self { - self.reuse_thread_pool(self.create_or_get_thread_pool()); - self + pub fn force_reuse_thread_pool(&mut self) -> io::Result<&mut Self> { + self.reuse_thread_pool(self.create_or_get_thread_pool()?); + Ok(self) } /// Create or reuse the thread pool from the config. - pub fn create_or_get_thread_pool(&self) -> AsyncifyPool { + pub fn create_or_get_thread_pool(&self) -> io::Result { self.pool_builder.create_or_reuse() } diff --git a/compio-driver/src/poll/mod.rs b/compio-driver/src/poll/mod.rs index 147891a2..16c05356 100644 --- a/compio-driver/src/poll/mod.rs +++ b/compio-driver/src/poll/mod.rs @@ -193,7 +193,7 @@ impl Driver { events, poll, registry: HashMap::new(), - pool: builder.create_or_get_thread_pool(), + pool: builder.create_or_get_thread_pool()?, pool_completed: Arc::new(SegQueue::new()), }) }