diff --git a/src/vmm/src/devices/virtio/transport/pci/common_config.rs b/src/vmm/src/devices/virtio/transport/pci/common_config.rs index 00b61e67b67..d353b04c43e 100644 --- a/src/vmm/src/devices/virtio/transport/pci/common_config.rs +++ b/src/vmm/src/devices/virtio/transport/pci/common_config.rs @@ -16,6 +16,7 @@ use vm_memory::GuestAddress; use crate::devices::virtio::device::VirtioDevice; use crate::devices::virtio::queue::Queue; +use crate::devices::virtio::transport::pci::device::VIRTQ_MSI_NO_VECTOR; use crate::logger::{debug, error, info, trace, warn}; pub const VIRTIO_PCI_COMMON_CONFIG_ID: &str = "virtio_pci_common_config"; @@ -242,7 +243,7 @@ impl VirtioPciCommonConfig { .unwrap() .get(self.queue_select as usize) .copied() - .unwrap_or(0xffff), + .unwrap_or(VIRTQ_MSI_NO_VECTOR), 0x1c => u16::from(self.with_queue(queues, |q| q.ready).unwrap_or(false)), 0x1e => self.queue_select, // notify_off _ => { @@ -255,19 +256,36 @@ impl VirtioPciCommonConfig { fn write_common_config_word(&mut self, offset: u64, value: u16, queues: &mut [Queue]) { debug!("write_common_config_word: offset 0x{:x}", offset); match offset { - 0x10 => self.msix_config.store(value, Ordering::Release), + 0x10 => { + // Make sure that the guest doesn't select an invalid vector. We are offering + // `num_queues + 1` vectors (plus one for configuration updates). If an invalid + // vector has been selected, we just store the `NO_VECTOR` value. + let mut msix_queues = self.msix_queues.lock().expect("Poisoned lock"); + let nr_vectors = msix_queues.len() + 1; + + if (value as usize) < nr_vectors { + self.msix_config.store(value, Ordering::Release); + } else { + self.msix_config + .store(VIRTQ_MSI_NO_VECTOR, Ordering::Release); + } + } 0x16 => self.queue_select = value, 0x18 => self.with_queue_mut(queues, |q| q.size = value), 0x1a => { + let mut msix_queues = self.msix_queues.lock().expect("Poisoned lock"); + let nr_vectors = msix_queues.len() + 1; // Make sure that `queue_select` points to a valid queue. If not, we won't do // anything here and subsequent reads at 0x1a will return `NO_VECTOR`. - if let Some(msix_queue) = self - .msix_queues - .lock() - .unwrap() - .get_mut(self.queue_select as usize) - { - *msix_queue = value; + if let Some(queue) = msix_queues.get_mut(self.queue_select as usize) { + // Make sure that the guest doesn't select an invalid vector. We are offering + // `num_queues + 1` vectors (plus one for configuration updates). If an invalid + // vector has been selected, we just store the `NO_VECTOR` value. + if (value as usize) < nr_vectors { + *queue = value; + } else { + *queue = VIRTQ_MSI_NO_VECTOR; + } } } 0x1c => self.with_queue_mut(queues, |q| { @@ -446,8 +464,8 @@ mod tests { // Valid `queue_select` though should setup the corresponding MSI-X queue. regs.write(0x16, &[0x1, 0x0], dev.clone()); assert_eq!(regs.queue_select, 1); - regs.write(0x1a, &[0x12, 0x13], dev.clone()); + regs.write(0x1a, &[0x1, 0x0], dev.clone()); regs.read(0x1a, &mut read_back, dev); - assert_eq!(LittleEndian::read_u16(&read_back[..2]), 0x1312); + assert_eq!(LittleEndian::read_u16(&read_back[..2]), 0x1); } } diff --git a/src/vmm/src/devices/virtio/transport/pci/device.rs b/src/vmm/src/devices/virtio/transport/pci/device.rs index ba91163fe49..9daf88201ac 100644 --- a/src/vmm/src/devices/virtio/transport/pci/device.rs +++ b/src/vmm/src/devices/virtio/transport/pci/device.rs @@ -11,7 +11,7 @@ use std::any::Any; use std::cmp; use std::collections::HashMap; use std::fmt::{Debug, Formatter}; -use std::io::Write; +use std::io::{ErrorKind, Write}; use std::sync::atomic::{AtomicBool, AtomicU16, AtomicU32, AtomicUsize, Ordering}; use std::sync::{Arc, Barrier, Mutex}; @@ -65,7 +65,7 @@ const VIRTIO_F_SR_IOV: u32 = 37; const VIRTIO_F_NOTIFICATION_DATA: u32 = 38; /// Vector value used to disable MSI for a queue. -const VIRTQ_MSI_NO_VECTOR: u16 = 0xffff; +pub const VIRTQ_MSI_NO_VECTOR: u16 = 0xffff; /// BAR index we are using for VirtIO configuration const VIRTIO_BAR_INDEX: u8 = 0; @@ -765,9 +765,12 @@ impl VirtioInterrupt for VirtioInterruptMsix { fn trigger(&self, int_type: VirtioInterruptType) -> std::result::Result<(), std::io::Error> { let vector = match int_type { VirtioInterruptType::Config => self.config_vector.load(Ordering::Acquire), - VirtioInterruptType::Queue(queue_index) => { - self.queues_vectors.lock().unwrap()[queue_index as usize] - } + VirtioInterruptType::Queue(queue_index) => *self + .queues_vectors + .lock() + .unwrap() + .get(queue_index as usize) + .ok_or(ErrorKind::InvalidInput)?, }; if vector == VIRTQ_MSI_NO_VECTOR { @@ -793,9 +796,11 @@ impl VirtioInterrupt for VirtioInterruptMsix { fn notifier(&self, int_type: VirtioInterruptType) -> Option<&EventFd> { let vector = match int_type { VirtioInterruptType::Config => self.config_vector.load(Ordering::Acquire), - VirtioInterruptType::Queue(queue_index) => { - self.queues_vectors.lock().unwrap()[queue_index as usize] - } + VirtioInterruptType::Queue(queue_index) => *self + .queues_vectors + .lock() + .unwrap() + .get(queue_index as usize)?, }; self.interrupt_source_group