diff --git a/src/vmm/src/devices/virtio/iovec.rs b/src/vmm/src/devices/virtio/iovec.rs index fd48a94ca2c..c4525b961e7 100644 --- a/src/vmm/src/devices/virtio/iovec.rs +++ b/src/vmm/src/devices/virtio/iovec.rs @@ -97,10 +97,8 @@ impl IoVecBuffer { /// /// The descriptor chain cannot be referencing the same memory location as another chain pub unsafe fn from_descriptor_chain(head: DescriptorChain) -> Result { - let mut new_buffer: Self = Default::default(); - + let mut new_buffer = Self::default(); new_buffer.load_descriptor_chain(head)?; - Ok(new_buffer) } @@ -217,7 +215,7 @@ impl IoVecBuffer { /// It describes a write-only buffer passed to us by the guest that is scattered across multiple /// memory regions. Additionally, this wrapper provides methods that allow reading arbitrary ranges /// of data from that buffer. -#[derive(Debug)] +#[derive(Debug, Default, Clone)] pub struct IoVecBufferMut { // container of the memory regions included in this IO vector vecs: IoVecVec, @@ -226,12 +224,19 @@ pub struct IoVecBufferMut { } impl IoVecBufferMut { - /// Create an `IoVecBufferMut` from a `DescriptorChain` - pub fn from_descriptor_chain(head: DescriptorChain) -> Result { - let mut vecs = IoVecVec::new(); - let mut len = 0u32; + /// Create an `IoVecBuffer` from a `DescriptorChain` + /// + /// # Safety + /// + /// The descriptor chain cannot be referencing the same memory location as another chain + pub unsafe fn load_descriptor_chain( + &mut self, + head: DescriptorChain, + ) -> Result<(), IoVecError> { + self.clear(); - for desc in head { + let mut next_descriptor = Some(head); + while let Some(desc) = next_descriptor { if !desc.is_write_only() { return Err(IoVecError::ReadOnlyDescriptor); } @@ -247,16 +252,30 @@ impl IoVecBufferMut { slice.bitmap().mark_dirty(0, desc.len as usize); let iov_base = slice.ptr_guard_mut().as_ptr().cast::(); - vecs.push(iovec { + self.vecs.push(iovec { iov_base, iov_len: desc.len as size_t, }); - len = len + self.len = self + .len .checked_add(desc.len) .ok_or(IoVecError::OverflowedDescriptor)?; + + next_descriptor = desc.next_descriptor(); } - Ok(Self { vecs, len }) + Ok(()) + } + + /// Create an `IoVecBuffer` from a `DescriptorChain` + /// + /// # Safety + /// + /// The descriptor chain cannot be referencing the same memory location as another chain + pub unsafe fn from_descriptor_chain(head: DescriptorChain) -> Result { + let mut new_buffer = Self::default(); + new_buffer.load_descriptor_chain(head)?; + Ok(new_buffer) } /// Get the total length of the memory regions covered by this `IoVecBuffer` @@ -264,6 +283,12 @@ impl IoVecBufferMut { self.len } + /// Clears the `iovec` array + pub fn clear(&mut self) { + self.vecs.clear(); + self.len = 0u32; + } + /// Writes a number of bytes into the `IoVecBufferMut` starting at a given offset. /// /// This will try to fill `IoVecBufferMut` writing bytes from the `buf` starting from @@ -468,11 +493,13 @@ mod tests { let (mut q, _) = read_only_chain(&mem); let head = q.pop(&mem).unwrap(); - IoVecBufferMut::from_descriptor_chain(head).unwrap_err(); + // SAFETY: This descriptor chain is only loaded into one buffer + unsafe { IoVecBufferMut::from_descriptor_chain(head).unwrap_err() }; let (mut q, _) = write_only_chain(&mem); let head = q.pop(&mem).unwrap(); - IoVecBufferMut::from_descriptor_chain(head).unwrap(); + // SAFETY: This descriptor chain is only loaded into one buffer + unsafe { IoVecBufferMut::from_descriptor_chain(head).unwrap() }; } #[test] @@ -493,7 +520,7 @@ mod tests { let head = q.pop(&mem).unwrap(); // SAFETY: This descriptor chain is only loaded once in this test - let iovec = IoVecBufferMut::from_descriptor_chain(head).unwrap(); + let iovec = unsafe { IoVecBufferMut::from_descriptor_chain(head).unwrap() }; assert_eq!(iovec.len(), 4 * 64); } @@ -558,7 +585,8 @@ mod tests { // This is a descriptor chain with 4 elements 64 bytes long each. let head = q.pop(&mem).unwrap(); - let mut iovec = IoVecBufferMut::from_descriptor_chain(head).unwrap(); + // SAFETY: This descriptor chain is only loaded into one buffer + let mut iovec = unsafe { IoVecBufferMut::from_descriptor_chain(head).unwrap() }; let buf = vec![0u8, 1, 2, 3, 4]; // One test vector for each part of the chain diff --git a/src/vmm/src/devices/virtio/queue.rs b/src/vmm/src/devices/virtio/queue.rs index 0fd6882d201..8bb27b2c892 100644 --- a/src/vmm/src/devices/virtio/queue.rs +++ b/src/vmm/src/devices/virtio/queue.rs @@ -402,33 +402,28 @@ impl Queue { // In a naive notation, that would be: // `descriptor_table[avail_ring[next_avail]]`. // - // First, we compute the byte-offset (into `self.avail_ring`) of the index of the next - // available descriptor. `self.avail_ring` stores the address of a `struct - // virtq_avail`, as defined by the VirtIO spec: - // - // ```C - // struct virtq_avail { - // le16 flags; - // le16 idx; - // le16 ring[QUEUE_SIZE]; - // le16 used_event + // Avail ring has layout: + // struct AvailRing { + // flags: u16, + // idx: u16, + // ring: [u16; ], + // used_event: u16, // } - // ``` - // - // We use `self.next_avail` to store the position, in `ring`, of the next available - // descriptor index, with a twist: we always only increment `self.next_avail`, so the - // actual position will be `self.next_avail % self.actual_size()`. - // We are now looking for the offset of `ring[self.next_avail % self.actual_size()]`. - // `ring` starts after `flags` and `idx` (4 bytes into `struct virtq_avail`), and holds - // 2-byte items, so the offset will be: - let index_offset = 4 + 2 * (self.next_avail.0 % self.actual_size()); + // We calculate offset into `ring` field. + // We use `self.next_avail` to store the position, of the next available descriptor + // index in the `ring` field. Because `self.next_avail` is only incremented, the actual + // index into `AvailRing` is `self.next_avail % self.actual_size()`. + let desc_index_offset = std::mem::size_of::() + + std::mem::size_of::() + + std::mem::size_of::() * usize::from(self.next_avail.0 % self.actual_size()); + let desc_index_address = self + .avail_ring + .unchecked_add(usize_to_u64(desc_index_offset)); // `self.is_valid()` already performed all the bound checks on the descriptor table // and virtq rings, so it's safe to unwrap guest memory reads and to use unchecked // offsets. - let desc_index: u16 = mem - .read_obj(self.avail_ring.unchecked_add(u64::from(index_offset))) - .unwrap(); + let desc_index: u16 = mem.read_obj(desc_index_address).unwrap(); DescriptorChain::checked_new(mem, self.desc_table, self.actual_size(), desc_index).map( |dc| { diff --git a/src/vmm/src/devices/virtio/rng/device.rs b/src/vmm/src/devices/virtio/rng/device.rs index bb01ce5e44e..f671f00e554 100644 --- a/src/vmm/src/devices/virtio/rng/device.rs +++ b/src/vmm/src/devices/virtio/rng/device.rs @@ -132,7 +132,10 @@ impl Entropy { let index = desc.index; METRICS.entropy_event_count.inc(); - let bytes = match IoVecBufferMut::from_descriptor_chain(desc) { + // SAFETY: This descriptor chain is only loaded once + // virtio requests are handled sequentially so no two IoVecBuffers + // are live at the same time, meaning this has exclusive ownership over the memory + let bytes = match unsafe { IoVecBufferMut::from_descriptor_chain(desc) } { Ok(mut iovec) => { debug!( "entropy: guest request for {} bytes of entropy", @@ -428,13 +431,15 @@ mod tests { // This should succeed, we just added two descriptors let desc = entropy_dev.queues_mut()[RNG_QUEUE].pop(&mem).unwrap(); assert!(matches!( - IoVecBufferMut::from_descriptor_chain(desc), + // SAFETY: This descriptor chain is only loaded into one buffer + unsafe { IoVecBufferMut::from_descriptor_chain(desc) }, Err(crate::devices::virtio::iovec::IoVecError::ReadOnlyDescriptor) )); // This should succeed, we should have one more descriptor let desc = entropy_dev.queues_mut()[RNG_QUEUE].pop(&mem).unwrap(); - let mut iovec = IoVecBufferMut::from_descriptor_chain(desc).unwrap(); + // SAFETY: This descriptor chain is only loaded into one buffer + let mut iovec = unsafe { IoVecBufferMut::from_descriptor_chain(desc).unwrap() }; entropy_dev.handle_one(&mut iovec).unwrap(); } diff --git a/src/vmm/src/devices/virtio/vsock/packet.rs b/src/vmm/src/devices/virtio/vsock/packet.rs index 952f8b1511e..c18b45b9a94 100644 --- a/src/vmm/src/devices/virtio/vsock/packet.rs +++ b/src/vmm/src/devices/virtio/vsock/packet.rs @@ -161,7 +161,10 @@ impl VsockPacket { /// Returns [`VsockError::DescChainTooShortForHeader`] if the descriptor chain's total buffer /// length is insufficient to hold the 44 byte vsock header pub fn from_rx_virtq_head(chain: DescriptorChain) -> Result { - let buffer = IoVecBufferMut::from_descriptor_chain(chain)?; + // SAFETY: This descriptor chain is only loaded once + // virtio requests are handled sequentially so no two IoVecBuffers + // are live at the same time, meaning this has exclusive ownership over the memory + let buffer = unsafe { IoVecBufferMut::from_descriptor_chain(chain)? }; if buffer.len() < VSOCK_PKT_HDR_SIZE { return Err(VsockError::DescChainTooShortForHeader(buffer.len() as usize));