|
| 1 | +//! Linux-specific types for signal handling. |
| 2 | +
|
| 3 | +use std::{ |
| 4 | + cell::RefCell, collections::HashMap, io, mem::MaybeUninit, os::fd::FromRawFd, ptr::null_mut, |
| 5 | + thread_local, |
| 6 | +}; |
| 7 | + |
| 8 | +use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, SetBufInit}; |
| 9 | +use compio_driver::{op::Recv, syscall, OwnedFd, SharedFd}; |
| 10 | + |
| 11 | +thread_local! { |
| 12 | + static REG_MAP: RefCell<HashMap<i32, usize>> = RefCell::new(HashMap::new()); |
| 13 | +} |
| 14 | + |
| 15 | +fn sigset(sig: i32) -> io::Result<libc::sigset_t> { |
| 16 | + let mut set: MaybeUninit<libc::sigset_t> = MaybeUninit::uninit(); |
| 17 | + syscall!(libc::sigemptyset(set.as_mut_ptr()))?; |
| 18 | + syscall!(libc::sigaddset(set.as_mut_ptr(), sig))?; |
| 19 | + // SAFETY: sigemptyset initializes the set. |
| 20 | + Ok(unsafe { set.assume_init() }) |
| 21 | +} |
| 22 | + |
| 23 | +fn register_signal(sig: i32) -> io::Result<libc::sigset_t> { |
| 24 | + REG_MAP.with_borrow_mut(|map| { |
| 25 | + let count = map.entry(sig).or_default(); |
| 26 | + let set = sigset(sig)?; |
| 27 | + if *count == 0 { |
| 28 | + syscall!(libc::pthread_sigmask(libc::SIG_BLOCK, &set, null_mut()))?; |
| 29 | + } |
| 30 | + *count += 1; |
| 31 | + Ok(set) |
| 32 | + }) |
| 33 | +} |
| 34 | + |
| 35 | +fn unregister_signal(sig: i32) -> io::Result<libc::sigset_t> { |
| 36 | + REG_MAP.with_borrow_mut(|map| { |
| 37 | + let count = map.entry(sig).or_default(); |
| 38 | + if *count > 0 { |
| 39 | + *count -= 1; |
| 40 | + } |
| 41 | + let set = sigset(sig)?; |
| 42 | + if *count == 0 { |
| 43 | + syscall!(libc::pthread_sigmask(libc::SIG_UNBLOCK, &set, null_mut()))?; |
| 44 | + } |
| 45 | + Ok(set) |
| 46 | + }) |
| 47 | +} |
| 48 | + |
| 49 | +/// Represents a listener to unix signal event. |
| 50 | +#[derive(Debug)] |
| 51 | +struct SignalFd { |
| 52 | + fd: SharedFd<OwnedFd>, |
| 53 | + sig: i32, |
| 54 | +} |
| 55 | + |
| 56 | +impl SignalFd { |
| 57 | + fn new(sig: i32) -> io::Result<Self> { |
| 58 | + let set = register_signal(sig)?; |
| 59 | + let mut flag = libc::SFD_CLOEXEC; |
| 60 | + if cfg!(not(feature = "io-uring")) { |
| 61 | + flag |= libc::SFD_NONBLOCK; |
| 62 | + } |
| 63 | + let fd = syscall!(libc::signalfd(-1, &set, flag))?; |
| 64 | + let fd = unsafe { OwnedFd::from_raw_fd(fd) }; |
| 65 | + Ok(Self { |
| 66 | + fd: SharedFd::new(fd), |
| 67 | + sig, |
| 68 | + }) |
| 69 | + } |
| 70 | + |
| 71 | + async fn wait(self) -> io::Result<()> { |
| 72 | + const INFO_SIZE: usize = std::mem::size_of::<libc::signalfd_siginfo>(); |
| 73 | + |
| 74 | + struct SignalInfo(MaybeUninit<libc::signalfd_siginfo>); |
| 75 | + |
| 76 | + unsafe impl IoBuf for SignalInfo { |
| 77 | + fn as_buf_ptr(&self) -> *const u8 { |
| 78 | + self.0.as_ptr().cast() |
| 79 | + } |
| 80 | + |
| 81 | + fn buf_len(&self) -> usize { |
| 82 | + 0 |
| 83 | + } |
| 84 | + |
| 85 | + fn buf_capacity(&self) -> usize { |
| 86 | + INFO_SIZE |
| 87 | + } |
| 88 | + } |
| 89 | + |
| 90 | + unsafe impl IoBufMut for SignalInfo { |
| 91 | + fn as_buf_mut_ptr(&mut self) -> *mut u8 { |
| 92 | + self.0.as_mut_ptr().cast() |
| 93 | + } |
| 94 | + } |
| 95 | + |
| 96 | + impl SetBufInit for SignalInfo { |
| 97 | + unsafe fn set_buf_init(&mut self, len: usize) { |
| 98 | + debug_assert!(len <= INFO_SIZE) |
| 99 | + } |
| 100 | + } |
| 101 | + |
| 102 | + let info = SignalInfo(MaybeUninit::<libc::signalfd_siginfo>::uninit()); |
| 103 | + let op = Recv::new(self.fd.clone(), info); |
| 104 | + let BufResult(res, op) = compio_runtime::submit(op).await; |
| 105 | + let len = res?; |
| 106 | + debug_assert_eq!(len, INFO_SIZE); |
| 107 | + let info = op.into_inner(); |
| 108 | + let info = unsafe { info.0.assume_init() }; |
| 109 | + debug_assert_eq!(info.ssi_signo, self.sig as u32); |
| 110 | + Ok(()) |
| 111 | + } |
| 112 | +} |
| 113 | + |
| 114 | +impl Drop for SignalFd { |
| 115 | + fn drop(&mut self) { |
| 116 | + unregister_signal(self.sig).ok(); |
| 117 | + } |
| 118 | +} |
| 119 | + |
| 120 | +/// Creates a new listener which will receive notifications when the current |
| 121 | +/// process receives the specified signal. |
| 122 | +/// |
| 123 | +/// It sets the signal mask of the current thread. |
| 124 | +pub async fn signal(sig: i32) -> io::Result<()> { |
| 125 | + let fd = SignalFd::new(sig)?; |
| 126 | + fd.wait().await?; |
| 127 | + Ok(()) |
| 128 | +} |
0 commit comments