Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions compio-dispatcher/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ impl Dispatcher {
mut names,
mut proactor_builder,
} = builder;
proactor_builder.force_reuse_thread_pool();
let pool = proactor_builder.create_or_get_thread_pool();
proactor_builder.force_reuse_thread_pool()?;
let pool = proactor_builder.create_or_get_thread_pool()?;
let (sender, receiver) = unbounded::<Spawning>();

let threads = (0..nthreads)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{
fmt,
io,
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
Expand All @@ -9,52 +9,10 @@ use std::{

use crossbeam_channel::{Receiver, Sender, TrySendError, bounded};

/// An error that may be emitted when all worker threads are busy. It simply
/// returns the dispatchable value with a convenient [`fmt::Debug`] and
/// [`fmt::Display`] implementation.
#[derive(Copy, Clone, PartialEq, Eq)]
pub struct DispatchError<T>(pub T);

impl<T> DispatchError<T> {
/// Consume the error, yielding the dispatchable that failed to be sent.
pub fn into_inner(self) -> T {
self.0
}
}

impl<T> fmt::Debug for DispatchError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
"DispatchError(..)".fmt(f)
}
}

impl<T> fmt::Display for DispatchError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
"all threads are busy".fmt(f)
}
}

impl<T> std::error::Error for DispatchError<T> {}
use super::{DispatchError, Dispatchable};

type BoxedDispatchable = Box<dyn Dispatchable + Send>;

/// A trait for dispatching a closure. It's implemented for all `FnOnce() + Send
/// + 'static` but may also be implemented for any other types that are `Send`
/// and `'static`.
pub trait Dispatchable: Send + 'static {
/// Run the dispatchable
fn run(self: Box<Self>);
}

impl<F> Dispatchable for F
where
F: FnOnce() + Send + 'static,
{
fn run(self: Box<Self>) {
(*self)()
}
}

struct CounterGuard(Arc<AtomicUsize>);

impl Drop for CounterGuard {
Expand All @@ -77,7 +35,6 @@ fn worker(
}
}

/// A thread pool to perform blocking operations in other threads.
#[derive(Debug, Clone)]
pub struct AsyncifyPool {
sender: Sender<BoxedDispatchable>,
Expand All @@ -88,23 +45,17 @@ pub struct AsyncifyPool {
}

impl AsyncifyPool {
/// Create [`AsyncifyPool`] with thread number limit and channel receive
/// timeout.
pub fn new(thread_limit: usize, recv_timeout: Duration) -> Self {
pub fn new(thread_limit: usize, recv_timeout: Duration) -> io::Result<Self> {
let (sender, receiver) = bounded(0);
Self {
Ok(Self {
sender,
receiver,
counter: Arc::new(AtomicUsize::new(0)),
thread_limit,
recv_timeout,
}
})
}

/// Send a dispatchable, usually a closure, to another thread. Usually the
/// user should not use it. When all threads are busy and thread number
/// limit has been reached, it will return an error with the original
/// dispatchable.
pub fn dispatch<D: Dispatchable>(&self, f: D) -> Result<(), DispatchError<D>> {
match self.sender.try_send(Box::new(f) as BoxedDispatchable) {
Ok(_) => Ok(()),
Expand Down
82 changes: 82 additions & 0 deletions compio-driver/src/asyncify/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
use std::{self, fmt, io, time::Duration};

cfg_if::cfg_if! {
if #[cfg(windows)] {
#[path = "windows.rs"]
mod sys;
} else {
#[path = "fallback.rs"]
mod sys;
}
}

/// An error that may be emitted when all worker threads are busy. It simply
/// returns the dispatchable value with a convenient [`fmt::Debug`] and
/// [`fmt::Display`] implementation.
#[derive(Copy, Clone, PartialEq, Eq)]
pub struct DispatchError<T>(pub T);

impl<T> DispatchError<T> {
/// Consume the error, yielding the dispatchable that failed to be sent.
pub fn into_inner(self) -> T {
self.0
}
}

impl<T> fmt::Debug for DispatchError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
"DispatchError(..)".fmt(f)
}
}

impl<T> fmt::Display for DispatchError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
"all threads are busy".fmt(f)
}
}

impl<T> std::error::Error for DispatchError<T> {}

/// A trait for dispatching a closure. It's implemented for all `FnOnce() + Send
/// + 'static` but may also be implemented for any other types that are `Send`
/// and `'static`.
pub trait Dispatchable: Send + 'static {
/// Run the dispatchable
fn run(self: Box<Self>);
}

