diff --git a/Cargo.lock b/Cargo.lock index 6c02d89..430aa4e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -825,6 +825,7 @@ version = "0.2.0" dependencies = [ "libc", "mctp", + "smol", ] [[package]] diff --git a/mctp-linux/Cargo.toml b/mctp-linux/Cargo.toml index 57b8bb7..1813f69 100644 --- a/mctp-linux/Cargo.toml +++ b/mctp-linux/Cargo.toml @@ -10,6 +10,7 @@ categories = ["network-programming", "embedded", "hardware-support", "os::linux- [dependencies] libc = "0.2" mctp = { workspace = true, features = ["std"] } +smol = { version = "2.0" } [[example]] name = "mctp-req" diff --git a/mctp-linux/examples/async-req.rs b/mctp-linux/examples/async-req.rs new file mode 100644 index 0000000..d66efe1 --- /dev/null +++ b/mctp-linux/examples/async-req.rs @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +/* + * Simple MCTP example using Linux sockets in async mode. + * + * Copyright (c) 2025 Code Construct + */ + +use mctp::{AsyncReqChannel, Eid, MCTP_TYPE_CONTROL}; +use mctp_linux::MctpLinuxAsyncReq; + +fn main() -> std::io::Result<()> { + const EID: Eid = Eid(8); + + let mut ep = MctpLinuxAsyncReq::new(EID, None)?; + + let tx_buf = vec![0x02u8]; + let mut rx_buf = vec![0u8; 16]; + + let (typ, ic, rx_buf) = smol::block_on(async { + ep.send(MCTP_TYPE_CONTROL, &tx_buf).await?; + ep.recv(&mut rx_buf).await + })?; + + println!("response type {typ}, ic {ic:?}: {rx_buf:x?}"); + + Ok(()) +} diff --git a/mctp-linux/src/lib.rs b/mctp-linux/src/lib.rs index c642975..3a64275 100644 --- a/mctp-linux/src/lib.rs +++ b/mctp-linux/src/lib.rs @@ -40,9 +40,10 @@ //! socket model. use core::mem; +use smol::Async; use std::fmt; use std::io::Error; -use std::os::unix::io::RawFd; +use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, OwnedFd, RawFd}; use std::time::Duration; use mctp::{ @@ -139,13 +140,7 @@ fn last_os_error() -> mctp::Error { } /// MCTP socket object. -pub struct MctpSocket(RawFd); - -impl Drop for MctpSocket { - fn drop(&mut self) { - unsafe { libc::close(self.0) }; - } -} +pub struct MctpSocket(OwnedFd); impl MctpSocket { /// Create a new MCTP socket. This can then be used for send/receive @@ -161,56 +156,78 @@ impl MctpSocket { if rc < 0 { return Err(last_os_error()); } - Ok(MctpSocket(rc)) + // safety: the fd is valid, and we have exclusive ownership + let fd = unsafe { OwnedFd::from_raw_fd(rc) }; + Ok(MctpSocket(fd)) } - /// Blocking receive from a socket, into `buf`, returning a length - /// and source address - /// - /// Essentially a wrapper around [libc::recvfrom], using MCTP-specific - /// addressing. - pub fn recvfrom(&self, buf: &mut [u8]) -> Result<(usize, MctpSockAddr)> { + // Inner recvfrom, returning an io::Error on failure. This can be + // used with async wrappers. + fn io_recvfrom( + &self, + buf: &mut [u8], + ) -> std::io::Result<(usize, MctpSockAddr)> { let mut addr = MctpSockAddr::zero(); let (addr_ptr, mut addr_len) = addr.as_raw_mut(); let buf_ptr = buf.as_mut_ptr() as *mut libc::c_void; let buf_len = buf.len() as libc::size_t; + let fd = self.as_raw_fd(); let rc = unsafe { - libc::recvfrom(self.0, buf_ptr, buf_len, 0, addr_ptr, &mut addr_len) + libc::recvfrom(fd, buf_ptr, buf_len, 0, addr_ptr, &mut addr_len) }; if rc < 0 { - Err(last_os_error()) + Err(Error::last_os_error()) } else { Ok((rc as usize, addr)) } } - /// Blocking send to a socket, given a buffer and address, returning - /// the number of bytes sent. + /// Blocking receive from a socket, into `buf`, returning a length + /// and source address /// - /// Essentially a wrapper around [libc::sendto]. - pub fn sendto(&self, buf: &[u8], addr: &MctpSockAddr) -> Result { + /// Essentially a wrapper around [libc::recvfrom], using MCTP-specific + /// addressing. + pub fn recvfrom(&self, buf: &mut [u8]) -> Result<(usize, MctpSockAddr)> { + self.io_recvfrom(buf).map_err(mctp::Error::Io) + } + + fn io_sendto( + &self, + buf: &[u8], + addr: &MctpSockAddr, + ) -> std::io::Result { let (addr_ptr, addr_len) = addr.as_raw(); let buf_ptr = buf.as_ptr() as *const libc::c_void; let buf_len = buf.len() as libc::size_t; + let fd = self.as_raw_fd(); let rc = unsafe { - libc::sendto(self.0, buf_ptr, buf_len, 0, addr_ptr, addr_len) + libc::sendto(fd, buf_ptr, buf_len, 0, addr_ptr, addr_len) }; if rc < 0 { - Err(last_os_error()) + Err(Error::last_os_error()) } else { Ok(rc as usize) } } + /// Blocking send to a socket, given a buffer and address, returning + /// the number of bytes sent. + /// + /// Essentially a wrapper around [libc::sendto]. + pub fn sendto(&self, buf: &[u8], addr: &MctpSockAddr) -> Result { + self.io_sendto(buf, addr).map_err(mctp::Error::Io) + } + /// Bind the socket to a local address. pub fn bind(&self, addr: &MctpSockAddr) -> Result<()> { let (addr_ptr, addr_len) = addr.as_raw(); + let fd = self.as_raw_fd(); - let rc = unsafe { libc::bind(self.0, addr_ptr, addr_len) }; + let rc = unsafe { libc::bind(fd, addr_ptr, addr_len) }; if rc < 0 { Err(last_os_error()) @@ -231,9 +248,10 @@ impl MctpSocket { tv_sec: dur.as_secs() as libc::time_t, tv_usec: dur.subsec_micros() as libc::suseconds_t, }; + let fd = self.as_raw_fd(); let rc = unsafe { libc::setsockopt( - self.0, + fd, libc::SOL_SOCKET, libc::SO_RCVTIMEO, (&tv as *const libc::timeval) as *const libc::c_void, @@ -260,9 +278,10 @@ impl MctpSocket { let mut tv = std::mem::MaybeUninit::::uninit(); let mut tv_len = std::mem::size_of::() as libc::socklen_t; + let fd = self.as_raw_fd(); let rc = unsafe { libc::getsockopt( - self.0, + fd, libc::SOL_SOCKET, libc::SO_RCVTIMEO, tv.as_mut_ptr() as *mut libc::c_void, @@ -293,7 +312,58 @@ impl MctpSocket { impl std::os::fd::AsRawFd for MctpSocket { fn as_raw_fd(&self) -> RawFd { + self.0.as_raw_fd() + } +} + +impl AsFd for MctpSocket { + fn as_fd(&self) -> BorrowedFd<'_> { + self.0.as_fd() + } +} + +/// MCTP socket for async use +pub struct MctpSocketAsync(Async); + +impl MctpSocketAsync { + /// Create a new async MCTP socket + pub fn new() -> Result { + let sock = MctpSocket::new()?; + let sock = Async::new(sock).map_err(mctp::Error::Io)?; + + Ok(Self(sock)) + } + + /// Bind the socket to a local address. + pub fn bind(&self, addr: &MctpSockAddr) -> Result<()> { + self.0.as_ref().bind(addr) + } + + /// Receive a message from this socket + /// + /// Returns the length of buffer read, and the peer address. + pub async fn recvfrom( + &self, + buf: &mut [u8], + ) -> Result<(usize, MctpSockAddr)> { + self.0 + .read_with(|io| io.io_recvfrom(buf)) + .await + .map_err(mctp::Error::Io) + } + + /// Send a message to a given address + /// + /// Returns the number of bytes sent + pub async fn sendto( + &self, + buf: &[u8], + addr: &MctpSockAddr, + ) -> Result { self.0 + .write_with(|io| io.io_sendto(buf, addr)) + .await + .map_err(mctp::Error::Io) } } @@ -382,6 +452,71 @@ impl mctp::ReqChannel for MctpLinuxReq { } } +/// Encapsulation of a remote endpoint: a socket and an Endpoint ID. +pub struct MctpLinuxAsyncReq { + eid: Eid, + net: u32, + sock: MctpSocketAsync, + sent: bool, +} + +impl MctpLinuxAsyncReq { + /// Create a new asynchronous request channel. + pub fn new(eid: Eid, net: Option) -> Result { + let net = net.unwrap_or(MCTP_NET_ANY); + Ok(Self { + eid, + net, + sock: MctpSocketAsync::new()?, + sent: false, + }) + } +} + +impl mctp::AsyncReqChannel for MctpLinuxAsyncReq { + fn remote_eid(&self) -> Eid { + self.eid + } + + async fn send_vectored( + &mut self, + typ: MsgType, + ic: MsgIC, + bufs: &[&[u8]], + ) -> Result<()> { + let typ_ic = mctp::encode_type_ic(typ, ic); + let addr = MctpSockAddr::new( + self.eid.0, + self.net, + typ_ic, + mctp::MCTP_TAG_OWNER, + ); + let concat = bufs + .iter() + .flat_map(|b| b.iter().cloned()) + .collect::>(); + self.sock.sendto(&concat, &addr).await?; + self.sent = true; + Ok(()) + } + + async fn recv<'f>( + &mut self, + buf: &'f mut [u8], + ) -> Result<(MsgType, MsgIC, &'f mut [u8])> { + if !self.sent { + return Err(mctp::Error::BadArgument); + } + let (sz, addr) = self.sock.recvfrom(buf).await?; + let src = Eid(addr.0.smctp_addr); + let (typ, ic) = mctp::decode_type_ic(addr.0.smctp_type); + if src != self.eid { + return Err(mctp::Error::Other); + } + Ok((typ, ic, &mut buf[..sz])) + } +} + /// A Listener for Linux MCTP messages pub struct MctpLinuxListener { sock: MctpSocket, @@ -452,6 +587,62 @@ impl mctp::Listener for MctpLinuxListener { } } +/// An MCTP Listener for asynchronous IO +pub struct MctpLinuxAsyncListener { + sock: MctpSocketAsync, + net: u32, + typ: MsgType, +} + +impl MctpLinuxAsyncListener { + /// Create a new `MctpLinuxAsyncListener`. + /// + /// This will listen for MCTP message type `typ`, on an optional + /// Linux network `net`. `None` network defaults to `MCTP_NET_ANY`. + pub fn new(typ: MsgType, net: Option) -> Result { + let sock = MctpSocketAsync::new()?; + // Linux requires MCTP_ADDR_ANY for binds. + let net = net.unwrap_or(MCTP_NET_ANY); + let addr = MctpSockAddr::new( + MCTP_ADDR_ANY.0, + net, + typ.0, + mctp::MCTP_TAG_OWNER, + ); + sock.bind(&addr)?; + Ok(Self { sock, net, typ }) + } +} + +impl mctp::AsyncListener for MctpLinuxAsyncListener { + type RespChannel<'a> = MctpLinuxAsyncResp<'a>; + + async fn recv<'f>( + &mut self, + buf: &'f mut [u8], + ) -> Result<(MsgType, MsgIC, &'f mut [u8], Self::RespChannel<'_>)> { + let (sz, addr) = self.sock.recvfrom(buf).await?; + let src = Eid(addr.0.smctp_addr); + let (typ, ic) = mctp::decode_type_ic(addr.0.smctp_type); + let tag = tag_from_smctp(addr.0.smctp_tag); + if let Tag::Unowned(_) = tag { + // bind() shouldn't give non-owned packets. + return Err(mctp::Error::InternalError); + } + if typ != self.typ { + // bind() should return the requested type + return Err(mctp::Error::InternalError); + } + let ep = MctpLinuxAsyncResp { + eid: src, + tv: tag.tag(), + listener: self, + typ, + }; + Ok((typ, ic, &mut buf[..sz], ep)) + } +} + /// A Linux MCTP Listener response channel pub struct MctpLinuxResp<'a> { eid: Eid, @@ -491,6 +682,42 @@ impl mctp::RespChannel for MctpLinuxResp<'_> { } } +/// A Linux MCTP Async Listener response channel +pub struct MctpLinuxAsyncResp<'l> { + eid: Eid, + tv: TagValue, + listener: &'l MctpLinuxAsyncListener, + typ: MsgType, +} + +impl<'l> mctp::AsyncRespChannel for MctpLinuxAsyncResp<'l> { + type ReqChannel<'a> + = MctpLinuxAsyncReq + where + Self: 'a; + + async fn send_vectored(&mut self, ic: MsgIC, bufs: &[&[u8]]) -> Result<()> { + let typ_ic = mctp::encode_type_ic(self.typ, ic); + let tag = tag_to_smctp(&Tag::Unowned(self.tv)); + let addr = + MctpSockAddr::new(self.eid.0, self.listener.net, typ_ic, tag); + let concat = bufs + .iter() + .flat_map(|b| b.iter().cloned()) + .collect::>(); + self.listener.sock.sendto(&concat, &addr).await?; + Ok(()) + } + + fn remote_eid(&self) -> Eid { + self.eid + } + + fn req_channel(&self) -> Result> { + MctpLinuxAsyncReq::new(self.eid, Some(self.listener.net)) + } +} + /// Helper for applications taking an MCTP address as an argument, /// configuration, etc. /// @@ -570,4 +797,17 @@ impl MctpAddr { pub fn create_listener(&self, typ: MsgType) -> Result { MctpLinuxListener::new(typ, self.net) } + + /// Create an `MctpLinuxAsyncReq` using the net & eid values in this address. + pub fn create_req_async(&self) -> Result { + MctpLinuxAsyncReq::new(self.eid, self.net) + } + + /// Create an `MctpLinuxAsyncListener`. + pub fn create_listener_async( + &self, + typ: MsgType, + ) -> Result { + MctpLinuxAsyncListener::new(typ, self.net) + } }