diff --git a/Cargo.lock b/Cargo.lock index 68714b97269..c88ca2d2ce1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1265,12 +1265,6 @@ dependencies = [ "autocfg", ] -[[package]] -name = "smallvec" -version = "1.13.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" - [[package]] name = "snapshot-editor" version = "1.10.0-dev" @@ -1575,7 +1569,6 @@ dependencies = [ "serde", "serde_json", "slab", - "smallvec", "thiserror", "timerfd", "userfaultfd", diff --git a/src/vmm/Cargo.toml b/src/vmm/Cargo.toml index e2b7e2a6766..52927579756 100644 --- a/src/vmm/Cargo.toml +++ b/src/vmm/Cargo.toml @@ -34,7 +34,6 @@ semver = { version = "1.0.23", features = ["serde"] } serde = { version = "1.0.210", features = ["derive", "rc"] } serde_json = "1.0.128" slab = "0.4.7" -smallvec = "1.11.2" thiserror = "1.0.64" timerfd = "1.5.0" userfaultfd = "0.8.1" diff --git a/src/vmm/src/devices/virtio/iovec.rs b/src/vmm/src/devices/virtio/iovec.rs index 3acde02fc05..1f6680aa8e7 100644 --- a/src/vmm/src/devices/virtio/iovec.rs +++ b/src/vmm/src/devices/virtio/iovec.rs @@ -4,7 +4,6 @@ use std::io::ErrorKind; use libc::{c_void, iovec, size_t}; -use smallvec::SmallVec; use vm_memory::bitmap::Bitmap; use vm_memory::{ GuestMemory, GuestMemoryError, ReadVolatile, VolatileMemoryError, VolatileSlice, WriteVolatile, @@ -25,14 +24,6 @@ pub enum IoVecError { GuestMemory(#[from] GuestMemoryError), } -// Using SmallVec in the kani proofs causes kani to use unbounded amounts of memory -// during post-processing, and then crash. -// TODO: remove new-type once kani performance regression are resolved -#[cfg(kani)] -type IoVecVec = Vec; -#[cfg(not(kani))] -type IoVecVec = SmallVec<[iovec; 4]>; - /// This is essentially a wrapper of a `Vec` which can be passed to `libc::writev`. /// /// It describes a buffer passed to us by the guest that is scattered across multiple @@ -41,7 +32,7 @@ type IoVecVec = SmallVec<[iovec; 4]>; #[derive(Debug, Default)] pub struct IoVecBuffer { // container of the memory regions included in this IO vector - vecs: IoVecVec, + vecs: Vec, // Total length of the IoVecBuffer len: u32, } @@ -219,14 +210,18 @@ impl IoVecBuffer { /// It describes a write-only buffer passed to us by the guest that is scattered across multiple /// memory regions. Additionally, this wrapper provides methods that allow reading arbitrary ranges /// of data from that buffer. -#[derive(Debug, Default, Clone)] +#[derive(Debug, Default)] pub struct IoVecBufferMut { // container of the memory regions included in this IO vector - vecs: IoVecVec, + vecs: Vec, // Total length of the IoVecBufferMut len: u32, } +// SAFETY: `IoVecBufferMut` doesn't allow for interior mutability and no shared ownership is +// possible as it doesn't implement clone +unsafe impl Send for IoVecBufferMut {} + impl IoVecBufferMut { /// Create an `IoVecBuffer` from a `DescriptorChain` /// @@ -402,8 +397,7 @@ mod tests { vecs: vec![iovec { iov_base: buf.as_ptr() as *mut c_void, iov_len: buf.len(), - }] - .into(), + }], len: buf.len().try_into().unwrap(), } } @@ -433,8 +427,7 @@ mod tests { vecs: vec![iovec { iov_base: buf.as_mut_ptr().cast::(), iov_len: buf.len(), - }] - .into(), + }], len: buf.len().try_into().unwrap(), } } @@ -686,7 +679,7 @@ mod verification { use vm_memory::bitmap::BitmapSlice; use vm_memory::VolatileSlice; - use super::{IoVecBuffer, IoVecBufferMut, IoVecVec}; + use super::{IoVecBuffer, IoVecBufferMut}; // Maximum memory size to use for our buffers. For the time being 1KB. const GUEST_MEMORY_SIZE: usize = 1 << 10; @@ -698,7 +691,7 @@ mod verification { // >= 1. const MAX_DESC_LENGTH: usize = 4; - fn create_iovecs(mem: *mut u8, size: usize, nr_descs: usize) -> (IoVecVec, u32) { + fn create_iovecs(mem: *mut u8, size: usize, nr_descs: usize) -> (Vec, u32) { let mut vecs: Vec = Vec::with_capacity(nr_descs); let mut len = 0u32; for _ in 0..nr_descs { diff --git a/src/vmm/src/devices/virtio/vsock/csm/connection.rs b/src/vmm/src/devices/virtio/vsock/csm/connection.rs index 0307911cba9..e49ac56a39d 100644 --- a/src/vmm/src/devices/virtio/vsock/csm/connection.rs +++ b/src/vmm/src/devices/virtio/vsock/csm/connection.rs @@ -88,11 +88,11 @@ use vm_memory::GuestMemoryError; use vmm_sys_util::epoll::EventSet; use super::super::defs::uapi; -use super::super::packet::VsockPacket; use super::super::{VsockChannel, VsockEpollListener, VsockError}; use super::txbuf::TxBuf; use super::{defs, ConnState, PendingRx, PendingRxSet, VsockCsmError}; use crate::devices::virtio::vsock::metrics::METRICS; +use crate::devices::virtio::vsock::packet::{VsockPacketHeader, VsockPacketRx, VsockPacketTx}; use crate::logger::IncMetric; use crate::utils::wrap_usize_to_u32; @@ -160,16 +160,16 @@ where /// - `Err(VsockError::NoData)`: there was no data available with which to fill in the packet; /// - `Err(VsockError::PktBufMissing)`: the packet would've been filled in with data, but it is /// missing the data buffer. - fn recv_pkt(&mut self, pkt: &mut VsockPacket) -> Result<(), VsockError> { + fn recv_pkt(&mut self, pkt: &mut VsockPacketRx) -> Result<(), VsockError> { // Perform some generic initialization that is the same for any packet operation (e.g. // source, destination, credit, etc). - self.init_pkt(pkt); + self.init_pkt_hdr(&mut pkt.hdr); METRICS.rx_packets_count.inc(); // If forceful termination is pending, there's no point in checking for anything else. // It's dead, Jim. if self.pending_rx.remove(PendingRx::Rst) { - pkt.set_op(uapi::VSOCK_OP_RST); + pkt.hdr.set_op(uapi::VSOCK_OP_RST); return Ok(()); } @@ -177,7 +177,7 @@ where // in this packet. if self.pending_rx.remove(PendingRx::Response) { self.state = ConnState::Established; - pkt.set_op(uapi::VSOCK_OP_RESPONSE); + pkt.hdr.set_op(uapi::VSOCK_OP_RESPONSE); return Ok(()); } @@ -186,7 +186,7 @@ where if self.pending_rx.remove(PendingRx::Request) { self.expiry = Some(Instant::now() + Duration::from_millis(defs::CONN_REQUEST_TIMEOUT_MS)); - pkt.set_op(uapi::VSOCK_OP_REQUEST); + pkt.hdr.set_op(uapi::VSOCK_OP_REQUEST); return Ok(()); } @@ -201,7 +201,7 @@ where _ => { // Any other connection state is invalid at this point, and we need to kill it // with fire. - pkt.set_op(uapi::VSOCK_OP_RST); + pkt.hdr.set_op(uapi::VSOCK_OP_RST); return Ok(()); } } @@ -210,7 +210,7 @@ where // much bytey goodness? if self.need_credit_update_from_peer() { self.last_fwd_cnt_to_peer = self.fwd_cnt; - pkt.set_op(uapi::VSOCK_OP_CREDIT_REQUEST); + pkt.hdr.set_op(uapi::VSOCK_OP_CREDIT_REQUEST); return Ok(()); } @@ -229,7 +229,8 @@ where self.expiry = Some( Instant::now() + Duration::from_millis(defs::CONN_SHUTDOWN_TIMEOUT_MS), ); - pkt.set_op(uapi::VSOCK_OP_SHUTDOWN) + pkt.hdr + .set_op(uapi::VSOCK_OP_SHUTDOWN) .set_flag(uapi::VSOCK_FLAGS_SHUTDOWN_RCV) .set_flag(uapi::VSOCK_FLAGS_SHUTDOWN_SEND); } else { @@ -237,10 +238,10 @@ where // length of the read data. // Safe to unwrap because read_cnt is no more than max_len, which is bounded // by self.peer_avail_credit(), a u32 internally. - pkt.set_op(uapi::VSOCK_OP_RW).set_len(read_cnt); + pkt.hdr.set_op(uapi::VSOCK_OP_RW).set_len(read_cnt); METRICS.rx_bytes_count.add(read_cnt as u64); } - self.rx_cnt += Wrapping(pkt.len()); + self.rx_cnt += Wrapping(pkt.hdr.len()); self.last_fwd_cnt_to_peer = self.fwd_cnt; return Ok(()); } @@ -263,7 +264,7 @@ where "vsock: error reading from backing stream: lp={}, pp={}, err={:?}", self.local_port, self.peer_port, err ); - pkt.set_op(uapi::VSOCK_OP_RST); + pkt.hdr.set_op(uapi::VSOCK_OP_RST); self.last_fwd_cnt_to_peer = self.fwd_cnt; return Ok(()); } @@ -274,7 +275,7 @@ where // buffer on it if we really have nothing else to say, hence we check for this RX // indication last. if self.pending_rx.remove(PendingRx::CreditUpdate) && !self.has_pending_rx() { - pkt.set_op(uapi::VSOCK_OP_CREDIT_UPDATE); + pkt.hdr.set_op(uapi::VSOCK_OP_CREDIT_UPDATE); self.last_fwd_cnt_to_peer = self.fwd_cnt; return Ok(()); } @@ -291,10 +292,10 @@ where /// /// Returns: /// always `Ok(())`: the packet has been consumed; - fn send_pkt(&mut self, pkt: &VsockPacket) -> Result<(), VsockError> { + fn send_pkt(&mut self, pkt: &VsockPacketTx) -> Result<(), VsockError> { // Update the peer credit information. - self.peer_buf_alloc = pkt.buf_alloc(); - self.peer_fwd_cnt = Wrapping(pkt.fwd_cnt()); + self.peer_buf_alloc = pkt.hdr.buf_alloc(); + self.peer_fwd_cnt = Wrapping(pkt.hdr.fwd_cnt()); METRICS.tx_packets_count.inc(); match self.state { @@ -302,7 +303,7 @@ where // data to the host stream. Also works for a connection that has begun shutting // down, but the peer still has some data to send. ConnState::Established | ConnState::PeerClosed(_, false) - if pkt.op() == uapi::VSOCK_OP_RW => + if pkt.hdr.op() == uapi::VSOCK_OP_RW => { if pkt.buf_size() == 0 { info!( @@ -335,7 +336,7 @@ where // Next up: receiving a response / confirmation for a host-initiated connection. // We'll move to an Established state, and pass on the good news through the host // stream. - ConnState::LocalInit if pkt.op() == uapi::VSOCK_OP_RESPONSE => { + ConnState::LocalInit if pkt.hdr.op() == uapi::VSOCK_OP_RESPONSE => { self.expiry = None; self.state = ConnState::Established; } @@ -344,9 +345,9 @@ where // more to send nor receive, and we don't have to wait to drain our TX buffer, we // can schedule an RST packet (to terminate the connection on the next recv call). // Otherwise, we'll arm the kill timer. - ConnState::Established if pkt.op() == uapi::VSOCK_OP_SHUTDOWN => { - let recv_off = pkt.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_RCV != 0; - let send_off = pkt.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_SEND != 0; + ConnState::Established if pkt.hdr.op() == uapi::VSOCK_OP_SHUTDOWN => { + let recv_off = pkt.hdr.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_RCV != 0; + let send_off = pkt.hdr.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_SEND != 0; self.state = ConnState::PeerClosed(recv_off, send_off); if recv_off && send_off { if self.tx_buf.is_empty() { @@ -362,10 +363,10 @@ where // The peer wants to update a shutdown request, with more receive/send indications. // The same logic as above applies. ConnState::PeerClosed(ref mut recv_off, ref mut send_off) - if pkt.op() == uapi::VSOCK_OP_SHUTDOWN => + if pkt.hdr.op() == uapi::VSOCK_OP_SHUTDOWN => { - *recv_off = *recv_off || (pkt.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_RCV != 0); - *send_off = *send_off || (pkt.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_SEND != 0); + *recv_off = *recv_off || (pkt.hdr.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_RCV != 0); + *send_off = *send_off || (pkt.hdr.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_SEND != 0); if *recv_off && *send_off && self.tx_buf.is_empty() { self.pending_rx.insert(PendingRx::Rst); } @@ -374,7 +375,7 @@ where // A credit update from our peer is valid only in a state which allows data // transfer towards the peer. ConnState::Established | ConnState::PeerInit | ConnState::PeerClosed(false, _) - if pkt.op() == uapi::VSOCK_OP_CREDIT_UPDATE => + if pkt.hdr.op() == uapi::VSOCK_OP_CREDIT_UPDATE => { // Nothing to do here; we've already updated peer credit. } @@ -382,7 +383,7 @@ where // A credit request from our peer is valid only in a state which allows data // transfer from the peer. We'll respond with a credit update packet. ConnState::Established | ConnState::PeerInit | ConnState::PeerClosed(_, false) - if pkt.op() == uapi::VSOCK_OP_CREDIT_REQUEST => + if pkt.hdr.op() == uapi::VSOCK_OP_CREDIT_REQUEST => { self.pending_rx.insert(PendingRx::CreditUpdate); } @@ -390,8 +391,7 @@ where _ => { debug!( "vsock: dropping invalid TX pkt for connection: state={:?}, pkt.hdr={:?}", - self.state, - pkt.hdr() + self.state, pkt.hdr ); } }; @@ -603,8 +603,8 @@ where /// /// Raw data can either be sent straight to the host stream, or to our TX buffer, if the /// former fails. - fn send_bytes(&mut self, pkt: &VsockPacket) -> Result<(), VsockError> { - let len = pkt.len(); + fn send_bytes(&mut self, pkt: &VsockPacketTx) -> Result<(), VsockError> { + let len = pkt.hdr.len(); // If there is data in the TX buffer, that means we're already registered for EPOLLOUT // events on the underlying stream. Therefore, there's no point in attempting a write @@ -666,14 +666,14 @@ where } /// Prepare a packet header for transmission to our peer. - fn init_pkt<'a>(&self, pkt: &'a mut VsockPacket) -> &'a mut VsockPacket { - pkt.set_src_cid(self.local_cid) + fn init_pkt_hdr(&self, hdr: &mut VsockPacketHeader) { + hdr.set_src_cid(self.local_cid) .set_dst_cid(self.peer_cid) .set_src_port(self.local_port) .set_dst_port(self.peer_port) .set_type(uapi::VSOCK_TYPE_STREAM) .set_buf_alloc(defs::CONN_TX_BUF_SIZE) - .set_fwd_cnt(self.fwd_cnt.0) + .set_fwd_cnt(self.fwd_cnt.0); } } @@ -822,15 +822,15 @@ mod tests { } } - fn init_pkt(pkt: &mut VsockPacket, op: u16, len: u32) -> &mut VsockPacket { - pkt.set_src_cid(PEER_CID) + fn init_pkt_hdr(hdr: &mut VsockPacketHeader, op: u16, len: u32) { + hdr.set_src_cid(PEER_CID) .set_dst_cid(LOCAL_CID) .set_src_port(PEER_PORT) .set_dst_port(LOCAL_PORT) .set_type(uapi::VSOCK_TYPE_STREAM) .set_buf_alloc(PEER_BUF_ALLOC) .set_op(op) - .set_len(len) + .set_len(len); } // This is the connection state machine test context: a helper struct to provide CSM testing @@ -846,8 +846,8 @@ mod tests { struct CsmTestContext { _vsock_test_ctx: TestContext, // Two views of the same in-memory packet. rx-view for writing, tx-view for reading - rx_pkt: VsockPacket, - tx_pkt: VsockPacket, + rx_pkt: VsockPacketRx, + tx_pkt: VsockPacketTx, conn: VsockConnection, } @@ -860,16 +860,20 @@ mod tests { let vsock_test_ctx = TestContext::new(); let mut handler_ctx = vsock_test_ctx.create_event_handler_context(); let stream = TestStream::new(); - let mut rx_pkt = VsockPacket::from_rx_virtq_head( - &vsock_test_ctx.mem, - handler_ctx.device.queues[RXQ_INDEX].pop().unwrap(), - ) - .unwrap(); - let tx_pkt = VsockPacket::from_tx_virtq_head( - &vsock_test_ctx.mem, - handler_ctx.device.queues[TXQ_INDEX].pop().unwrap(), - ) - .unwrap(); + let mut rx_pkt = VsockPacketRx::default(); + rx_pkt + .parse( + &vsock_test_ctx.mem, + handler_ctx.device.queues[RXQ_INDEX].pop().unwrap(), + ) + .unwrap(); + let mut tx_pkt = VsockPacketTx::default(); + tx_pkt + .parse( + &vsock_test_ctx.mem, + handler_ctx.device.queues[TXQ_INDEX].pop().unwrap(), + ) + .unwrap(); let conn = match conn_state { ConnState::PeerInit => VsockConnection::::new_peer_init( stream, @@ -893,7 +897,7 @@ mod tests { ); assert!(conn.has_pending_rx()); conn.recv_pkt(&mut rx_pkt).unwrap(); - assert_eq!(rx_pkt.op(), uapi::VSOCK_OP_RESPONSE); + assert_eq!(rx_pkt.hdr.op(), uapi::VSOCK_OP_RESPONSE); conn } other => panic!("invalid ctx state: {:?}", other), @@ -935,11 +939,12 @@ mod tests { self.conn.notify(EventSet::OUT); } - fn init_tx_pkt(&mut self, op: u16, len: u32) -> &mut VsockPacket { - init_pkt(&mut self.tx_pkt, op, len) + fn init_tx_pkt(&mut self, op: u16, len: u32) -> &mut VsockPacketTx { + init_pkt_hdr(&mut self.tx_pkt.hdr, op, len); + &mut self.tx_pkt } - fn init_data_tx_pkt(&mut self, mut data: &[u8]) -> &VsockPacket { + fn init_data_tx_pkt(&mut self, mut data: &[u8]) -> &VsockPacketTx { assert!(data.len() <= self.tx_pkt.buf_size() as usize); self.init_tx_pkt(uapi::VSOCK_OP_RW, u32::try_from(data.len()).unwrap()); @@ -958,13 +963,13 @@ mod tests { ctx.recv(); // For peer-initiated requests, our connection should always yield a vsock reponse packet, // in order to establish the connection. - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_RESPONSE); - assert_eq!(ctx.rx_pkt.src_cid(), LOCAL_CID); - assert_eq!(ctx.rx_pkt.dst_cid(), PEER_CID); - assert_eq!(ctx.rx_pkt.src_port(), LOCAL_PORT); - assert_eq!(ctx.rx_pkt.dst_port(), PEER_PORT); - assert_eq!(ctx.rx_pkt.type_(), uapi::VSOCK_TYPE_STREAM); - assert_eq!(ctx.rx_pkt.len(), 0); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RESPONSE); + assert_eq!(ctx.rx_pkt.hdr.src_cid(), LOCAL_CID); + assert_eq!(ctx.rx_pkt.hdr.dst_cid(), PEER_CID); + assert_eq!(ctx.rx_pkt.hdr.src_port(), LOCAL_PORT); + assert_eq!(ctx.rx_pkt.hdr.dst_port(), PEER_PORT); + assert_eq!(ctx.rx_pkt.hdr.type_(), uapi::VSOCK_TYPE_STREAM); + assert_eq!(ctx.rx_pkt.hdr.len(), 0); // After yielding the response packet, the connection should have transitioned to the // established state. assert_eq!(ctx.conn.state, ConnState::Established); @@ -979,7 +984,7 @@ mod tests { // armed. assert!(!ctx.conn.will_expire()); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_REQUEST); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_REQUEST); // Since the request might time-out, the kill timer should now be armed. assert!(ctx.conn.will_expire()); assert!(!ctx.conn.has_expired()); @@ -995,7 +1000,7 @@ mod tests { fn test_local_request_timeout() { let mut ctx = CsmTestContext::new(ConnState::LocalInit); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_REQUEST); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_REQUEST); assert!(ctx.conn.will_expire()); assert!(!ctx.conn.has_expired()); std::thread::sleep(std::time::Duration::from_millis( @@ -1012,14 +1017,15 @@ mod tests { assert_eq!(ctx.conn.as_raw_fd(), ctx.conn.stream.as_raw_fd()); ctx.notify_epollin(); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_RW); - assert_eq!(ctx.rx_pkt.len() as usize, data.len()); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RW); + assert_eq!(ctx.rx_pkt.hdr.len() as usize, data.len()); let buf = test_utils::read_packet_data(&ctx.tx_pkt, 4); assert_eq!(&buf, data); // There's no more data in the stream, so `recv_pkt` should yield `VsockError::NoData`. - match ctx.conn.recv_pkt(&mut ctx.tx_pkt) { + // match ctx.conn.recv_pkt(&mut ctx.tx_pkt) { + match ctx.conn.recv_pkt(&mut ctx.rx_pkt) { Err(VsockError::NoData) => (), other => panic!("{:?}", other), } @@ -1028,7 +1034,7 @@ mod tests { ctx.conn.state = ConnState::LocalClosed; ctx.notify_epollin(); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_RST); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RST); } #[test] @@ -1042,9 +1048,9 @@ mod tests { // When the host-side stream is closed, we can neither send not receive any more data. // Therefore, the vsock shutdown packet that we'll deliver to the guest must contain both // the no-more-send and the no-more-recv indications. - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_SHUTDOWN); - assert_ne!(ctx.rx_pkt.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_SEND, 0); - assert_ne!(ctx.rx_pkt.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_RCV, 0); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_SHUTDOWN); + assert_ne!(ctx.rx_pkt.hdr.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_SEND, 0); + assert_ne!(ctx.rx_pkt.hdr.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_RCV, 0); // The kill timer should now be armed. assert!(ctx.conn.will_expire()); @@ -1061,14 +1067,14 @@ mod tests { { let mut ctx = CsmTestContext::new_established(); - ctx.init_tx_pkt(uapi::VSOCK_OP_SHUTDOWN, 0) - .set_flags(uapi::VSOCK_FLAGS_SHUTDOWN_RCV); + let tx_pkt = ctx.init_tx_pkt(uapi::VSOCK_OP_SHUTDOWN, 0); + tx_pkt.hdr.set_flags(uapi::VSOCK_FLAGS_SHUTDOWN_RCV); ctx.send(); assert_eq!(ctx.conn.state, ConnState::PeerClosed(true, false)); // Attempting to reset the no-more-recv indication should not work // (we are only setting the no-more-send indication here). - ctx.tx_pkt.set_flags(uapi::VSOCK_FLAGS_SHUTDOWN_SEND); + ctx.tx_pkt.hdr.set_flags(uapi::VSOCK_FLAGS_SHUTDOWN_SEND); ctx.send(); assert_eq!(ctx.conn.state, ConnState::PeerClosed(true, true)); } @@ -1080,12 +1086,12 @@ mod tests { let data = &[1, 2, 3, 4]; let mut ctx = CsmTestContext::new_established(); ctx.set_stream(TestStream::new_with_read_buf(data)); - ctx.init_tx_pkt(uapi::VSOCK_OP_SHUTDOWN, 0) - .set_flags(uapi::VSOCK_FLAGS_SHUTDOWN_SEND); + let tx_pkt = ctx.init_tx_pkt(uapi::VSOCK_OP_SHUTDOWN, 0); + tx_pkt.hdr.set_flags(uapi::VSOCK_FLAGS_SHUTDOWN_SEND); ctx.send(); ctx.notify_epollin(); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_RW); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RW); let buf = test_utils::read_packet_data(&ctx.tx_pkt, 4); assert_eq!(&buf, data); @@ -1101,8 +1107,8 @@ mod tests { // - attempting to read data from it should yield an RST packet. { let mut ctx = CsmTestContext::new_established(); - ctx.init_tx_pkt(uapi::VSOCK_OP_SHUTDOWN, 0) - .set_flags(uapi::VSOCK_FLAGS_SHUTDOWN_RCV); + let tx_pkt = ctx.init_tx_pkt(uapi::VSOCK_OP_SHUTDOWN, 0); + tx_pkt.hdr.set_flags(uapi::VSOCK_FLAGS_SHUTDOWN_RCV); ctx.send(); let data = &[1, 2, 3, 4]; ctx.init_data_tx_pkt(data); @@ -1111,19 +1117,21 @@ mod tests { ctx.notify_epollin(); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_RST); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RST); } // Test case: setting both no-more-send and no-more-recv indications should have the // connection confirm termination (i.e. yield an RST). { let mut ctx = CsmTestContext::new_established(); - ctx.init_tx_pkt(uapi::VSOCK_OP_SHUTDOWN, 0) + let tx_pkt = ctx.init_tx_pkt(uapi::VSOCK_OP_SHUTDOWN, 0); + tx_pkt + .hdr .set_flags(uapi::VSOCK_FLAGS_SHUTDOWN_RCV | uapi::VSOCK_FLAGS_SHUTDOWN_SEND); ctx.send(); assert!(ctx.conn.has_pending_rx()); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_RST); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RST); } } @@ -1135,7 +1143,7 @@ mod tests { ctx.set_stream(stream); ctx.notify_epollin(); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_RST); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RST); } #[test] @@ -1144,7 +1152,7 @@ mod tests { ctx.set_peer_credit(0); ctx.notify_epollin(); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_CREDIT_REQUEST); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_CREDIT_REQUEST); } #[test] @@ -1154,9 +1162,9 @@ mod tests { ctx.send(); assert!(ctx.conn.has_pending_rx()); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_CREDIT_UPDATE); - assert_eq!(ctx.rx_pkt.buf_alloc(), csm_defs::CONN_TX_BUF_SIZE); - assert_eq!(ctx.rx_pkt.fwd_cnt(), ctx.conn.fwd_cnt.0); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_CREDIT_UPDATE); + assert_eq!(ctx.rx_pkt.hdr.buf_alloc(), csm_defs::CONN_TX_BUF_SIZE); + assert_eq!(ctx.rx_pkt.hdr.fwd_cnt(), ctx.conn.fwd_cnt.0); } #[test] @@ -1188,9 +1196,9 @@ mod tests { // The CSM should now have a credit update available for the peer. assert!(ctx.conn.has_pending_rx()); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_CREDIT_UPDATE); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_CREDIT_UPDATE); assert_eq!( - ctx.rx_pkt.fwd_cnt() as usize, + ctx.rx_pkt.hdr.fwd_cnt() as usize, initial_fwd_cnt as usize + data.len() * 2, ); assert_eq!(ctx.conn.fwd_cnt, ctx.conn.last_fwd_cnt_to_peer); @@ -1245,7 +1253,7 @@ mod tests { assert_eq!(ctx.conn.state, ConnState::Killed); assert!(ctx.conn.has_pending_rx()); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_RST); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RST); } // Test case: notifying a connection that it can flush its TX buffer to a broken stream @@ -1296,6 +1304,6 @@ mod tests { assert_eq!(ctx.conn.state, ConnState::Killed); assert!(ctx.conn.has_pending_rx()); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_RST); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RST); } } diff --git a/src/vmm/src/devices/virtio/vsock/device.rs b/src/vmm/src/devices/virtio/vsock/device.rs index afd1d8348d4..82b4495749c 100644 --- a/src/vmm/src/devices/virtio/vsock/device.rs +++ b/src/vmm/src/devices/virtio/vsock/device.rs @@ -27,7 +27,7 @@ use vmm_sys_util::eventfd::EventFd; use super::super::super::DeviceError; use super::defs::uapi; -use super::packet::{VsockPacket, VSOCK_PKT_HDR_SIZE}; +use super::packet::{VsockPacketRx, VsockPacketTx, VSOCK_PKT_HDR_SIZE}; use super::{defs, VsockBackend}; use crate::devices::virtio::device::{DeviceState, IrqTrigger, IrqType, VirtioDevice}; use crate::devices::virtio::queue::Queue as VirtQueue; @@ -68,6 +68,9 @@ pub struct Vsock { // continuous triggers from happening before the device gets activated. pub(crate) activate_evt: EventFd, pub(crate) device_state: DeviceState, + + pub rx_packet: VsockPacketRx, + pub tx_packet: VsockPacketTx, } // TODO: Detect / handle queue deadlock: @@ -101,6 +104,8 @@ where irq_trigger: IrqTrigger::new().map_err(VsockError::EventFd)?, activate_evt: EventFd::new(libc::EFD_NONBLOCK).map_err(VsockError::EventFd)?, device_state: DeviceState::Inactive, + rx_packet: VsockPacketRx::default(), + tx_packet: VsockPacketTx::default(), }) } @@ -147,14 +152,14 @@ where while let Some(head) = self.queues[RXQ_INDEX].pop() { let index = head.index; - let used_len = match VsockPacket::from_rx_virtq_head(mem, head) { - Ok(mut pkt) => { - if self.backend.recv_pkt(&mut pkt).is_ok() { - match pkt.commit_hdr() { + let used_len = match self.rx_packet.parse(mem, head) { + Ok(()) => { + if self.backend.recv_pkt(&mut self.rx_packet).is_ok() { + match self.rx_packet.commit_hdr() { // This addition cannot overflow, because packet length // is previously validated against `MAX_PKT_BUF_SIZE` // bound as part of `commit_hdr()`. - Ok(()) => VSOCK_PKT_HDR_SIZE + pkt.len(), + Ok(()) => VSOCK_PKT_HDR_SIZE + self.rx_packet.hdr.len(), Err(err) => { warn!( "vsock: Error writing packet header to guest memory: \ @@ -200,8 +205,9 @@ where while let Some(head) = self.queues[TXQ_INDEX].pop() { let index = head.index; - let pkt = match VsockPacket::from_tx_virtq_head(mem, head) { - Ok(pkt) => pkt, + // let pkt = match VsockPacket::from_tx_virtq_head(mem, head) { + match self.tx_packet.parse(mem, head) { + Ok(()) => (), Err(err) => { error!("vsock: error reading TX packet: {:?}", err); have_used = true; @@ -214,7 +220,7 @@ where } }; - if self.backend.send_pkt(&pkt).is_err() { + if self.backend.send_pkt(&self.tx_packet).is_err() { self.queues[TXQ_INDEX].undo_pop(); break; } diff --git a/src/vmm/src/devices/virtio/vsock/event_handler.rs b/src/vmm/src/devices/virtio/vsock/event_handler.rs index c35fa1ba77c..93263ecb970 100755 --- a/src/vmm/src/devices/virtio/vsock/event_handler.rs +++ b/src/vmm/src/devices/virtio/vsock/event_handler.rs @@ -439,7 +439,9 @@ mod tests { // If the descriptor chain is already declared invalid, there's no reason to assemble // a packet. if let Some(rx_desc) = ctx.device.queues[RXQ_INDEX].pop() { - VsockPacket::from_rx_virtq_head(&test_ctx.mem, rx_desc).unwrap_err(); + VsockPacketRx::default() + .parse(&test_ctx.mem, rx_desc) + .unwrap_err(); } } @@ -461,7 +463,9 @@ mod tests { ctx.guest_txvq.dtable[desc_idx].len.set(len); if let Some(tx_desc) = ctx.device.queues[TXQ_INDEX].pop() { - VsockPacket::from_tx_virtq_head(&test_ctx.mem, tx_desc).unwrap_err(); + VsockPacketTx::default() + .parse(&test_ctx.mem, tx_desc) + .unwrap_err(); } } } @@ -486,13 +490,17 @@ mod tests { { let mut ctx = test_ctx.create_event_handler_context(); let rx_desc = ctx.device.queues[RXQ_INDEX].pop().unwrap(); - VsockPacket::from_rx_virtq_head(&test_ctx.mem, rx_desc).unwrap(); + VsockPacketRx::default() + .parse(&test_ctx.mem, rx_desc) + .unwrap(); } { let mut ctx = test_ctx.create_event_handler_context(); let tx_desc = ctx.device.queues[TXQ_INDEX].pop().unwrap(); - VsockPacket::from_tx_virtq_head(&test_ctx.mem, tx_desc).unwrap(); + VsockPacketTx::default() + .parse(&test_ctx.mem, tx_desc) + .unwrap(); } // Let's check what happens when the header descriptor is right before the gap. diff --git a/src/vmm/src/devices/virtio/vsock/mod.rs b/src/vmm/src/devices/virtio/vsock/mod.rs index 7fdc86aed2e..364fe83124f 100644 --- a/src/vmm/src/devices/virtio/vsock/mod.rs +++ b/src/vmm/src/devices/virtio/vsock/mod.rs @@ -22,13 +22,13 @@ mod unix; use std::os::unix::io::AsRawFd; -use packet::VsockPacket; use vm_memory::GuestMemoryError; use vmm_sys_util::epoll::EventSet; pub use self::defs::uapi::VIRTIO_ID_VSOCK as TYPE_VSOCK; pub use self::defs::VSOCK_DEV_ID; pub use self::device::Vsock; +use self::packet::{VsockPacketRx, VsockPacketTx}; pub use self::unix::{VsockUnixBackend, VsockUnixBackendError}; use crate::devices::virtio::iovec::IoVecError; use crate::devices::virtio::persist::PersistError as VirtioStateError; @@ -174,10 +174,10 @@ pub trait VsockEpollListener: AsRawFd { /// - `send_pkt(&pkt)` will fetch data from `pkt`, and place it into the channel. pub trait VsockChannel { /// Read/receive an incoming packet from the channel. - fn recv_pkt(&mut self, pkt: &mut VsockPacket) -> Result<(), VsockError>; + fn recv_pkt(&mut self, pkt: &mut VsockPacketRx) -> Result<(), VsockError>; /// Write/send a packet through the channel. - fn send_pkt(&mut self, pkt: &VsockPacket) -> Result<(), VsockError>; + fn send_pkt(&mut self, pkt: &VsockPacketTx) -> Result<(), VsockError>; /// Checks whether there is pending incoming data inside the channel, meaning that a subsequent /// call to `recv_pkt()` won't fail. diff --git a/src/vmm/src/devices/virtio/vsock/packet.rs b/src/vmm/src/devices/virtio/vsock/packet.rs index d63dd41386e..4db46928121 100644 --- a/src/vmm/src/devices/virtio/vsock/packet.rs +++ b/src/vmm/src/devices/virtio/vsock/packet.rs @@ -77,40 +77,125 @@ pub struct VsockPacketHeader { fwd_cnt: u32, } +impl VsockPacketHeader { + pub fn src_cid(&self) -> u64 { + u64::from_le(self.src_cid) + } + + pub fn set_src_cid(&mut self, cid: u64) -> &mut Self { + self.src_cid = cid.to_le(); + self + } + + pub fn dst_cid(&self) -> u64 { + u64::from_le(self.dst_cid) + } + + pub fn set_dst_cid(&mut self, cid: u64) -> &mut Self { + self.dst_cid = cid.to_le(); + self + } + + pub fn src_port(&self) -> u32 { + u32::from_le(self.src_port) + } + + pub fn set_src_port(&mut self, port: u32) -> &mut Self { + self.src_port = port.to_le(); + self + } + + pub fn dst_port(&self) -> u32 { + u32::from_le(self.dst_port) + } + + pub fn set_dst_port(&mut self, port: u32) -> &mut Self { + self.dst_port = port.to_le(); + self + } + + pub fn len(&self) -> u32 { + u32::from_le(self.len) + } + + pub fn set_len(&mut self, len: u32) -> &mut Self { + self.len = len.to_le(); + self + } + + pub fn type_(&self) -> u16 { + u16::from_le(self.type_) + } + + pub fn set_type(&mut self, type_: u16) -> &mut Self { + self.type_ = type_.to_le(); + self + } + + pub fn op(&self) -> u16 { + u16::from_le(self.op) + } + + pub fn set_op(&mut self, op: u16) -> &mut Self { + self.op = op.to_le(); + self + } + + pub fn flags(&self) -> u32 { + u32::from_le(self.flags) + } + + pub fn set_flags(&mut self, flags: u32) -> &mut Self { + self.flags = flags.to_le(); + self + } + + pub fn set_flag(&mut self, flag: u32) -> &mut Self { + self.set_flags(self.flags() | flag); + self + } + + pub fn buf_alloc(&self) -> u32 { + u32::from_le(self.buf_alloc) + } + + pub fn set_buf_alloc(&mut self, buf_alloc: u32) -> &mut Self { + self.buf_alloc = buf_alloc.to_le(); + self + } + + pub fn fwd_cnt(&self) -> u32 { + u32::from_le(self.fwd_cnt) + } + + pub fn set_fwd_cnt(&mut self, fwd_cnt: u32) -> &mut Self { + self.fwd_cnt = fwd_cnt.to_le(); + self + } +} + /// The vsock packet header struct size (the struct is packed). pub const VSOCK_PKT_HDR_SIZE: u32 = 44; // SAFETY: `VsockPacketHeader` is a POD and contains no padding. unsafe impl ByteValued for VsockPacketHeader {} -/// Enum representing either a TX (e.g. read-only) or RX (e.g. write-only) buffer -/// -/// Read and write permissions are statically enforced by using the correct `IoVecBuffer[Mut]` -/// abstraction -#[derive(Debug)] -pub enum VsockPacketBuffer { - /// Buffer holds a read-only guest-to-host (TX) packet - Tx(IoVecBuffer), - /// Buffer holds a write-only host-to-guest (RX) packet - Rx(IoVecBufferMut), -} - -/// Struct describing a single vsock packet. -/// -/// Encapsulates the virtio descriptor chain containing the packet through the `IoVecBuffer[Mut]` -/// abstractions. -#[derive(Debug)] -pub struct VsockPacket { +// /// Struct describing a single vsock packet. +// /// +// /// Encapsulates the virtio descriptor chain containing the packet through the `IoVecBuffer[Mut]` +// /// abstractions. +#[derive(Debug, Default)] +pub struct VsockPacketTx { /// A copy of the vsock packet's 44-byte header, held in hypervisor memory /// to minimize the number of accesses to guest memory. Can be written back /// to geust memory using [`VsockPacket::commit_hdr`] (only for RX buffers). - hdr: VsockPacketHeader, + pub hdr: VsockPacketHeader, /// The raw buffer, as it is contained in guest memory (containing both /// header and payload) - buffer: VsockPacketBuffer, + buffer: IoVecBuffer, } -impl VsockPacket { +impl VsockPacketTx { /// Create the packet wrapper from a TX virtq chain head. /// /// ## Errors @@ -123,17 +208,18 @@ impl VsockPacket { /// length would exceed [`defs::MAX_PKT_BUR_SIZE`]. /// - [`VsockError::DescChainTooShortForPacket`] if the contained vsock header describes a vsock /// packet whose length exceeds the descriptor chain's actual total buffer length. - pub fn from_tx_virtq_head( + pub fn parse( + &mut self, mem: &GuestMemoryMmap, chain: DescriptorChain, - ) -> Result { + ) -> Result<(), VsockError> { // SAFETY: This descriptor chain is only loaded once // virtio requests are handled sequentially so no two IoVecBuffers // are live at the same time, meaning this has exclusive ownership over the memory - let buffer = unsafe { IoVecBuffer::from_descriptor_chain(mem, chain)? }; + unsafe { self.buffer.load_descriptor_chain(mem, chain)? }; let mut hdr = VsockPacketHeader::default(); - match buffer.read_exact_volatile_at(hdr.as_mut_slice(), 0) { + match self.buffer.read_exact_volatile_at(hdr.as_mut_slice(), 0) { Ok(()) => (), Err(Error::PartialBuffer { completed, .. }) => { return Err(VsockError::DescChainTooShortForHeader(completed)) @@ -145,50 +231,85 @@ impl VsockPacket { return Err(VsockError::InvalidPktLen(hdr.len)); } - if hdr.len > buffer.len() - VSOCK_PKT_HDR_SIZE { + if hdr.len > self.buffer.len() - VSOCK_PKT_HDR_SIZE { return Err(VsockError::DescChainTooShortForPacket( - buffer.len(), + self.buffer.len(), hdr.len, )); } + self.hdr = hdr; + Ok(()) + } - Ok(VsockPacket { - hdr, - buffer: VsockPacketBuffer::Tx(buffer), - }) + pub fn write_from_offset_to( + &self, + dst: &mut T, + offset: u32, + count: u32, + ) -> Result { + if count + > self + .buffer + .len() + .saturating_sub(VSOCK_PKT_HDR_SIZE) + .saturating_sub(offset) + { + return Err(VsockError::GuestMemoryBounds); + } + + self.buffer + .read_volatile_at(dst, (offset + VSOCK_PKT_HDR_SIZE) as usize, count as usize) + .map_err(|err| VsockError::GuestMemoryMmap(GuestMemoryError::from(err))) + .and_then(|read| read.try_into().map_err(|_| VsockError::DescChainOverflow)) + } + + /// Returns the total length of this [`VsockPacket`]'s buffer (e.g. the amount of data bytes + /// contained in this packet). + /// + /// Return value will equal the total length of the underlying descriptor chain's buffers, + /// minus the length of the vsock header. + pub fn buf_size(&self) -> u32 { + self.buffer.len() - VSOCK_PKT_HDR_SIZE } +} + +/// Struct describing a single vsock packet. +/// +/// Encapsulates the virtio descriptor chain containing the packet through the `IoVecBuffer[Mut]` +/// abstractions. +#[derive(Debug, Default)] +pub struct VsockPacketRx { + /// A copy of the vsock packet's 44-byte header, held in hypervisor memory + /// to minimize the number of accesses to guest memory. Can be written back + /// to geust memory using [`VsockPacket::commit_hdr`] (only for RX buffers). + pub hdr: VsockPacketHeader, + /// The raw buffer, as it is contained in guest memory (containing both + /// header and payload) + buffer: IoVecBufferMut, +} +impl VsockPacketRx { /// Create the packet wrapper from an RX virtq chain head. /// /// ## Errors /// Returns [`VsockError::DescChainTooShortForHeader`] if the descriptor chain's total buffer /// length is insufficient to hold the 44 byte vsock header - pub fn from_rx_virtq_head( + pub fn parse( + &mut self, mem: &GuestMemoryMmap, chain: DescriptorChain, - ) -> Result { + ) -> Result<(), VsockError> { // SAFETY: This descriptor chain is only loaded once // virtio requests are handled sequentially so no two IoVecBuffers // are live at the same time, meaning this has exclusive ownership over the memory - let buffer = unsafe { IoVecBufferMut::from_descriptor_chain(mem, chain)? }; - - if buffer.len() < VSOCK_PKT_HDR_SIZE { - return Err(VsockError::DescChainTooShortForHeader(buffer.len() as usize)); + unsafe { self.buffer.load_descriptor_chain(mem, chain)? }; + if self.buffer.len() < VSOCK_PKT_HDR_SIZE { + return Err(VsockError::DescChainTooShortForHeader( + self.buffer.len() as usize + )); } - - Ok(Self { - // On the Rx path the header has to be filled by Firecracker. The guest only provides - // a write-only memory area that Firecracker can write the header into. So we initialize - // the local copy with zeros, we write to it whenever we need to, and we only commit it - // to the guest memory once, before marking the RX descriptor chain as used. - hdr: VsockPacketHeader::default(), - buffer: VsockPacketBuffer::Rx(buffer), - }) - } - - /// Provides in-place access to the local copy of the vsock packet header. - pub fn hdr(&self) -> &VsockPacketHeader { - &self.hdr + self.hdr = VsockPacketHeader::default(); + Ok(()) } /// Writes the local copy of the packet header to the guest memory. @@ -199,19 +320,13 @@ impl VsockPacket { /// packet's payload as described by this [`VsockPacket`] would exceed /// [`defs::MAX_PKT_BUF_SIZE`]. pub fn commit_hdr(&mut self) -> Result<(), VsockError> { - match self.buffer { - VsockPacketBuffer::Tx(_) => Err(VsockError::UnwritableDescriptor), - VsockPacketBuffer::Rx(ref mut buffer) => { - if self.hdr.len > defs::MAX_PKT_BUF_SIZE { - return Err(VsockError::InvalidPktLen(self.hdr.len)); - } - - buffer - .write_all_volatile_at(self.hdr.as_slice(), 0) - .map_err(GuestMemoryError::from) - .map_err(VsockError::GuestMemoryMmap) - } + if self.hdr.len > defs::MAX_PKT_BUF_SIZE { + return Err(VsockError::InvalidPktLen(self.hdr.len)); } + self.buffer + .write_all_volatile_at(self.hdr.as_slice(), 0) + .map_err(GuestMemoryError::from) + .map_err(VsockError::GuestMemoryMmap) } /// Returns the total length of this [`VsockPacket`]'s buffer (e.g. the amount of data bytes @@ -220,11 +335,7 @@ impl VsockPacket { /// Return value will equal the total length of the underlying descriptor chain's buffers, /// minus the length of the vsock header. pub fn buf_size(&self) -> u32 { - let chain_length = match self.buffer { - VsockPacketBuffer::Tx(ref iovec_buf) => iovec_buf.len(), - VsockPacketBuffer::Rx(ref iovec_buf) => iovec_buf.len(), - }; - chain_length - VSOCK_PKT_HDR_SIZE + self.buffer.len() - VSOCK_PKT_HDR_SIZE } pub fn read_at_offset_from( @@ -233,145 +344,20 @@ impl VsockPacket { offset: u32, count: u32, ) -> Result { - match self.buffer { - VsockPacketBuffer::Tx(_) => Err(VsockError::UnwritableDescriptor), - VsockPacketBuffer::Rx(ref mut buffer) => { - if count - > buffer - .len() - .saturating_sub(VSOCK_PKT_HDR_SIZE) - .saturating_sub(offset) - { - return Err(VsockError::GuestMemoryBounds); - } - - buffer - .write_volatile_at(src, (offset + VSOCK_PKT_HDR_SIZE) as usize, count as usize) - .map_err(|err| VsockError::GuestMemoryMmap(GuestMemoryError::from(err))) - .and_then(|read| read.try_into().map_err(|_| VsockError::DescChainOverflow)) - } - } - } - - pub fn write_from_offset_to( - &self, - dst: &mut T, - offset: u32, - count: u32, - ) -> Result { - match self.buffer { - VsockPacketBuffer::Tx(ref buffer) => { - if count - > buffer - .len() - .saturating_sub(VSOCK_PKT_HDR_SIZE) - .saturating_sub(offset) - { - return Err(VsockError::GuestMemoryBounds); - } - - buffer - .read_volatile_at(dst, (offset + VSOCK_PKT_HDR_SIZE) as usize, count as usize) - .map_err(|err| VsockError::GuestMemoryMmap(GuestMemoryError::from(err))) - .and_then(|read| read.try_into().map_err(|_| VsockError::DescChainOverflow)) - } - VsockPacketBuffer::Rx(_) => Err(VsockError::UnreadableDescriptor), + if count + > self + .buffer + .len() + .saturating_sub(VSOCK_PKT_HDR_SIZE) + .saturating_sub(offset) + { + return Err(VsockError::GuestMemoryBounds); } - } - - pub fn src_cid(&self) -> u64 { - u64::from_le(self.hdr.src_cid) - } - - pub fn set_src_cid(&mut self, cid: u64) -> &mut Self { - self.hdr.src_cid = cid.to_le(); - self - } - - pub fn dst_cid(&self) -> u64 { - u64::from_le(self.hdr.dst_cid) - } - - pub fn set_dst_cid(&mut self, cid: u64) -> &mut Self { - self.hdr.dst_cid = cid.to_le(); - self - } - - pub fn src_port(&self) -> u32 { - u32::from_le(self.hdr.src_port) - } - - pub fn set_src_port(&mut self, port: u32) -> &mut Self { - self.hdr.src_port = port.to_le(); - self - } - - pub fn dst_port(&self) -> u32 { - u32::from_le(self.hdr.dst_port) - } - - pub fn set_dst_port(&mut self, port: u32) -> &mut Self { - self.hdr.dst_port = port.to_le(); - self - } - - pub fn len(&self) -> u32 { - u32::from_le(self.hdr.len) - } - - pub fn set_len(&mut self, len: u32) -> &mut Self { - self.hdr.len = len.to_le(); - self - } - - pub fn type_(&self) -> u16 { - u16::from_le(self.hdr.type_) - } - - pub fn set_type(&mut self, type_: u16) -> &mut Self { - self.hdr.type_ = type_.to_le(); - self - } - - pub fn op(&self) -> u16 { - u16::from_le(self.hdr.op) - } - - pub fn set_op(&mut self, op: u16) -> &mut Self { - self.hdr.op = op.to_le(); - self - } - - pub fn flags(&self) -> u32 { - u32::from_le(self.hdr.flags) - } - pub fn set_flags(&mut self, flags: u32) -> &mut Self { - self.hdr.flags = flags.to_le(); - self - } - - pub fn set_flag(&mut self, flag: u32) -> &mut Self { - self.set_flags(self.flags() | flag); - self - } - - pub fn buf_alloc(&self) -> u32 { - u32::from_le(self.hdr.buf_alloc) - } - - pub fn set_buf_alloc(&mut self, buf_alloc: u32) -> &mut Self { - self.hdr.buf_alloc = buf_alloc.to_le(); - self - } - - pub fn fwd_cnt(&self) -> u32 { - u32::from_le(self.hdr.fwd_cnt) - } - - pub fn set_fwd_cnt(&mut self, fwd_cnt: u32) -> &mut Self { - self.hdr.fwd_cnt = fwd_cnt.to_le(); - self + self.buffer + .write_volatile_at(src, (offset + VSOCK_PKT_HDR_SIZE) as usize, count as usize) + .map_err(|err| VsockError::GuestMemoryMmap(GuestMemoryError::from(err))) + .and_then(|read| read.try_into().map_err(|_| VsockError::DescChainOverflow)) } } @@ -396,22 +382,6 @@ mod tests { }; } - macro_rules! expect_asm_error { - (tx, $test_ctx:expr, $handler_ctx:expr, $err:pat) => { - expect_asm_error!($test_ctx, $handler_ctx, $err, from_tx_virtq_head, TXQ_INDEX); - }; - (rx, $test_ctx:expr, $handler_ctx:expr, $err:pat) => { - expect_asm_error!($test_ctx, $handler_ctx, $err, from_rx_virtq_head, RXQ_INDEX); - }; - ($test_ctx:expr, $handler_ctx:expr, $err:pat, $ctor:ident, $vq_index:ident) => { - let result = VsockPacket::$ctor( - &$test_ctx.mem, - $handler_ctx.device.queues[$vq_index].pop().unwrap(), - ); - assert!(matches!(result, Err($err)), "{:?}", result) - }; - } - fn set_pkt_len(len: u32, guest_desc: &GuestQDesc, mem: &GuestMemoryMmap) { let hdr_addr = GuestAddress(guest_desc.addr.get()); let mut hdr: VsockPacketHeader = mem.read_obj(hdr_addr).unwrap(); @@ -434,7 +404,8 @@ mod tests { { create_context!(test_ctx, handler_ctx); - let pkt = VsockPacket::from_tx_virtq_head( + let mut pkt = VsockPacketTx::default(); + pkt.parse( &test_ctx.mem, handler_ctx.device.queues[TXQ_INDEX].pop().unwrap(), ) @@ -452,7 +423,13 @@ mod tests { handler_ctx.guest_txvq.dtable[0] .flags .set(VIRTQ_DESC_F_WRITE); - expect_asm_error!(tx, test_ctx, handler_ctx, VsockError::UnreadableDescriptor); + assert!(matches!( + VsockPacketTx::default().parse( + &test_ctx.mem, + handler_ctx.device.queues[TXQ_INDEX].pop().unwrap(), + ), + Err(VsockError::UnreadableDescriptor) + )) } // Test case: header descriptor has insufficient space to hold the packet header. @@ -462,23 +439,25 @@ mod tests { .len .set(VSOCK_PKT_HDR_SIZE - 1); handler_ctx.guest_txvq.dtable[1].len.set(0); - expect_asm_error!( - tx, - test_ctx, - handler_ctx, - VsockError::DescChainTooShortForHeader(_) - ); + assert!(matches!( + VsockPacketTx::default().parse( + &test_ctx.mem, + handler_ctx.device.queues[TXQ_INDEX].pop().unwrap(), + ), + Err(VsockError::DescChainTooShortForHeader(_)) + )) } // Test case: zero-length TX packet. { create_context!(test_ctx, handler_ctx); set_pkt_len(0, &handler_ctx.guest_txvq.dtable[0], &test_ctx.mem); - VsockPacket::from_tx_virtq_head( - &test_ctx.mem, - handler_ctx.device.queues[TXQ_INDEX].pop().unwrap(), - ) - .unwrap(); + VsockPacketTx::default() + .parse( + &test_ctx.mem, + handler_ctx.device.queues[TXQ_INDEX].pop().unwrap(), + ) + .unwrap(); } // Test case: TX packet has more data than we can handle. @@ -489,7 +468,13 @@ mod tests { &handler_ctx.guest_txvq.dtable[0], &test_ctx.mem, ); - expect_asm_error!(tx, test_ctx, handler_ctx, VsockError::InvalidPktLen(_)); + assert!(matches!( + VsockPacketTx::default().parse( + &test_ctx.mem, + handler_ctx.device.queues[TXQ_INDEX].pop().unwrap(), + ), + Err(VsockError::InvalidPktLen(_)) + )) } // Test case: @@ -499,12 +484,13 @@ mod tests { create_context!(test_ctx, handler_ctx); set_pkt_len(1024, &handler_ctx.guest_txvq.dtable[0], &test_ctx.mem); handler_ctx.guest_txvq.dtable[0].flags.set(0); - expect_asm_error!( - tx, - test_ctx, - handler_ctx, - VsockError::DescChainTooShortForPacket(44, 1024) - ); + assert!(matches!( + VsockPacketTx::default().parse( + &test_ctx.mem, + handler_ctx.device.queues[TXQ_INDEX].pop().unwrap(), + ), + Err(VsockError::DescChainTooShortForPacket(44, 1024)) + )) } // Test case: error on write-only buf descriptor. @@ -513,7 +499,13 @@ mod tests { handler_ctx.guest_txvq.dtable[1] .flags .set(VIRTQ_DESC_F_WRITE); - expect_asm_error!(tx, test_ctx, handler_ctx, VsockError::UnreadableDescriptor); + assert!(matches!( + VsockPacketTx::default().parse( + &test_ctx.mem, + handler_ctx.device.queues[TXQ_INDEX].pop().unwrap(), + ), + Err(VsockError::UnreadableDescriptor) + )) } // Test case: the buffer descriptor cannot fit all the data advertised by the the @@ -522,12 +514,13 @@ mod tests { create_context!(test_ctx, handler_ctx); set_pkt_len(8 * 1024, &handler_ctx.guest_txvq.dtable[0], &test_ctx.mem); handler_ctx.guest_txvq.dtable[1].len.set(4 * 1024); - expect_asm_error!( - tx, - test_ctx, - handler_ctx, - VsockError::DescChainTooShortForPacket(4140, 8192) - ); + assert!(matches!( + VsockPacketTx::default().parse( + &test_ctx.mem, + handler_ctx.device.queues[TXQ_INDEX].pop().unwrap(), + ), + Err(VsockError::DescChainTooShortForPacket(4140, 8192)) + )) } } @@ -536,7 +529,8 @@ mod tests { // Test case: successful RX packet assembly. { create_context!(test_ctx, handler_ctx); - let pkt = VsockPacket::from_rx_virtq_head( + let mut pkt = VsockPacketRx::default(); + pkt.parse( &test_ctx.mem, handler_ctx.device.queues[RXQ_INDEX].pop().unwrap(), ) @@ -548,7 +542,13 @@ mod tests { { create_context!(test_ctx, handler_ctx); handler_ctx.guest_rxvq.dtable[0].flags.set(0); - expect_asm_error!(rx, test_ctx, handler_ctx, VsockError::UnwritableDescriptor); + assert!(matches!( + VsockPacketRx::default().parse( + &test_ctx.mem, + handler_ctx.device.queues[RXQ_INDEX].pop().unwrap(), + ), + Err(VsockError::UnwritableDescriptor) + )) } // Test case: RX descriptor chain cannot fit packet header @@ -558,12 +558,13 @@ mod tests { .len .set(VSOCK_PKT_HDR_SIZE - 1); handler_ctx.guest_rxvq.dtable[1].len.set(0); - expect_asm_error!( - rx, - test_ctx, - handler_ctx, - VsockError::DescChainTooShortForHeader(_) - ); + assert!(matches!( + VsockPacketRx::default().parse( + &test_ctx.mem, + handler_ctx.device.queues[RXQ_INDEX].pop().unwrap(), + ), + Err(VsockError::DescChainTooShortForHeader(_)) + )) } } @@ -581,15 +582,20 @@ mod tests { const BUF_ALLOC: u32 = 9; const FWD_CNT: u32 = 10; - create_context!(test_ctx, handler_ctx); - let mut pkt = VsockPacket::from_rx_virtq_head( - &test_ctx.mem, - handler_ctx.device.queues[RXQ_INDEX].pop().unwrap(), - ) - .unwrap(); + let mut hdr = VsockPacketHeader::default(); + assert_eq!(hdr.src_cid(), 0); + assert_eq!(hdr.dst_cid(), 0); + assert_eq!(hdr.src_port(), 0); + assert_eq!(hdr.dst_port(), 0); + assert_eq!(hdr.len(), 0); + assert_eq!(hdr.type_(), 0); + assert_eq!(hdr.op(), 0); + assert_eq!(hdr.flags(), 0); + assert_eq!(hdr.buf_alloc(), 0); + assert_eq!(hdr.fwd_cnt(), 0); // Test field accessors. - pkt.set_src_cid(SRC_CID) + hdr.set_src_cid(SRC_CID) .set_dst_cid(DST_CID) .set_src_port(SRC_PORT) .set_dst_port(DST_PORT) @@ -600,33 +606,21 @@ mod tests { .set_buf_alloc(BUF_ALLOC) .set_fwd_cnt(FWD_CNT); - assert_eq!(pkt.src_cid(), SRC_CID); - assert_eq!(pkt.dst_cid(), DST_CID); - assert_eq!(pkt.src_port(), SRC_PORT); - assert_eq!(pkt.dst_port(), DST_PORT); - assert_eq!(pkt.len(), LEN); - assert_eq!(pkt.type_(), TYPE); - assert_eq!(pkt.op(), OP); - assert_eq!(pkt.flags(), FLAGS); - assert_eq!(pkt.buf_alloc(), BUF_ALLOC); - assert_eq!(pkt.fwd_cnt(), FWD_CNT); + assert_eq!(hdr.src_cid(), SRC_CID); + assert_eq!(hdr.dst_cid(), DST_CID); + assert_eq!(hdr.src_port(), SRC_PORT); + assert_eq!(hdr.dst_port(), DST_PORT); + assert_eq!(hdr.len(), LEN); + assert_eq!(hdr.type_(), TYPE); + assert_eq!(hdr.op(), OP); + assert_eq!(hdr.flags(), FLAGS); + assert_eq!(hdr.buf_alloc(), BUF_ALLOC); + assert_eq!(hdr.fwd_cnt(), FWD_CNT); // Test individual flag setting. - let flags = pkt.flags() | 0b1000; - pkt.set_flag(0b1000); - assert_eq!(pkt.flags(), flags); - - pkt.hdr = VsockPacketHeader::default(); - assert_eq!(pkt.src_cid(), 0); - assert_eq!(pkt.dst_cid(), 0); - assert_eq!(pkt.src_port(), 0); - assert_eq!(pkt.dst_port(), 0); - assert_eq!(pkt.len(), 0); - assert_eq!(pkt.type_(), 0); - assert_eq!(pkt.op(), 0); - assert_eq!(pkt.flags(), 0); - assert_eq!(pkt.buf_alloc(), 0); - assert_eq!(pkt.fwd_cnt(), 0); + let flags = hdr.flags() | 0b1000; + hdr.set_flag(0b1000); + assert_eq!(hdr.flags(), flags); } #[test] @@ -635,12 +629,14 @@ mod tests { // create_context gives us an rx descriptor chain and a tx descriptor chain pointing to the // same area of memory. We need both a rx-view and a tx-view into the packet, as tx-queue // buffers are read only, while rx queue buffers are write-only - let mut pkt = VsockPacket::from_rx_virtq_head( + let mut pkt = VsockPacketRx::default(); + pkt.parse( &test_ctx.mem, handler_ctx.device.queues[RXQ_INDEX].pop().unwrap(), ) .unwrap(); - let pkt2 = VsockPacket::from_tx_virtq_head( + let mut pkt2 = VsockPacketTx::default(); + pkt2.parse( &test_ctx.mem, handler_ctx.device.queues[TXQ_INDEX].pop().unwrap(), ) diff --git a/src/vmm/src/devices/virtio/vsock/test_utils.rs b/src/vmm/src/devices/virtio/vsock/test_utils.rs index b5806de08d6..4a5fdb2c941 100644 --- a/src/vmm/src/devices/virtio/vsock/test_utils.rs +++ b/src/vmm/src/devices/virtio/vsock/test_utils.rs @@ -9,11 +9,12 @@ use std::os::unix::io::{AsRawFd, RawFd}; use vmm_sys_util::epoll::EventSet; use vmm_sys_util::eventfd::EventFd; +use super::packet::{VsockPacketRx, VsockPacketTx}; use crate::devices::virtio::device::VirtioDevice; use crate::devices::virtio::queue::{VIRTQ_DESC_F_NEXT, VIRTQ_DESC_F_WRITE}; use crate::devices::virtio::test_utils::VirtQueue as GuestQ; use crate::devices::virtio::vsock::device::{RXQ_INDEX, TXQ_INDEX}; -use crate::devices::virtio::vsock::packet::{VsockPacket, VSOCK_PKT_HDR_SIZE}; +use crate::devices::virtio::vsock::packet::VSOCK_PKT_HDR_SIZE; use crate::devices::virtio::vsock::{ Vsock, VsockBackend, VsockChannel, VsockEpollListener, VsockError, }; @@ -62,7 +63,7 @@ impl Default for TestBackend { } impl VsockChannel for TestBackend { - fn recv_pkt(&mut self, pkt: &mut VsockPacket) -> Result<(), VsockError> { + fn recv_pkt(&mut self, pkt: &mut VsockPacketRx) -> Result<(), VsockError> { let cool_buf = [0xDu8, 0xE, 0xA, 0xD, 0xB, 0xE, 0xE, 0xF]; match self.rx_err.take() { None => { @@ -81,7 +82,7 @@ impl VsockChannel for TestBackend { } } - fn send_pkt(&mut self, _pkt: &VsockPacket) -> Result<(), VsockError> { + fn send_pkt(&mut self, _pkt: &VsockPacketTx) -> Result<(), VsockError> { match self.tx_err.take() { None => { self.tx_ok_cnt += 1; @@ -206,7 +207,7 @@ impl<'a> EventHandlerContext<'a> { } #[cfg(test)] -pub fn read_packet_data(pkt: &VsockPacket, how_much: u32) -> Vec { +pub fn read_packet_data(pkt: &VsockPacketTx, how_much: u32) -> Vec { let mut buf = vec![0; how_much as usize]; pkt.write_from_offset_to(&mut buf.as_mut_slice(), 0, how_much) .unwrap(); diff --git a/src/vmm/src/devices/virtio/vsock/unix/muxer.rs b/src/vmm/src/devices/virtio/vsock/unix/muxer.rs index 5dfdcc582e5..79b5d4c143f 100644 --- a/src/vmm/src/devices/virtio/vsock/unix/muxer.rs +++ b/src/vmm/src/devices/virtio/vsock/unix/muxer.rs @@ -40,12 +40,12 @@ use vmm_sys_util::epoll::{ControlOperation, Epoll, EpollEvent, EventSet}; use super::super::csm::ConnState; use super::super::defs::uapi; -use super::super::packet::VsockPacket; use super::super::{VsockBackend, VsockChannel, VsockEpollListener, VsockError}; use super::muxer_killq::MuxerKillQ; use super::muxer_rxq::MuxerRxQ; use super::{defs, MuxerConnection, VsockUnixBackendError}; use crate::devices::virtio::vsock::metrics::METRICS; +use crate::devices::virtio::vsock::packet::{VsockPacketRx, VsockPacketTx}; use crate::logger::IncMetric; /// A unique identifier of a `MuxerConnection` object. Connections are stored in a hash map, @@ -115,7 +115,7 @@ impl VsockChannel for VsockMuxer { /// Retuns: /// - `Ok(())`: `pkt` has been successfully filled in; or /// - `Err(VsockError::NoData)`: there was no available data with which to fill in the packet. - fn recv_pkt(&mut self, pkt: &mut VsockPacket) -> Result<(), VsockError> { + fn recv_pkt(&mut self, pkt: &mut VsockPacketRx) -> Result<(), VsockError> { // We'll look for instructions on how to build the RX packet in the RX queue. If the // queue is empty, that doesn't necessarily mean we don't have any pending RX, since // the queue might be out-of-sync. If that's the case, we'll attempt to sync it first, @@ -131,7 +131,8 @@ impl VsockChannel for VsockMuxer { local_port, peer_port, } => { - pkt.set_op(uapi::VSOCK_OP_RST) + pkt.hdr + .set_op(uapi::VSOCK_OP_RST) .set_src_cid(uapi::VSOCK_HOST_CID) .set_dst_cid(self.cid) .set_src_port(local_port) @@ -165,14 +166,14 @@ impl VsockChannel for VsockMuxer { // Inspect traffic, looking for RST packets, since that means we have to // terminate and remove this connection from the active connection pool. // - if pkt.op() == uapi::VSOCK_OP_RST { + if pkt.hdr.op() == uapi::VSOCK_OP_RST { self.remove_connection(ConnMapKey { - local_port: pkt.src_port(), - peer_port: pkt.dst_port(), + local_port: pkt.hdr.src_port(), + peer_port: pkt.hdr.dst_port(), }); } - debug!("vsock muxer: RX pkt: {:?}", pkt.hdr()); + debug!("vsock muxer: RX pkt: {:?}", pkt.hdr); return Ok(()); } } @@ -188,31 +189,31 @@ impl VsockChannel for VsockMuxer { /// Returns: /// always `Ok(())` - the packet has been consumed, and its virtio TX buffers can be /// returned to the guest vsock driver. - fn send_pkt(&mut self, pkt: &VsockPacket) -> Result<(), VsockError> { + fn send_pkt(&mut self, pkt: &VsockPacketTx) -> Result<(), VsockError> { let conn_key = ConnMapKey { - local_port: pkt.dst_port(), - peer_port: pkt.src_port(), + local_port: pkt.hdr.dst_port(), + peer_port: pkt.hdr.src_port(), }; debug!( "vsock: muxer.send[rxq.len={}]: {:?}", self.rxq.len(), - pkt.hdr() + pkt.hdr ); // If this packet has an unsupported type (!=stream), we must send back an RST. // - if pkt.type_() != uapi::VSOCK_TYPE_STREAM { - self.enq_rst(pkt.dst_port(), pkt.src_port()); + if pkt.hdr.type_() != uapi::VSOCK_TYPE_STREAM { + self.enq_rst(pkt.hdr.dst_port(), pkt.hdr.src_port()); return Ok(()); } // We don't know how to handle packets addressed to other CIDs. We only handle the host // part of the guest - host communication here. - if pkt.dst_cid() != uapi::VSOCK_HOST_CID { + if pkt.hdr.dst_cid() != uapi::VSOCK_HOST_CID { info!( "vsock: dropping guest packet for unknown CID: {:?}", - pkt.hdr() + pkt.hdr ); return Ok(()); } @@ -221,12 +222,12 @@ impl VsockChannel for VsockMuxer { // This packet can't be routed to any active connection (based on its src and dst // ports). The only orphan / unroutable packets we know how to handle are // connection requests. - if pkt.op() == uapi::VSOCK_OP_REQUEST { + if pkt.hdr.op() == uapi::VSOCK_OP_REQUEST { // Oh, this is a connection request! self.handle_peer_request_pkt(pkt); } else { // Send back an RST, to let the drive know we weren't expecting this packet. - self.enq_rst(pkt.dst_port(), pkt.src_port()); + self.enq_rst(pkt.hdr.dst_port(), pkt.hdr.src_port()); } return Ok(()); } @@ -234,7 +235,7 @@ impl VsockChannel for VsockMuxer { // Right, we know where to send this packet, then (to `conn_key`). // However, if this is an RST, we have to forcefully terminate the connection, so // there's no point in forwarding it the packet. - if pkt.op() == uapi::VSOCK_OP_RST { + if pkt.hdr.op() == uapi::VSOCK_OP_RST { self.remove_connection(conn_key); return Ok(()); } @@ -608,8 +609,8 @@ impl VsockMuxer { /// the file system path corresponing to the destination port. If successful, a new /// connection object will be created and added to the connection pool. On failure, a new /// RST packet will be scheduled for delivery to the guest. - fn handle_peer_request_pkt(&mut self, pkt: &VsockPacket) { - let port_path = format!("{}_{}", self.host_sock_path, pkt.dst_port()); + fn handle_peer_request_pkt(&mut self, pkt: &VsockPacketTx) { + let port_path = format!("{}_{}", self.host_sock_path, pkt.hdr.dst_port()); UnixStream::connect(port_path) .and_then(|stream| stream.set_nonblocking(true).map(|_| stream)) @@ -617,20 +618,20 @@ impl VsockMuxer { .and_then(|stream| { self.add_connection( ConnMapKey { - local_port: pkt.dst_port(), - peer_port: pkt.src_port(), + local_port: pkt.hdr.dst_port(), + peer_port: pkt.hdr.src_port(), }, MuxerConnection::new_peer_init( stream, uapi::VSOCK_HOST_CID, self.cid, - pkt.dst_port(), - pkt.src_port(), - pkt.buf_alloc(), + pkt.hdr.dst_port(), + pkt.hdr.src_port(), + pkt.hdr.buf_alloc(), ), ) }) - .unwrap_or_else(|_| self.enq_rst(pkt.dst_port(), pkt.src_port())); + .unwrap_or_else(|_| self.enq_rst(pkt.hdr.dst_port(), pkt.hdr.src_port())); } /// Perform an action that might mutate a connection's state. @@ -806,8 +807,8 @@ mod tests { struct MuxerTestContext { _vsock_test_ctx: VsockTestContext, // Two views of the same in-memory packet. rx-view for writing, tx-view for reading - rx_pkt: VsockPacket, - tx_pkt: VsockPacket, + rx_pkt: VsockPacketRx, + tx_pkt: VsockPacketTx, muxer: VsockMuxer, } @@ -832,16 +833,20 @@ mod tests { fn new(name: &str) -> Self { let vsock_test_ctx = VsockTestContext::new(); let mut handler_ctx = vsock_test_ctx.create_event_handler_context(); - let rx_pkt = VsockPacket::from_rx_virtq_head( - &vsock_test_ctx.mem, - handler_ctx.device.queues[RXQ_INDEX].pop().unwrap(), - ) - .unwrap(); - let tx_pkt = VsockPacket::from_tx_virtq_head( - &vsock_test_ctx.mem, - handler_ctx.device.queues[TXQ_INDEX].pop().unwrap(), - ) - .unwrap(); + let mut rx_pkt = VsockPacketRx::default(); + rx_pkt + .parse( + &vsock_test_ctx.mem, + handler_ctx.device.queues[RXQ_INDEX].pop().unwrap(), + ) + .unwrap(); + let mut tx_pkt = VsockPacketTx::default(); + tx_pkt + .parse( + &vsock_test_ctx.mem, + handler_ctx.device.queues[TXQ_INDEX].pop().unwrap(), + ) + .unwrap(); let muxer = VsockMuxer::new(PEER_CID, get_file(name)).unwrap(); Self { @@ -852,15 +857,17 @@ mod tests { } } - fn init_tx_pkt(&mut self, local_port: u32, peer_port: u32, op: u16) -> &mut VsockPacket { + fn init_tx_pkt(&mut self, local_port: u32, peer_port: u32, op: u16) -> &mut VsockPacketTx { self.tx_pkt + .hdr .set_type(uapi::VSOCK_TYPE_STREAM) .set_src_cid(PEER_CID) .set_dst_cid(uapi::VSOCK_HOST_CID) .set_src_port(peer_port) .set_dst_port(local_port) .set_op(op) - .set_buf_alloc(PEER_BUF_ALLOC) + .set_buf_alloc(PEER_BUF_ALLOC); + &mut self.tx_pkt } fn init_data_tx_pkt( @@ -868,10 +875,10 @@ mod tests { local_port: u32, peer_port: u32, mut data: &[u8], - ) -> &mut VsockPacket { + ) -> &mut VsockPacketTx { assert!(data.len() <= self.tx_pkt.buf_size() as usize); - self.init_tx_pkt(local_port, peer_port, uapi::VSOCK_OP_RW) - .set_len(u32::try_from(data.len()).unwrap()); + let tx_pkt = self.init_tx_pkt(local_port, peer_port, uapi::VSOCK_OP_RW); + tx_pkt.hdr.set_len(u32::try_from(data.len()).unwrap()); let data_len = data.len().try_into().unwrap(); // store in tmp var to make borrow checker happy. self.rx_pkt @@ -948,9 +955,9 @@ mod tests { // A connection request for the peer should now be available from the muxer. assert!(self.muxer.has_pending_rx()); self.recv(); - assert_eq!(self.rx_pkt.op(), uapi::VSOCK_OP_REQUEST); - assert_eq!(self.rx_pkt.dst_port(), peer_port); - assert_eq!(self.rx_pkt.src_port(), local_port); + assert_eq!(self.rx_pkt.hdr.op(), uapi::VSOCK_OP_REQUEST); + assert_eq!(self.rx_pkt.hdr.dst_port(), peer_port); + assert_eq!(self.rx_pkt.hdr.src_port(), local_port); self.init_tx_pkt(local_port, peer_port, uapi::VSOCK_OP_RESPONSE); self.send(); @@ -1022,19 +1029,19 @@ mod tests { const SOCK_DGRAM: u16 = 2; let mut ctx = MuxerTestContext::new("bad_peer_pkt"); - ctx.init_tx_pkt(LOCAL_PORT, PEER_PORT, uapi::VSOCK_OP_REQUEST) - .set_type(SOCK_DGRAM); + let tx_pkt = ctx.init_tx_pkt(LOCAL_PORT, PEER_PORT, uapi::VSOCK_OP_REQUEST); + tx_pkt.hdr.set_type(SOCK_DGRAM); ctx.send(); // The guest sent a SOCK_DGRAM packet. Per the vsock spec, we need to reply with an RST // packet, since vsock only supports stream sockets. assert!(ctx.muxer.has_pending_rx()); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_RST); - assert_eq!(ctx.rx_pkt.src_cid(), uapi::VSOCK_HOST_CID); - assert_eq!(ctx.rx_pkt.dst_cid(), PEER_CID); - assert_eq!(ctx.rx_pkt.src_port(), LOCAL_PORT); - assert_eq!(ctx.rx_pkt.dst_port(), PEER_PORT); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RST); + assert_eq!(ctx.rx_pkt.hdr.src_cid(), uapi::VSOCK_HOST_CID); + assert_eq!(ctx.rx_pkt.hdr.dst_cid(), PEER_CID); + assert_eq!(ctx.rx_pkt.hdr.src_port(), LOCAL_PORT); + assert_eq!(ctx.rx_pkt.hdr.dst_port(), PEER_PORT); // Any orphan (i.e. without a connection), non-RST packet, should be replied to with an // RST. @@ -1050,15 +1057,15 @@ mod tests { ctx.send(); assert!(ctx.muxer.has_pending_rx()); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_RST); - assert_eq!(ctx.rx_pkt.src_port(), LOCAL_PORT); - assert_eq!(ctx.rx_pkt.dst_port(), PEER_PORT); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RST); + assert_eq!(ctx.rx_pkt.hdr.src_port(), LOCAL_PORT); + assert_eq!(ctx.rx_pkt.hdr.dst_port(), PEER_PORT); } // Any packet addressed to anything other than VSOCK_VHOST_CID should get dropped. assert!(!ctx.muxer.has_pending_rx()); - ctx.init_tx_pkt(LOCAL_PORT, PEER_PORT, uapi::VSOCK_OP_REQUEST) - .set_dst_cid(uapi::VSOCK_HOST_CID + 1); + let tx_pkt = ctx.init_tx_pkt(LOCAL_PORT, PEER_PORT, uapi::VSOCK_OP_REQUEST); + tx_pkt.hdr.set_dst_cid(uapi::VSOCK_HOST_CID + 1); ctx.send(); assert!(!ctx.muxer.has_pending_rx()); } @@ -1074,12 +1081,12 @@ mod tests { ctx.init_tx_pkt(LOCAL_PORT, PEER_PORT, uapi::VSOCK_OP_REQUEST); ctx.send(); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_RST); - assert_eq!(ctx.rx_pkt.len(), 0); - assert_eq!(ctx.rx_pkt.src_cid(), uapi::VSOCK_HOST_CID); - assert_eq!(ctx.rx_pkt.dst_cid(), PEER_CID); - assert_eq!(ctx.rx_pkt.src_port(), LOCAL_PORT); - assert_eq!(ctx.rx_pkt.dst_port(), PEER_PORT); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RST); + assert_eq!(ctx.rx_pkt.hdr.len(), 0); + assert_eq!(ctx.rx_pkt.hdr.src_cid(), uapi::VSOCK_HOST_CID); + assert_eq!(ctx.rx_pkt.hdr.dst_cid(), PEER_CID); + assert_eq!(ctx.rx_pkt.hdr.src_port(), LOCAL_PORT); + assert_eq!(ctx.rx_pkt.hdr.dst_port(), PEER_PORT); // Test peer connection accepted. let mut listener = ctx.create_local_listener(LOCAL_PORT); @@ -1088,12 +1095,12 @@ mod tests { assert_eq!(ctx.muxer.conn_map.len(), 1); let mut stream = listener.accept(); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_RESPONSE); - assert_eq!(ctx.rx_pkt.len(), 0); - assert_eq!(ctx.rx_pkt.src_cid(), uapi::VSOCK_HOST_CID); - assert_eq!(ctx.rx_pkt.dst_cid(), PEER_CID); - assert_eq!(ctx.rx_pkt.src_port(), LOCAL_PORT); - assert_eq!(ctx.rx_pkt.dst_port(), PEER_PORT); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RESPONSE); + assert_eq!(ctx.rx_pkt.hdr.len(), 0); + assert_eq!(ctx.rx_pkt.hdr.src_cid(), uapi::VSOCK_HOST_CID); + assert_eq!(ctx.rx_pkt.hdr.dst_cid(), PEER_CID); + assert_eq!(ctx.rx_pkt.hdr.src_port(), LOCAL_PORT); + assert_eq!(ctx.rx_pkt.hdr.dst_port(), PEER_PORT); let key = ConnMapKey { local_port: LOCAL_PORT, peer_port: PEER_PORT, @@ -1120,9 +1127,9 @@ mod tests { // of its connections, so it should now be reporting that it can fill in an RX packet. assert!(ctx.muxer.has_pending_rx()); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_RW); - assert_eq!(ctx.rx_pkt.src_port(), LOCAL_PORT); - assert_eq!(ctx.rx_pkt.dst_port(), PEER_PORT); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RW); + assert_eq!(ctx.rx_pkt.hdr.src_port(), LOCAL_PORT); + assert_eq!(ctx.rx_pkt.hdr.dst_port(), PEER_PORT); let buf = test_utils::read_packet_data(&ctx.tx_pkt, 4); assert_eq!(&buf, &data); @@ -1156,9 +1163,9 @@ mod tests { assert!(ctx.muxer.has_pending_rx()); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_RW); - assert_eq!(ctx.rx_pkt.src_port(), local_port); - assert_eq!(ctx.rx_pkt.dst_port(), peer_port); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RW); + assert_eq!(ctx.rx_pkt.hdr.src_port(), local_port); + assert_eq!(ctx.rx_pkt.hdr.dst_port(), peer_port); let buf = test_utils::read_packet_data(&ctx.tx_pkt, 4); assert_eq!(&buf, &data); @@ -1179,11 +1186,11 @@ mod tests { ctx.notify_muxer(); assert!(ctx.muxer.has_pending_rx()); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_SHUTDOWN); - assert_ne!(ctx.rx_pkt.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_SEND, 0); - assert_ne!(ctx.rx_pkt.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_RCV, 0); - assert_eq!(ctx.rx_pkt.src_port(), local_port); - assert_eq!(ctx.rx_pkt.dst_port(), peer_port); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_SHUTDOWN); + assert_ne!(ctx.rx_pkt.hdr.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_SEND, 0); + assert_ne!(ctx.rx_pkt.hdr.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_RCV, 0); + assert_eq!(ctx.rx_pkt.hdr.src_port(), local_port); + assert_eq!(ctx.rx_pkt.hdr.dst_port(), peer_port); // The connection should get removed (and its local port freed), after the peer replies // with an RST. @@ -1210,9 +1217,9 @@ mod tests { assert!(ctx.muxer.has_pending_rx()); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_RESPONSE); - assert_eq!(ctx.rx_pkt.src_port(), local_port); - assert_eq!(ctx.rx_pkt.dst_port(), peer_port); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RESPONSE); + assert_eq!(ctx.rx_pkt.hdr.src_port(), local_port); + assert_eq!(ctx.rx_pkt.hdr.dst_port(), peer_port); let key = ConnMapKey { local_port, peer_port, @@ -1220,17 +1227,17 @@ mod tests { assert!(ctx.muxer.conn_map.contains_key(&key)); // Emulate a full shutdown from the peer (no-more-send + no-more-recv). - ctx.init_tx_pkt(local_port, peer_port, uapi::VSOCK_OP_SHUTDOWN) - .set_flag(uapi::VSOCK_FLAGS_SHUTDOWN_SEND) - .set_flag(uapi::VSOCK_FLAGS_SHUTDOWN_RCV); + let tx_pkt = ctx.init_tx_pkt(local_port, peer_port, uapi::VSOCK_OP_SHUTDOWN); + tx_pkt.hdr.set_flag(uapi::VSOCK_FLAGS_SHUTDOWN_SEND); + tx_pkt.hdr.set_flag(uapi::VSOCK_FLAGS_SHUTDOWN_RCV); ctx.send(); // Now, the muxer should remove the connection from its map, and reply with an RST. assert!(ctx.muxer.has_pending_rx()); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_RST); - assert_eq!(ctx.rx_pkt.src_port(), local_port); - assert_eq!(ctx.rx_pkt.dst_port(), peer_port); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RST); + assert_eq!(ctx.rx_pkt.hdr.src_port(), local_port); + assert_eq!(ctx.rx_pkt.hdr.dst_port(), peer_port); let key = ConnMapKey { local_port, peer_port, @@ -1277,15 +1284,15 @@ mod tests { for peer_port in peer_port_first..peer_port_first + defs::MUXER_RXQ_SIZE - 1 { ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_RESPONSE); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RESPONSE); // The response order should hold. The evicted response should have been the last // enqueued. - assert_eq!(ctx.rx_pkt.dst_port(), peer_port); + assert_eq!(ctx.rx_pkt.hdr.dst_port(), peer_port); } // There should be one more packet in the queue: the RST. assert_eq!(ctx.muxer.rxq.len(), 1); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_RST); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RST); // The queue should now be empty, but out-of-sync, so the muxer should report it has some // pending RX. @@ -1299,11 +1306,11 @@ mod tests { // - the one that got evicted by the RST. ctx.recv(); assert!(ctx.muxer.rxq.is_synced()); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_RESPONSE); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RESPONSE); assert!(ctx.muxer.has_pending_rx()); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_RESPONSE); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RESPONSE); } #[test] @@ -1325,17 +1332,17 @@ mod tests { ctx.send(); ctx.notify_muxer(); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_RESPONSE); - assert_eq!(ctx.rx_pkt.src_port(), local_port); - assert_eq!(ctx.rx_pkt.dst_port(), peer_port); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RESPONSE); + assert_eq!(ctx.rx_pkt.hdr.src_port(), local_port); + assert_eq!(ctx.rx_pkt.hdr.dst_port(), peer_port); { let _stream = listener.accept(); } ctx.notify_muxer(); ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_SHUTDOWN); - assert_eq!(ctx.rx_pkt.src_port(), local_port); - assert_eq!(ctx.rx_pkt.dst_port(), peer_port); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_SHUTDOWN); + assert_eq!(ctx.rx_pkt.hdr.src_port(), local_port); + assert_eq!(ctx.rx_pkt.hdr.dst_port(), peer_port); // The kill queue should be synchronized, up until the `defs::MUXER_KILLQ_SIZE`th // connection we schedule for termination. assert_eq!( @@ -1380,8 +1387,8 @@ mod tests { // dying connections in the recent killq sweep. for _p in peer_port_first..peer_port_last { ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_RST); - assert_eq!(ctx.rx_pkt.src_port(), local_port); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RST); + assert_eq!(ctx.rx_pkt.hdr.src_port(), local_port); } // The connections should have been removed here. @@ -1393,8 +1400,8 @@ mod tests { // There should be one more packet in the RX queue: the connection response our request // that triggered the kill queue sweep. ctx.recv(); - assert_eq!(ctx.rx_pkt.op(), uapi::VSOCK_OP_RESPONSE); - assert_eq!(ctx.rx_pkt.dst_port(), peer_port_last + 1); + assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RESPONSE); + assert_eq!(ctx.rx_pkt.hdr.dst_port(), peer_port_last + 1); assert!(!ctx.muxer.has_pending_rx()); }