impl<F> Dispatchable for F
where
F: FnOnce() + Send + 'static,
{
fn run(self: Box<Self>) {
(*self)()
}
}

/// A thread pool to perform blocking operations in other threads.
#[derive(Debug, Clone)]
pub struct AsyncifyPool(sys::AsyncifyPool);

impl AsyncifyPool {
/// Create [`AsyncifyPool`] with thread number limit and channel receive
/// timeout.
pub fn new(thread_limit: usize, recv_timeout: Duration) -> io::Result<Self> {
Ok(Self(sys::AsyncifyPool::new(thread_limit, recv_timeout)?))
}

/// Send a dispatchable, usually a closure, to another thread. Usually the
/// user should not use it. When all threads are busy and thread number
/// limit has been reached, it will return an error with the original
/// dispatchable.
pub fn dispatch<D: Dispatchable>(&self, f: D) -> Result<(), DispatchError<D>> {
self.0.dispatch(f)
}

#[cfg(windows)]
pub(crate) fn as_ptr(
&self,
) -> *const windows_sys::Win32::System::Threading::TP_CALLBACK_ENVIRON_V3 {
self.0.as_ptr()
}
}
113 changes: 113 additions & 0 deletions compio-driver/src/asyncify/windows.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
use std::{
fmt::Debug,
io,
ptr::{null, null_mut},
sync::Arc,
time::Duration,
};

use windows_sys::Win32::System::Threading::{
CloseThreadpool, CloseThreadpoolCleanupGroup, CloseThreadpoolCleanupGroupMembers,
CreateThreadpool, CreateThreadpoolCleanupGroup, PTP_CALLBACK_INSTANCE,
SetThreadpoolThreadMaximum, TP_CALLBACK_ENVIRON_V3, TP_CALLBACK_PRIORITY_NORMAL,
TrySubmitThreadpoolCallback,
};

use super::{DispatchError, Dispatchable};
use crate::syscall;

struct PoolEnv(TP_CALLBACK_ENVIRON_V3);

impl PoolEnv {
fn as_ref(&self) -> &TP_CALLBACK_ENVIRON_V3 {
&self.0
}
}

unsafe impl Send for PoolEnv {}
unsafe impl Sync for PoolEnv {}

impl Drop for PoolEnv {
fn drop(&mut self) {
unsafe {
let pool = self.0.Pool;
let group = self.0.CleanupGroup;
CloseThreadpoolCleanupGroupMembers(group, 1, null_mut());
CloseThreadpoolCleanupGroup(group);
CloseThreadpool(pool);
}
}
}

#[derive(Clone)]
pub struct AsyncifyPool {
inner: Option<Arc<PoolEnv>>,
}

impl Debug for AsyncifyPool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncifyPool").finish_non_exhaustive()
}
}

impl AsyncifyPool {
pub fn new(thread_limit: usize, _recv_timeout: Duration) -> io::Result<Self> {
if thread_limit == 0 {
Ok(Self { inner: None })
} else {
let pool = syscall!(BOOL, CreateThreadpool(null()))?;
let group = syscall!(BOOL, CreateThreadpoolCleanupGroup())?;
let inner = TP_CALLBACK_ENVIRON_V3 {
Version: 3,
Pool: pool,
CleanupGroup: group,
CallbackPriority: TP_CALLBACK_PRIORITY_NORMAL,
Size: size_of::<TP_CALLBACK_ENVIRON_V3>() as u32,
..Default::default()
};
unsafe {
SetThreadpoolThreadMaximum(pool, thread_limit as _);
}
Ok(Self {
inner: Some(Arc::new(PoolEnv(inner))),
})
}
}

pub fn dispatch<D: Dispatchable>(&self, f: D) -> Result<(), DispatchError<D>> {
unsafe extern "system" fn callback<F: Dispatchable>(
_: PTP_CALLBACK_INSTANCE,
callback: *mut std::ffi::c_void,
) {
unsafe {
Box::from_raw(callback as *mut F).run();
}
}

if let Some(inner) = &self.inner {
let f = Box::new(f);
let ptr = Box::into_raw(f);
let res = syscall!(
BOOL,
TrySubmitThreadpoolCallback(
Some(callback::<D>),
ptr.cast(),
inner.as_ref().as_ref(),
)
);
match res {
Ok(_) => Ok(()),
Err(_) => Err(DispatchError(*unsafe { Box::from_raw(ptr) })),
}
} else {
panic!("the thread pool is needed but no worker thread is running");
}
}

pub fn as_ptr(&self) -> *const TP_CALLBACK_ENVIRON_V3 {
self.inner
.as_deref()
.map(|e| e.as_ref() as *const _)
.unwrap_or(null())
}
}
8 changes: 5 additions & 3 deletions compio-driver/src/iocp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ impl Driver {
Ok(Self {
port,
waits: HashMap::default(),
pool: builder.create_or_get_thread_pool(),
pool: builder.create_or_get_thread_pool()?,
notify_overlapped: Arc::new(Overlapped::new(driver)),
})
}
Expand Down Expand Up @@ -350,8 +350,10 @@ impl Driver {
}
},
OpType::Event(e) => {
self.waits
.insert(user_data, wait::Wait::new(&self.port, e, op)?);
self.waits.insert(
user_data,
wait::Wait::new(&self.port, e, op, self.pool.as_ptr())?,
);
Poll::Pending
}
}
Expand Down
16 changes: 12 additions & 4 deletions compio-driver/src/iocp/wait/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ use std::{
ptr::null_mut,
};

