From a9f1828089ef40092aa0b2b32f28fefe6d9babb2 Mon Sep 17 00:00:00 2001 From: brandonpike Date: Wed, 12 Jun 2024 20:12:59 +0000 Subject: [PATCH] Use u32 to describe vsock related buffer sizes Move to u32 for vsock module. We can upsize from u32 to usize as needed. Signed-off-by: brandonpike --- src/vmm/src/devices/virtio/iovec.rs | 90 +++++++++++-------- src/vmm/src/devices/virtio/net/device.rs | 6 +- .../devices/virtio/vsock/csm/connection.rs | 27 +++--- src/vmm/src/devices/virtio/vsock/packet.rs | 60 ++++++------- .../src/devices/virtio/vsock/test_utils.rs | 6 +- .../src/devices/virtio/vsock/unix/muxer.rs | 6 +- 6 files changed, 103 insertions(+), 92 deletions(-) diff --git a/src/vmm/src/devices/virtio/iovec.rs b/src/vmm/src/devices/virtio/iovec.rs index 78e4e26aeb5..78aac21153d 100644 --- a/src/vmm/src/devices/virtio/iovec.rs +++ b/src/vmm/src/devices/virtio/iovec.rs @@ -108,23 +108,25 @@ impl IoVecBuffer { pub fn read_exact_volatile_at( &self, mut buf: &mut [u8], - offset: usize, + offset: u32, ) -> Result<(), VolatileMemoryError> { - if offset < self.len() as usize { - let expected = buf.len(); + if offset < self.len() { + let expected = u32::try_from(buf.len()).unwrap(); let bytes_read = self.read_volatile_at(&mut buf, offset, expected)?; if bytes_read != expected { return Err(VolatileMemoryError::PartialBuffer { - expected, - completed: bytes_read, + expected: expected as usize, + completed: bytes_read as usize, }); } Ok(()) } else { // If `offset` is past size, there's nothing to read. - Err(VolatileMemoryError::OutOfBounds { addr: offset }) + Err(VolatileMemoryError::OutOfBounds { + addr: offset as usize, + }) } } @@ -134,9 +136,9 @@ impl IoVecBuffer { pub fn read_volatile_at( &self, dst: &mut W, - mut offset: usize, - mut len: usize, - ) -> Result { + mut offset: u32, + mut len: u32, + ) -> Result { let mut total_bytes_read = 0; for iov in &self.vecs { @@ -144,19 +146,20 @@ impl IoVecBuffer { break; } - if offset >= iov.iov_len { - offset -= iov.iov_len; + let iov_len = u32::try_from(iov.iov_len).unwrap(); + if offset >= iov_len { + offset -= iov_len; continue; } let mut slice = // SAFETY: the constructor IoVecBufferMut::from_descriptor_chain ensures that // all iovecs contained point towards valid ranges of guest memory - unsafe { VolatileSlice::new(iov.iov_base.cast(), iov.iov_len).offset(offset)? }; + unsafe { VolatileSlice::new(iov.iov_base.cast(), iov.iov_len).offset(offset as usize)? }; offset = 0; - if slice.len() > len { - slice = slice.subslice(0, len)?; + if u32::try_from(slice.len()).unwrap() > len { + slice = slice.subslice(0, len as usize)?; } let bytes_read = loop { @@ -166,13 +169,13 @@ impl IoVecBuffer { { continue } - Ok(bytes_read) => break bytes_read, + Ok(bytes_read) => break u32::try_from(bytes_read).unwrap(), Err(volatile_memory_error) => return Err(volatile_memory_error), } }; total_bytes_read += bytes_read; - if bytes_read < slice.len() { + if slice.len() > bytes_read as usize { break; } len -= bytes_read; @@ -248,23 +251,25 @@ impl IoVecBufferMut { pub fn write_all_volatile_at( &mut self, mut buf: &[u8], - offset: usize, + offset: u32, ) -> Result<(), VolatileMemoryError> { - if offset < self.len() as usize { - let expected = buf.len(); + if offset < self.len() { + let expected = u32::try_from(buf.len()).unwrap(); let bytes_written = self.write_volatile_at(&mut buf, offset, expected)?; if bytes_written != expected { return Err(VolatileMemoryError::PartialBuffer { - expected, - completed: bytes_written, + expected: expected as usize, + completed: bytes_written as usize, }); } Ok(()) } else { // We cannot write past the end of the `IoVecBufferMut`. - Err(VolatileMemoryError::OutOfBounds { addr: offset }) + Err(VolatileMemoryError::OutOfBounds { + addr: offset as usize, + }) } } @@ -274,9 +279,9 @@ impl IoVecBufferMut { pub fn write_volatile_at( &mut self, src: &mut W, - mut offset: usize, - mut len: usize, - ) -> Result { + mut offset: u32, + mut len: u32, + ) -> Result { let mut total_bytes_read = 0; for iov in &self.vecs { @@ -284,19 +289,20 @@ impl IoVecBufferMut { break; } - if offset >= iov.iov_len { - offset -= iov.iov_len; + let iov_len = u32::try_from(iov.iov_len).unwrap(); + if offset >= iov_len { + offset -= iov_len; continue; } let mut slice = // SAFETY: the constructor IoVecBufferMut::from_descriptor_chain ensures that // all iovecs contained point towards valid ranges of guest memory - unsafe { VolatileSlice::new(iov.iov_base.cast(), iov.iov_len).offset(offset)? }; + unsafe { VolatileSlice::new(iov.iov_base.cast(), iov.iov_len).offset(offset as usize)? }; offset = 0; - if slice.len() > len { - slice = slice.subslice(0, len)?; + if u32::try_from(slice.len()).unwrap() > len { + slice = slice.subslice(0, len as usize)?; } let bytes_read = loop { @@ -306,13 +312,13 @@ impl IoVecBufferMut { { continue } - Ok(bytes_read) => break bytes_read, + Ok(bytes_read) => break u32::try_from(bytes_read).unwrap(), Err(volatile_memory_error) => return Err(volatile_memory_error), } }; total_bytes_read += bytes_read; - if bytes_read < slice.len() { + if slice.len() > bytes_read as usize { break; } len -= bytes_read; @@ -587,7 +593,9 @@ mod tests { // 5 bytes at offset 252 (only 4 bytes left). test_vec4[60..64].copy_from_slice(&buf[0..4]); assert_eq!( - iovec.write_volatile_at(&mut &*buf, 252, buf.len()).unwrap(), + iovec + .write_volatile_at(&mut &*buf, 252, buf.len().try_into().unwrap()) + .unwrap(), 4 ); vq.dtable[0].check_data(&test_vec1); @@ -731,11 +739,13 @@ mod verification { assert_eq!( iov.read_volatile_at( &mut KaniBuffer(&mut buf), - offset as usize, - GUEST_MEMORY_SIZE + offset, + GUEST_MEMORY_SIZE.try_into().unwrap() ) .unwrap(), - buf.len().min(iov.len().saturating_sub(offset) as usize) + u32::try_from(buf.len()) + .unwrap() + .min(iov.len().saturating_sub(offset)) ); } @@ -761,11 +771,13 @@ mod verification { iov_mut .write_volatile_at( &mut KaniBuffer(&mut buf), - offset as usize, - GUEST_MEMORY_SIZE + offset, + GUEST_MEMORY_SIZE.try_into().unwrap() ) .unwrap(), - buf.len().min(iov_mut.len().saturating_sub(offset) as usize) + u32::try_from(buf.len()) + .unwrap() + .min(iov_mut.len().saturating_sub(offset)) ); } } diff --git a/src/vmm/src/devices/virtio/net/device.rs b/src/vmm/src/devices/virtio/net/device.rs index f79692012e5..4fc29678ab5 100755 --- a/src/vmm/src/devices/virtio/net/device.rs +++ b/src/vmm/src/devices/virtio/net/device.rs @@ -445,7 +445,7 @@ impl Net { net_metrics: &NetDeviceMetrics, ) -> Result { // Read the frame headers from the IoVecBuffer - let max_header_len = headers.len(); + let max_header_len = u32::try_from(headers.len()).unwrap(); let header_len = frame_iovec .read_volatile_at(&mut &mut *headers, 0, max_header_len) .map_err(|err| { @@ -454,7 +454,7 @@ impl Net { NetError::VnetHeaderMissing })?; - let headers = frame_bytes_from_buf(&headers[..header_len]).map_err(|e| { + let headers = frame_bytes_from_buf(&headers[..header_len as usize]).map_err(|e| { error!("VNET headers missing in TX frame"); net_metrics.tx_malformed_frames.inc(); e @@ -466,7 +466,7 @@ impl Net { // Ok to unwrap here, because we are passing a buffer that has the exact size // of the `IoVecBuffer` minus the VNET headers. frame_iovec - .read_exact_volatile_at(&mut frame, vnet_hdr_len()) + .read_exact_volatile_at(&mut frame, vnet_hdr_len().try_into().unwrap()) .unwrap(); let _ = ns.detour_frame(&frame); METRICS.mmds.rx_accepted.inc(); diff --git a/src/vmm/src/devices/virtio/vsock/csm/connection.rs b/src/vmm/src/devices/virtio/vsock/csm/connection.rs index e38d9bab974..9941793e9a5 100644 --- a/src/vmm/src/devices/virtio/vsock/csm/connection.rs +++ b/src/vmm/src/devices/virtio/vsock/csm/connection.rs @@ -235,11 +235,8 @@ where } else { // On a successful data read, we fill in the packet with the RW op, and // length of the read data. - // Safe to unwrap because read_cnt is no more than max_len, which is bounded - // by self.peer_avail_credit(), a u32 internally. - pkt.set_op(uapi::VSOCK_OP_RW) - .set_len(u32::try_from(read_cnt).unwrap()); - METRICS.rx_bytes_count.add(read_cnt as u64); + pkt.set_op(uapi::VSOCK_OP_RW).set_len(read_cnt); + METRICS.rx_bytes_count.add(read_cnt.into()); } self.rx_cnt += Wrapping(pkt.len()); self.last_fwd_cnt_to_peer = self.fwd_cnt; @@ -605,7 +602,7 @@ where /// Raw data can either be sent straight to the host stream, or to our TX buffer, if the /// former fails. fn send_bytes(&mut self, pkt: &VsockPacket) -> Result<(), VsockError> { - let len = pkt.len() as usize; + let len = pkt.len(); // If there is data in the TX buffer, that means we're already registered for EPOLLOUT // events on the underlying stream. Therefore, there's no point in attempting a write @@ -635,8 +632,8 @@ where }; // Move the "forwarded bytes" counter ahead by how much we were able to send out. // Safe to unwrap because the maximum value is pkt.len(), which is a u32. - self.fwd_cnt += wrap_usize_to_u32(written); - METRICS.tx_bytes_count.add(written as u64); + self.fwd_cnt += written; + METRICS.tx_bytes_count.add(written.into()); // If we couldn't write the whole slice, we'll need to push the remaining data to our // buffer. @@ -662,8 +659,8 @@ where /// Get the maximum number of bytes that we can send to our peer, without overflowing its /// buffer. - fn peer_avail_credit(&self) -> usize { - (Wrapping(self.peer_buf_alloc) - (self.rx_cnt - self.peer_fwd_cnt)).0 as usize + fn peer_avail_credit(&self) -> u32 { + (Wrapping(self.peer_buf_alloc) - (self.rx_cnt - self.peer_fwd_cnt)).0 } /// Prepare a packet header for transmission to our peer. @@ -918,7 +915,7 @@ mod tests { assert!(credit < self.conn.peer_buf_alloc); self.conn.peer_fwd_cnt = Wrapping(0); self.conn.rx_cnt = Wrapping(self.conn.peer_buf_alloc - credit); - assert_eq!(self.conn.peer_avail_credit(), credit as usize); + assert_eq!(self.conn.peer_avail_credit(), credit); } fn send(&mut self) { @@ -943,11 +940,13 @@ mod tests { } fn init_data_tx_pkt(&mut self, mut data: &[u8]) -> &VsockPacket { - assert!(data.len() <= self.tx_pkt.buf_size()); + assert!(data.len() <= self.tx_pkt.buf_size() as usize); self.init_tx_pkt(uapi::VSOCK_OP_RW, u32::try_from(data.len()).unwrap()); let len = data.len(); - self.rx_pkt.read_at_offset_from(&mut data, 0, len).unwrap(); + self.rx_pkt + .read_at_offset_from(&mut data, 0, len.try_into().unwrap()) + .unwrap(); &self.tx_pkt } } @@ -1284,7 +1283,7 @@ mod tests { ctx.set_stream(stream); // Fill up the TX buffer. - let data = vec![0u8; ctx.tx_pkt.buf_size()]; + let data = vec![0u8; ctx.tx_pkt.buf_size() as usize]; ctx.init_data_tx_pkt(data.as_slice()); for _i in 0..(csm_defs::CONN_TX_BUF_SIZE as usize / data.len()) { ctx.send(); diff --git a/src/vmm/src/devices/virtio/vsock/packet.rs b/src/vmm/src/devices/virtio/vsock/packet.rs index ba82b169b3c..1f02218792a 100644 --- a/src/vmm/src/devices/virtio/vsock/packet.rs +++ b/src/vmm/src/devices/virtio/vsock/packet.rs @@ -207,33 +207,34 @@ impl VsockPacket { /// /// Return value will equal the total length of the underlying descriptor chain's buffers, /// minus the length of the vsock header. - pub fn buf_size(&self) -> usize { + pub fn buf_size(&self) -> u32 { let chain_length = match self.buffer { VsockPacketBuffer::Tx(ref iovec_buf) => iovec_buf.len(), VsockPacketBuffer::Rx(ref iovec_buf) => iovec_buf.len(), }; - (chain_length - VSOCK_PKT_HDR_SIZE) as usize + chain_length - VSOCK_PKT_HDR_SIZE } pub fn read_at_offset_from( &mut self, src: &mut T, - offset: usize, - count: usize, - ) -> Result { + offset: u32, + count: u32, + ) -> Result { match self.buffer { VsockPacketBuffer::Tx(_) => Err(VsockError::UnwritableDescriptor), VsockPacketBuffer::Rx(ref mut buffer) => { if count - > (buffer.len() as usize) - .saturating_sub(VSOCK_PKT_HDR_SIZE as usize) + > buffer + .len() + .saturating_sub(VSOCK_PKT_HDR_SIZE) .saturating_sub(offset) { return Err(VsockError::GuestMemoryBounds); } buffer - .write_volatile_at(src, offset + VSOCK_PKT_HDR_SIZE as usize, count) + .write_volatile_at(src, offset + VSOCK_PKT_HDR_SIZE, count) .map_err(|err| VsockError::GuestMemoryMmap(GuestMemoryError::from(err))) } } @@ -242,21 +243,22 @@ impl VsockPacket { pub fn write_from_offset_to( &self, dst: &mut T, - offset: usize, - count: usize, - ) -> Result { + offset: u32, + count: u32, + ) -> Result { match self.buffer { VsockPacketBuffer::Tx(ref buffer) => { if count - > (buffer.len() as usize) - .saturating_sub(VSOCK_PKT_HDR_SIZE as usize) + > buffer + .len() + .saturating_sub(VSOCK_PKT_HDR_SIZE) .saturating_sub(offset) { return Err(VsockError::GuestMemoryBounds); } buffer - .read_volatile_at(dst, offset + VSOCK_PKT_HDR_SIZE as usize, count) + .read_volatile_at(dst, offset + VSOCK_PKT_HDR_SIZE, count) .map_err(|err| VsockError::GuestMemoryMmap(GuestMemoryError::from(err))) } VsockPacketBuffer::Rx(_) => Err(VsockError::UnreadableDescriptor), @@ -529,10 +531,7 @@ mod tests { .unwrap(), ) .unwrap(); - assert_eq!( - pkt.buf_size(), - handler_ctx.guest_rxvq.dtable[1].len.get() as usize - ); + assert_eq!(pkt.buf_size(), handler_ctx.guest_rxvq.dtable[1].len.get()); } // Test case: read-only RX packet header. @@ -641,35 +640,36 @@ mod tests { .unwrap(); let buf_desc = &mut handler_ctx.guest_rxvq.dtable[1]; - assert_eq!(pkt.buf_size(), buf_desc.len.get() as usize); - let zeros = vec![0_u8; pkt.buf_size()]; + assert_eq!(pkt.buf_size(), buf_desc.len.get()); + let zeros = vec![0_u8; pkt.buf_size() as usize]; let data: Vec = (0..pkt.buf_size()) - .map(|i| ((i as u64) & 0xff) as u8) + .map(|i| ((u64::from(i)) & 0xff) as u8) .collect(); for offset in 0..pkt.buf_size() { + let count = pkt.buf_size() - offset; buf_desc.set_data(&zeros); - let mut expected_data = zeros[..offset].to_vec(); - expected_data.extend_from_slice(&data[..pkt.buf_size() - offset]); + let mut expected_data = zeros[..offset as usize].to_vec(); + expected_data.extend_from_slice(&data[..count as usize]); - pkt.read_at_offset_from(&mut data.as_slice(), offset, pkt.buf_size() - offset) + pkt.read_at_offset_from(&mut data.as_slice(), offset, count) .unwrap(); buf_desc.check_data(&expected_data); - let mut buf = vec![0; pkt.buf_size()]; - pkt2.write_from_offset_to(&mut buf.as_mut_slice(), offset, pkt.buf_size() - offset) + let mut buf = vec![0; pkt.buf_size() as usize]; + pkt2.write_from_offset_to(&mut buf.as_mut_slice(), offset, count) .unwrap(); - assert_eq!(&buf[..pkt.buf_size() - offset], &expected_data[offset..]); + assert_eq!(&buf[..count as usize], &expected_data[offset as usize..]); } let oob_cases = vec![ (1, pkt.buf_size()), (pkt.buf_size(), 1), - (usize::MAX, 1), - (1, usize::MAX), + (u32::MAX, 1), + (1, u32::MAX), ]; - let mut buf = vec![0; pkt.buf_size()]; + let mut buf = vec![0; pkt.buf_size() as usize]; for (offset, count) in oob_cases { let res = pkt.read_at_offset_from(&mut data.as_slice(), offset, count); assert!(matches!(res, Err(VsockError::GuestMemoryBounds))); diff --git a/src/vmm/src/devices/virtio/vsock/test_utils.rs b/src/vmm/src/devices/virtio/vsock/test_utils.rs index 4360e2f2a48..3b30c564e38 100644 --- a/src/vmm/src/devices/virtio/vsock/test_utils.rs +++ b/src/vmm/src/devices/virtio/vsock/test_utils.rs @@ -69,7 +69,7 @@ impl VsockChannel for TestBackend { let buf_size = pkt.buf_size(); if buf_size > 0 { let buf: Vec = (0..buf_size) - .map(|i| cool_buf[i % cool_buf.len()]) + .map(|i| cool_buf[i as usize % cool_buf.len()]) .collect(); pkt.read_at_offset_from(&mut buf.as_slice(), 0, buf_size) .unwrap(); @@ -206,8 +206,8 @@ impl<'a> EventHandlerContext<'a> { } #[cfg(test)] -pub fn read_packet_data(pkt: &VsockPacket, how_much: usize) -> Vec { - let mut buf = vec![0; how_much]; +pub fn read_packet_data(pkt: &VsockPacket, how_much: u32) -> Vec { + let mut buf = vec![0; how_much as usize]; pkt.write_from_offset_to(&mut buf.as_mut_slice(), 0, how_much) .unwrap(); buf diff --git a/src/vmm/src/devices/virtio/vsock/unix/muxer.rs b/src/vmm/src/devices/virtio/vsock/unix/muxer.rs index 00bf511a209..3f591bda34e 100644 --- a/src/vmm/src/devices/virtio/vsock/unix/muxer.rs +++ b/src/vmm/src/devices/virtio/vsock/unix/muxer.rs @@ -871,13 +871,13 @@ mod tests { peer_port: u32, mut data: &[u8], ) -> &mut VsockPacket { - assert!(data.len() <= self.tx_pkt.buf_size()); + assert!(data.len() <= self.tx_pkt.buf_size() as usize); self.init_tx_pkt(local_port, peer_port, uapi::VSOCK_OP_RW) - .set_len(u32::try_from(data.len()).unwrap()); + .set_len(data.len().try_into().unwrap()); let data_len = data.len(); // store in tmp var to make borrow checker happy. self.rx_pkt - .read_at_offset_from(&mut data, 0, data_len) + .read_at_offset_from(&mut data, 0, data_len.try_into().unwrap()) .unwrap(); &mut self.tx_pkt }