use windows_sys::Win32::Foundation::{
GENERIC_READ, GENERIC_WRITE, HANDLE, NTSTATUS, RtlNtStatusToDosError, STATUS_PENDING,
STATUS_SUCCESS,
use windows_sys::Win32::{
Foundation::{
GENERIC_READ, GENERIC_WRITE, HANDLE, NTSTATUS, RtlNtStatusToDosError, STATUS_PENDING,
STATUS_SUCCESS,
},
System::Threading::TP_CALLBACK_ENVIRON_V3,
};

use crate::{Key, OpCode, RawFd, sys::cp};
Expand Down Expand Up @@ -52,7 +55,12 @@ fn check_status(status: NTSTATUS) -> io::Result<()> {
}

impl Wait {
pub fn new(port: &cp::Port, event: RawFd, op: &mut Key<dyn OpCode>) -> io::Result<Self> {
pub fn new(
port: &cp::Port,
event: RawFd,
op: &mut Key<dyn OpCode>,
_env: *const TP_CALLBACK_ENVIRON_V3,
) -> io::Result<Self> {
let mut handle = null_mut();
check_status(unsafe {
NtCreateWaitCompletionPacket(&mut handle, GENERIC_READ | GENERIC_WRITE, null_mut())
Expand Down
11 changes: 8 additions & 3 deletions compio-driver/src/iocp/wait/thread_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use windows_sys::Win32::{
Foundation::{ERROR_IO_PENDING, ERROR_TIMEOUT, WAIT_OBJECT_0, WAIT_TIMEOUT},
System::Threading::{
CloseThreadpoolWait, CreateThreadpoolWait, PTP_CALLBACK_INSTANCE, PTP_WAIT,
SetThreadpoolWait, WaitForThreadpoolWaitCallbacks,
SetThreadpoolWait, TP_CALLBACK_ENVIRON_V3, WaitForThreadpoolWaitCallbacks,
},
};

Expand All @@ -22,7 +22,12 @@ pub struct Wait {
}

impl Wait {
pub fn new(port: &cp::Port, event: RawFd, op: &mut Key<dyn OpCode>) -> io::Result<Self> {
pub fn new(
port: &cp::Port,
event: RawFd,
op: &mut Key<dyn OpCode>,
env: *const TP_CALLBACK_ENVIRON_V3,
) -> io::Result<Self> {
let port = port.handle();
let mut context = Box::new(WinThreadpoolWaitContext {
port,
Expand All @@ -33,7 +38,7 @@ impl Wait {
CreateThreadpoolWait(
Some(Self::wait_callback),
(&mut *context) as *mut WinThreadpoolWaitContext as _,
null()
env
)
)?;
unsafe {
Expand Down
2 changes: 1 addition & 1 deletion compio-driver/src/iour/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ impl Driver {
Ok(Self {
inner,
notifier,
pool: builder.create_or_get_thread_pool(),
pool: builder.create_or_get_thread_pool()?,
pool_completed: Arc::new(SegQueue::new()),
buffer_group_ids: Slab::new(),
need_push_notifier: true,
Expand Down
Loading
Loading