diff --git a/src/vmm/src/builder.rs b/src/vmm/src/builder.rs index 398c25ba056..efc15b1e3c9 100644 --- a/src/vmm/src/builder.rs +++ b/src/vmm/src/builder.rs @@ -42,9 +42,9 @@ use crate::devices::legacy::{EventFdTrigger, SerialEventsWrapper, SerialWrapper} use crate::devices::virtio::balloon::Balloon; use crate::devices::virtio::block::device::Block; use crate::devices::virtio::device::VirtioDevice; -use crate::devices::virtio::mmio::MmioTransport; use crate::devices::virtio::net::Net; use crate::devices::virtio::rng::Entropy; +use crate::devices::virtio::transport::mmio::{IrqTrigger, MmioTransport}; use crate::devices::virtio::vsock::{Vsock, VsockUnixBackend}; #[cfg(feature = "gdb")] use crate::gdb; @@ -597,8 +597,14 @@ fn attach_virtio_device( ) -> Result<(), MmioError> { event_manager.add_subscriber(device.clone()); + let interrupt = Arc::new(IrqTrigger::new()); // The device mutex mustn't be locked here otherwise it will deadlock. - let device = MmioTransport::new(vmm.vm.guest_memory().clone(), device, is_vhost_user); + let device = MmioTransport::new( + vmm.vm.guest_memory().clone(), + interrupt, + device, + is_vhost_user, + ); vmm.mmio_device_manager .register_mmio_virtio_for_boot( vmm.vm.fd(), diff --git a/src/vmm/src/device_manager/mmio.rs b/src/vmm/src/device_manager/mmio.rs index 99bde6e2e78..333da93fa8a 100644 --- a/src/vmm/src/device_manager/mmio.rs +++ b/src/vmm/src/device_manager/mmio.rs @@ -30,9 +30,9 @@ use crate::devices::pseudo::BootTimer; use crate::devices::virtio::balloon::Balloon; use crate::devices::virtio::block::device::Block; use crate::devices::virtio::device::VirtioDevice; -use crate::devices::virtio::mmio::MmioTransport; use crate::devices::virtio::net::Net; use crate::devices::virtio::rng::Entropy; +use crate::devices::virtio::transport::mmio::MmioTransport; use crate::devices::virtio::vsock::{TYPE_VSOCK, Vsock, VsockUnixBackend}; use crate::devices::virtio::{TYPE_BALLOON, TYPE_BLOCK, TYPE_NET, TYPE_RNG}; #[cfg(target_arch = "x86_64")] @@ -53,6 +53,8 @@ pub enum MmioError { InvalidDeviceType, /// {0} InternalDeviceError(String), + /// Could not create IRQ for MMIO device: {0} + CreateIrq(#[from] std::io::Error), /// Invalid MMIO IRQ configuration. InvalidIrqConfig, /// Failed to register IO event: {0} @@ -205,7 +207,7 @@ impl MMIODeviceManager { vm.register_ioevent(queue_evt, &io_addr, u32::try_from(i).unwrap()) .map_err(MmioError::RegisterIoEvent)?; } - vm.register_irqfd(&locked_device.interrupt_trigger().irq_evt, irq.get()) + vm.register_irqfd(&mmio_device.interrupt.irq_evt, irq.get()) .map_err(MmioError::RegisterIrqFd)?; } @@ -223,7 +225,7 @@ impl MMIODeviceManager { device_info: &MMIODeviceInfo, ) -> Result<(), MmioError> { // as per doc, [virtio_mmio.]device=@: needs to be appended - // to kernel command line for virtio mmio devices to get recongnized + // to kernel command line for virtio mmio devices to get recognized // the size parameter has to be transformed to KiB, so dividing hexadecimal value in // bytes to 1024; further, the '{}' formatting rust construct will automatically // transform it to decimal @@ -503,7 +505,7 @@ impl MMIODeviceManager { .unwrap(); if vsock.is_activated() { info!("kick vsock {id}."); - vsock.signal_used_queue().unwrap(); + vsock.signal_used_queue(0).unwrap(); } } TYPE_RNG => { @@ -523,6 +525,7 @@ impl MMIODeviceManager { #[cfg(test)] mod tests { + use std::ops::Deref; use std::sync::Arc; use vmm_sys_util::eventfd::EventFd; @@ -530,8 +533,10 @@ mod tests { use super::*; use crate::Vm; use crate::devices::virtio::ActivateError; - use crate::devices::virtio::device::{IrqTrigger, VirtioDevice}; + use crate::devices::virtio::device::VirtioDevice; use crate::devices::virtio::queue::Queue; + use crate::devices::virtio::transport::VirtioInterrupt; + use crate::devices::virtio::transport::mmio::IrqTrigger; use crate::test_utils::multi_region_mem_raw; use crate::vstate::kvm::Kvm; use crate::vstate::memory::{GuestAddress, GuestMemoryMmap}; @@ -548,7 +553,8 @@ mod tests { cmdline: &mut kernel_cmdline::Cmdline, dev_id: &str, ) -> Result { - let mmio_device = MmioTransport::new(guest_mem, device, false); + let interrupt = Arc::new(IrqTrigger::new()); + let mmio_device = MmioTransport::new(guest_mem, interrupt, device, false); let device_info = self.register_mmio_virtio_for_boot( vm, resource_allocator, @@ -575,7 +581,7 @@ mod tests { dummy: u32, queues: Vec, queue_evts: [EventFd; 1], - interrupt_trigger: IrqTrigger, + interrupt_trigger: Option>, } impl DummyDevice { @@ -584,7 +590,7 @@ mod tests { dummy: 0, queues: QUEUE_SIZES.iter().map(|&s| Queue::new(s)).collect(), queue_evts: [EventFd::new(libc::EFD_NONBLOCK).expect("cannot create eventFD")], - interrupt_trigger: IrqTrigger::new().expect("cannot create eventFD"), + interrupt_trigger: None, } } } @@ -616,8 +622,8 @@ mod tests { &self.queue_evts } - fn interrupt_trigger(&self) -> &IrqTrigger { - &self.interrupt_trigger + fn interrupt_trigger(&self) -> &dyn VirtioInterrupt { + self.interrupt_trigger.as_ref().unwrap().deref() } fn ack_features_by_page(&mut self, page: u32, value: u32) { @@ -635,7 +641,11 @@ mod tests { let _ = data; } - fn activate(&mut self, _: GuestMemoryMmap) -> Result<(), ActivateError> { + fn activate( + &mut self, + _: GuestMemoryMmap, + _: Arc, + ) -> Result<(), ActivateError> { Ok(()) } diff --git a/src/vmm/src/device_manager/persist.rs b/src/vmm/src/device_manager/persist.rs index 30a6387bc82..2f331e644ad 100644 --- a/src/vmm/src/device_manager/persist.rs +++ b/src/vmm/src/device_manager/persist.rs @@ -25,7 +25,6 @@ use crate::devices::virtio::block::BlockError; use crate::devices::virtio::block::device::Block; use crate::devices::virtio::block::persist::{BlockConstructorArgs, BlockState}; use crate::devices::virtio::device::VirtioDevice; -use crate::devices::virtio::mmio::MmioTransport; use crate::devices::virtio::net::Net; use crate::devices::virtio::net::persist::{ NetConstructorArgs, NetPersistError as NetError, NetState, @@ -35,6 +34,7 @@ use crate::devices::virtio::rng::Entropy; use crate::devices::virtio::rng::persist::{ EntropyConstructorArgs, EntropyPersistError as EntropyError, EntropyState, }; +use crate::devices::virtio::transport::mmio::{IrqTrigger, MmioTransport}; use crate::devices::virtio::vsock::persist::{ VsockConstructorArgs, VsockState, VsockUdsConstructorArgs, }; @@ -473,11 +473,13 @@ impl<'a> Persist<'a> for MMIODeviceManager { as_subscriber: Arc>, id: &String, state: &MmioTransportState, + interrupt: Arc, device_info: &MMIODeviceInfo, event_manager: &mut EventManager| -> Result<(), Self::Error> { let restore_args = MmioTransportConstructorArgs { mem: mem.clone(), + interrupt, device, is_vhost_user, }; @@ -512,9 +514,11 @@ impl<'a> Persist<'a> for MMIODeviceManager { }; if let Some(balloon_state) = &state.balloon_device { + let interrupt = Arc::new(IrqTrigger::new()); let device = Arc::new(Mutex::new(Balloon::restore( BalloonConstructorArgs { mem: mem.clone(), + interrupt: interrupt.clone(), restored_from_file: constructor_args.restored_from_file, }, &balloon_state.device_state, @@ -530,14 +534,19 @@ impl<'a> Persist<'a> for MMIODeviceManager { device, &balloon_state.device_id, &balloon_state.transport_state, + interrupt, &balloon_state.device_info, constructor_args.event_manager, )?; } for block_state in &state.block_devices { + let interrupt = Arc::new(IrqTrigger::new()); let device = Arc::new(Mutex::new(Block::restore( - BlockConstructorArgs { mem: mem.clone() }, + BlockConstructorArgs { + mem: mem.clone(), + interrupt: interrupt.clone(), + }, &block_state.device_state, )?)); @@ -551,6 +560,7 @@ impl<'a> Persist<'a> for MMIODeviceManager { device, &block_state.device_id, &block_state.transport_state, + interrupt, &block_state.device_info, constructor_args.event_manager, )?; @@ -573,9 +583,11 @@ impl<'a> Persist<'a> for MMIODeviceManager { } for net_state in &state.net_devices { + let interrupt = Arc::new(IrqTrigger::new()); let device = Arc::new(Mutex::new(Net::restore( NetConstructorArgs { mem: mem.clone(), + interrupt: interrupt.clone(), mmds: constructor_args .vm_resources .mmds @@ -596,6 +608,7 @@ impl<'a> Persist<'a> for MMIODeviceManager { device, &net_state.device_id, &net_state.transport_state, + interrupt, &net_state.device_info, constructor_args.event_manager, )?; @@ -606,9 +619,11 @@ impl<'a> Persist<'a> for MMIODeviceManager { cid: vsock_state.device_state.frontend.cid, }; let backend = VsockUnixBackend::restore(ctor_args, &vsock_state.device_state.backend)?; + let interrupt = Arc::new(IrqTrigger::new()); let device = Arc::new(Mutex::new(Vsock::restore( VsockConstructorArgs { mem: mem.clone(), + interrupt: interrupt.clone(), backend, }, &vsock_state.device_state.frontend, @@ -624,13 +639,15 @@ impl<'a> Persist<'a> for MMIODeviceManager { device, &vsock_state.device_id, &vsock_state.transport_state, + interrupt, &vsock_state.device_info, constructor_args.event_manager, )?; } if let Some(entropy_state) = &state.entropy_device { - let ctor_args = EntropyConstructorArgs::new(mem.clone()); + let interrupt = Arc::new(IrqTrigger::new()); + let ctor_args = EntropyConstructorArgs::new(mem.clone(), interrupt.clone()); let device = Arc::new(Mutex::new(Entropy::restore( ctor_args, @@ -647,6 +664,7 @@ impl<'a> Persist<'a> for MMIODeviceManager { device, &entropy_state.device_id, &entropy_state.transport_state, + interrupt, &entropy_state.device_info, constructor_args.event_manager, )?; diff --git a/src/vmm/src/devices/bus.rs b/src/vmm/src/devices/bus.rs index 2b016d73083..d0e1b296998 100644 --- a/src/vmm/src/devices/bus.rs +++ b/src/vmm/src/devices/bus.rs @@ -56,7 +56,7 @@ use event_manager::{EventOps, Events, MutEventSubscriber}; use super::legacy::RTCDevice; use super::legacy::{I8042Device, SerialDevice}; use super::pseudo::BootTimer; -use super::virtio::mmio::MmioTransport; +use super::virtio::transport::mmio::MmioTransport; #[derive(Debug)] pub enum BusDevice { diff --git a/src/vmm/src/devices/virtio/balloon/device.rs b/src/vmm/src/devices/virtio/balloon/device.rs index 186f09275bc..3927b7e0aef 100644 --- a/src/vmm/src/devices/virtio/balloon/device.rs +++ b/src/vmm/src/devices/virtio/balloon/device.rs @@ -1,7 +1,8 @@ // Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use std::fmt; +use std::ops::Deref; +use std::sync::Arc; use std::time::Duration; use log::error; @@ -24,8 +25,9 @@ use super::{ VIRTIO_BALLOON_S_SWAP_OUT, }; use crate::devices::virtio::balloon::BalloonError; -use crate::devices::virtio::device::{IrqTrigger, IrqType}; +use crate::devices::virtio::device::ActiveState; use crate::devices::virtio::generated::virtio_config::VIRTIO_F_VERSION_1; +use crate::devices::virtio::transport::{VirtioInterrupt, VirtioInterruptType}; use crate::logger::IncMetric; use crate::utils::u64_to_usize; use crate::vstate::memory::{Address, ByteValued, Bytes, GuestAddress, GuestMemoryMmap}; @@ -149,6 +151,7 @@ impl BalloonStats { } } +#[derive(Debug)] /// Virtio balloon device. pub struct Balloon { // Virtio fields. @@ -161,7 +164,6 @@ pub struct Balloon { pub(crate) queues: Vec, pub(crate) queue_evts: [EventFd; BALLOON_NUM_QUEUES], pub(crate) device_state: DeviceState, - pub(crate) irq_trigger: IrqTrigger, // Implementation specific fields. pub(crate) restored_from_file: bool, @@ -175,29 +177,6 @@ pub struct Balloon { pub(crate) pfn_buffer: [u32; MAX_PAGE_COMPACT_BUFFER], } -// TODO Use `#[derive(Debug)]` when a new release of -// [rust-timerfd](https://github.com/main--/rust-timerfd) is published that includes -// https://github.com/main--/rust-timerfd/pull/12. -impl fmt::Debug for Balloon { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Balloon") - .field("avail_features", &self.avail_features) - .field("acked_features", &self.acked_features) - .field("config_space", &self.config_space) - .field("activate_evt", &self.activate_evt) - .field("queues", &self.queues) - .field("queue_evts", &self.queue_evts) - .field("device_state", &self.device_state) - .field("irq_trigger", &self.irq_trigger) - .field("restored_from_file", &self.restored_from_file) - .field("stats_polling_interval_s", &self.stats_polling_interval_s) - .field("stats_desc_index", &self.stats_desc_index) - .field("latest_stats", &self.latest_stats) - .field("pfn_buffer", &self.pfn_buffer) - .finish() - } -} - impl Balloon { /// Instantiate a new balloon device. pub fn new( @@ -242,7 +221,6 @@ impl Balloon { }, queue_evts, queues, - irq_trigger: IrqTrigger::new().map_err(BalloonError::EventFd)?, device_state: DeviceState::Inactive, activate_evt: EventFd::new(libc::EFD_NONBLOCK).map_err(BalloonError::EventFd)?, restored_from_file, @@ -282,7 +260,7 @@ impl Balloon { pub(crate) fn process_inflate(&mut self) -> Result<(), BalloonError> { // This is safe since we checked in the event handler that the device is activated. - let mem = self.device_state.mem().unwrap(); + let mem = &self.device_state.active_state().unwrap().mem; METRICS.inflate_count.inc(); let queue = &mut self.queues[INFLATE_INDEX]; @@ -363,7 +341,7 @@ impl Balloon { } if needs_interrupt { - self.signal_used_queue()?; + self.signal_used_queue(INFLATE_INDEX)?; } Ok(()) @@ -381,7 +359,7 @@ impl Balloon { } if needs_interrupt { - self.signal_used_queue() + self.signal_used_queue(DEFLATE_INDEX) } else { Ok(()) } @@ -389,7 +367,7 @@ impl Balloon { pub(crate) fn process_stats_queue(&mut self) -> Result<(), BalloonError> { // This is safe since we checked in the event handler that the device is activated. - let mem = self.device_state.mem().unwrap(); + let mem = &self.device_state.active_state().unwrap().mem; METRICS.stats_updates_count.inc(); while let Some(head) = self.queues[STATS_INDEX].pop() { @@ -425,11 +403,16 @@ impl Balloon { Ok(()) } - pub(crate) fn signal_used_queue(&self) -> Result<(), BalloonError> { - self.irq_trigger.trigger_irq(IrqType::Vring).map_err(|err| { - METRICS.event_fails.inc(); - BalloonError::InterruptError(err) - }) + pub(crate) fn signal_used_queue(&self, qidx: usize) -> Result<(), BalloonError> { + self.interrupt_trigger() + .trigger(VirtioInterruptType::Queue( + qidx.try_into() + .unwrap_or_else(|_| panic!("balloon: invalid queue id: {qidx}")), + )) + .map_err(|err| { + METRICS.event_fails.inc(); + BalloonError::InterruptError(err) + }) } /// Process device virtio queue(s). @@ -450,7 +433,7 @@ impl Balloon { self.queues[STATS_INDEX] .add_used(index, 0) .map_err(BalloonError::Queue)?; - self.signal_used_queue() + self.signal_used_queue(STATS_INDEX) } else { error!("Failed to update balloon stats, missing descriptor."); Ok(()) @@ -461,8 +444,8 @@ impl Balloon { pub fn update_size(&mut self, amount_mib: u32) -> Result<(), BalloonError> { if self.is_activated() { self.config_space.num_pages = mib_to_pages(amount_mib)?; - self.irq_trigger - .trigger_irq(IrqType::Config) + self.interrupt_trigger() + .trigger(VirtioInterruptType::Config) .map_err(BalloonError::InterruptError) } else { Err(BalloonError::DeviceNotActive) @@ -573,8 +556,12 @@ impl VirtioDevice for Balloon { &self.queue_evts } - fn interrupt_trigger(&self) -> &IrqTrigger { - &self.irq_trigger + fn interrupt_trigger(&self) -> &dyn VirtioInterrupt { + self.device_state + .active_state() + .expect("Device is not activated") + .interrupt + .deref() } fn read_config(&self, offset: u64, data: &mut [u8]) { @@ -601,13 +588,17 @@ impl VirtioDevice for Balloon { dst.copy_from_slice(data); } - fn activate(&mut self, mem: GuestMemoryMmap) -> Result<(), ActivateError> { + fn activate( + &mut self, + mem: GuestMemoryMmap, + interrupt: Arc, + ) -> Result<(), ActivateError> { for q in self.queues.iter_mut() { q.initialize(&mem) .map_err(ActivateError::QueueMemoryError)?; } - self.device_state = DeviceState::Activated(mem); + self.device_state = DeviceState::Activated(ActiveState { mem, interrupt }); if self.activate_evt.write(1).is_err() { METRICS.activate_fails.inc(); self.device_state = DeviceState::Inactive; @@ -636,7 +627,7 @@ pub(crate) mod tests { check_request_completion, invoke_handler_for_queue_event, set_request, }; use crate::devices::virtio::queue::{VIRTQ_DESC_F_NEXT, VIRTQ_DESC_F_WRITE}; - use crate::devices::virtio::test_utils::{VirtQueue, default_mem}; + use crate::devices::virtio::test_utils::{VirtQueue, default_interrupt, default_mem}; use crate::test_utils::single_region_mem; use crate::vstate::memory::GuestAddress; @@ -813,10 +804,11 @@ pub(crate) mod tests { fn test_invalid_request() { let mut balloon = Balloon::new(0, true, 0, false).unwrap(); let mem = default_mem(); + let interrupt = default_interrupt(); // Only initialize the inflate queue to demonstrate invalid request handling. let infq = VirtQueue::new(GuestAddress(0), &mem, 16); balloon.set_queue(INFLATE_INDEX, infq.create_queue()); - balloon.activate(mem.clone()).unwrap(); + balloon.activate(mem.clone(), interrupt).unwrap(); // Fill the second page with non-zero bytes. for i in 0..0x1000 { @@ -872,9 +864,10 @@ pub(crate) mod tests { fn test_inflate() { let mut balloon = Balloon::new(0, true, 0, false).unwrap(); let mem = default_mem(); + let interrupt = default_interrupt(); let infq = VirtQueue::new(GuestAddress(0), &mem, 16); balloon.set_queue(INFLATE_INDEX, infq.create_queue()); - balloon.activate(mem.clone()).unwrap(); + balloon.activate(mem.clone(), interrupt).unwrap(); // Fill the third page with non-zero bytes. for i in 0..0x1000 { @@ -942,9 +935,10 @@ pub(crate) mod tests { fn test_deflate() { let mut balloon = Balloon::new(0, true, 0, false).unwrap(); let mem = default_mem(); + let interrupt = default_interrupt(); let defq = VirtQueue::new(GuestAddress(0), &mem, 16); balloon.set_queue(DEFLATE_INDEX, defq.create_queue()); - balloon.activate(mem.clone()).unwrap(); + balloon.activate(mem.clone(), interrupt).unwrap(); let page_addr = 0x10; @@ -990,9 +984,10 @@ pub(crate) mod tests { fn test_stats() { let mut balloon = Balloon::new(0, true, 1, false).unwrap(); let mem = default_mem(); + let interrupt = default_interrupt(); let statsq = VirtQueue::new(GuestAddress(0), &mem, 16); balloon.set_queue(STATS_INDEX, statsq.create_queue()); - balloon.activate(mem.clone()).unwrap(); + balloon.activate(mem.clone(), interrupt).unwrap(); let page_addr = 0x100; @@ -1068,7 +1063,9 @@ pub(crate) mod tests { assert!(balloon.stats_desc_index.is_some()); balloon.process_stats_timer_event().unwrap(); assert!(balloon.stats_desc_index.is_none()); - assert!(balloon.irq_trigger.has_pending_irq(IrqType::Vring)); + assert!(balloon.interrupt_trigger().has_pending_interrupt( + VirtioInterruptType::Queue(STATS_INDEX.try_into().unwrap()) + )); }); } } @@ -1077,13 +1074,14 @@ pub(crate) mod tests { fn test_process_balloon_queues() { let mut balloon = Balloon::new(0x10, true, 0, false).unwrap(); let mem = default_mem(); + let interrupt = default_interrupt(); let infq = VirtQueue::new(GuestAddress(0), &mem, 16); let defq = VirtQueue::new(GuestAddress(0), &mem, 16); balloon.set_queue(INFLATE_INDEX, infq.create_queue()); balloon.set_queue(DEFLATE_INDEX, defq.create_queue()); - balloon.activate(mem).unwrap(); + balloon.activate(mem, interrupt).unwrap(); balloon.process_virtio_queues() } @@ -1091,7 +1089,8 @@ pub(crate) mod tests { fn test_update_stats_interval() { let mut balloon = Balloon::new(0, true, 0, false).unwrap(); let mem = default_mem(); - balloon.activate(mem).unwrap(); + let interrupt = default_interrupt(); + balloon.activate(mem, interrupt).unwrap(); assert_eq!( format!("{:?}", balloon.update_stats_polling_interval(1)), "Err(StatisticsStateChange)" @@ -1100,7 +1099,8 @@ pub(crate) mod tests { let mut balloon = Balloon::new(0, true, 1, false).unwrap(); let mem = default_mem(); - balloon.activate(mem).unwrap(); + let interrupt = default_interrupt(); + balloon.activate(mem, interrupt).unwrap(); assert_eq!( format!("{:?}", balloon.update_stats_polling_interval(0)), "Err(StatisticsStateChange)" @@ -1120,7 +1120,10 @@ pub(crate) mod tests { fn test_num_pages() { let mut balloon = Balloon::new(0, true, 0, false).unwrap(); // Switch the state to active. - balloon.device_state = DeviceState::Activated(single_region_mem(0x1)); + balloon.device_state = DeviceState::Activated(ActiveState { + mem: single_region_mem(0x1), + interrupt: default_interrupt(), + }); assert_eq!(balloon.num_pages(), 0); assert_eq!(balloon.actual_pages(), 0); diff --git a/src/vmm/src/devices/virtio/balloon/event_handler.rs b/src/vmm/src/devices/virtio/balloon/event_handler.rs index 56ff5c35047..cec643ef73a 100644 --- a/src/vmm/src/devices/virtio/balloon/event_handler.rs +++ b/src/vmm/src/devices/virtio/balloon/event_handler.rs @@ -136,7 +136,7 @@ pub mod tests { use super::*; use crate::devices::virtio::balloon::test_utils::set_request; - use crate::devices::virtio::test_utils::{VirtQueue, default_mem}; + use crate::devices::virtio::test_utils::{VirtQueue, default_interrupt, default_mem}; use crate::vstate::memory::GuestAddress; #[test] @@ -144,6 +144,7 @@ pub mod tests { let mut event_manager = EventManager::new().unwrap(); let mut balloon = Balloon::new(0, true, 10, false).unwrap(); let mem = default_mem(); + let interrupt = default_interrupt(); let infq = VirtQueue::new(GuestAddress(0), &mem, 16); balloon.set_queue(INFLATE_INDEX, infq.create_queue()); @@ -177,7 +178,11 @@ pub mod tests { } // Now activate the device. - balloon.lock().unwrap().activate(mem.clone()).unwrap(); + balloon + .lock() + .unwrap() + .activate(mem.clone(), interrupt) + .unwrap(); // Process the activate event. let ev_count = event_manager.run_with_timeout(50).unwrap(); assert_eq!(ev_count, 1); diff --git a/src/vmm/src/devices/virtio/balloon/mod.rs b/src/vmm/src/devices/virtio/balloon/mod.rs index 21f96d3ba56..6bdbfb26248 100644 --- a/src/vmm/src/devices/virtio/balloon/mod.rs +++ b/src/vmm/src/devices/virtio/balloon/mod.rs @@ -86,7 +86,7 @@ pub enum BalloonError { MalformedPayload, /// Error restoring the balloon device queues. QueueRestoreError, - /// Received stats querry when stats are disabled. + /// Received stats query when stats are disabled. StatisticsDisabled, /// Statistics cannot be enabled/disabled after activation. StatisticsStateChange, diff --git a/src/vmm/src/devices/virtio/balloon/persist.rs b/src/vmm/src/devices/virtio/balloon/persist.rs index 004fa27f8ca..a6634d07170 100644 --- a/src/vmm/src/devices/virtio/balloon/persist.rs +++ b/src/vmm/src/devices/virtio/balloon/persist.rs @@ -4,7 +4,6 @@ //! Defines the structures needed for saving/restoring balloon devices. use std::sync::Arc; -use std::sync::atomic::AtomicU32; use std::time::Duration; use serde::{Deserialize, Serialize}; @@ -13,9 +12,10 @@ use timerfd::{SetTimeFlags, TimerState}; use super::*; use crate::devices::virtio::TYPE_BALLOON; use crate::devices::virtio::balloon::device::{BalloonStats, ConfigSpace}; -use crate::devices::virtio::device::DeviceState; +use crate::devices::virtio::device::{ActiveState, DeviceState}; use crate::devices::virtio::persist::VirtioDeviceState; use crate::devices::virtio::queue::FIRECRACKER_MAX_QUEUE_SIZE; +use crate::devices::virtio::transport::VirtioInterrupt; use crate::snapshot::Persist; use crate::vstate::memory::GuestMemoryMmap; @@ -95,6 +95,8 @@ pub struct BalloonState { pub struct BalloonConstructorArgs { /// Pointer to guest memory. pub mem: GuestMemoryMmap, + /// Interrupt used from the device. + pub interrupt: Arc, pub restored_from_file: bool, } @@ -144,8 +146,6 @@ impl Persist<'_> for Balloon { FIRECRACKER_MAX_QUEUE_SIZE, ) .map_err(|_| Self::Error::QueueRestoreError)?; - balloon.irq_trigger.irq_status = - Arc::new(AtomicU32::new(state.virtio_state.interrupt_status)); balloon.avail_features = state.virtio_state.avail_features; balloon.acked_features = state.virtio_state.acked_features; balloon.latest_stats = state.latest_stats.create_stats(); @@ -155,7 +155,10 @@ impl Persist<'_> for Balloon { }; if state.virtio_state.activated { - balloon.device_state = DeviceState::Activated(constructor_args.mem); + balloon.device_state = DeviceState::Activated(ActiveState { + mem: constructor_args.mem, + interrupt: constructor_args.interrupt, + }); if balloon.stats_enabled() { // Restore the stats descriptor. @@ -178,12 +181,11 @@ impl Persist<'_> for Balloon { #[cfg(test)] mod tests { - use std::sync::atomic::Ordering; use super::*; use crate::devices::virtio::TYPE_BALLOON; use crate::devices::virtio::device::VirtioDevice; - use crate::devices::virtio::test_utils::default_mem; + use crate::devices::virtio::test_utils::{default_interrupt, default_mem}; use crate::snapshot::Snapshot; #[test] @@ -200,6 +202,7 @@ mod tests { let restored_balloon = Balloon::restore( BalloonConstructorArgs { mem: guest_mem, + interrupt: default_interrupt(), restored_from_file: true, }, &Snapshot::deserialize(&mut mem.as_slice()).unwrap(), @@ -213,11 +216,8 @@ mod tests { assert_eq!(restored_balloon.avail_features, balloon.avail_features); assert_eq!(restored_balloon.config_space, balloon.config_space); assert_eq!(restored_balloon.queues(), balloon.queues()); - assert_eq!( - restored_balloon.interrupt_status().load(Ordering::Relaxed), - balloon.interrupt_status().load(Ordering::Relaxed) - ); - assert_eq!(restored_balloon.is_activated(), balloon.is_activated()); + assert!(!restored_balloon.is_activated()); + assert!(!balloon.is_activated()); assert_eq!( restored_balloon.stats_polling_interval_s, diff --git a/src/vmm/src/devices/virtio/balloon/test_utils.rs b/src/vmm/src/devices/virtio/balloon/test_utils.rs index af0d7f5845e..2665d5dbd87 100644 --- a/src/vmm/src/devices/virtio/balloon/test_utils.rs +++ b/src/vmm/src/devices/virtio/balloon/test_utils.rs @@ -3,6 +3,8 @@ #![doc(hidden)] +#[cfg(test)] +use crate::devices::virtio::device::VirtioDevice; use crate::devices::virtio::test_utils::VirtQueue; #[cfg(test)] use crate::devices::virtio::{balloon::BALLOON_NUM_QUEUES, balloon::Balloon}; @@ -10,7 +12,7 @@ use crate::devices::virtio::{balloon::BALLOON_NUM_QUEUES, balloon::Balloon}; #[cfg(test)] pub fn invoke_handler_for_queue_event(b: &mut Balloon, queue_index: usize) { use crate::devices::virtio::balloon::{DEFLATE_INDEX, INFLATE_INDEX, STATS_INDEX}; - use crate::devices::virtio::device::IrqType; + use crate::devices::virtio::transport::VirtioInterruptType; assert!(queue_index < BALLOON_NUM_QUEUES); // Trigger the queue event. @@ -23,7 +25,11 @@ pub fn invoke_handler_for_queue_event(b: &mut Balloon, queue_index: usize) { _ => unreachable!(), }; // Validate the queue operation finished successfully. - assert!(b.irq_trigger.has_pending_irq(IrqType::Vring)); + let interrupt = b.interrupt_trigger(); + assert!( + interrupt + .has_pending_interrupt(VirtioInterruptType::Queue(queue_index.try_into().unwrap())) + ); } pub fn set_request(queue: &VirtQueue, idx: u16, addr: u64, len: u32, flags: u16) { diff --git a/src/vmm/src/devices/virtio/block/device.rs b/src/vmm/src/devices/virtio/block/device.rs index bf3043bcdd4..f2f797e60f3 100644 --- a/src/vmm/src/devices/virtio/block/device.rs +++ b/src/vmm/src/devices/virtio/block/device.rs @@ -1,6 +1,8 @@ // Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 +use std::sync::Arc; + use event_manager::{EventOps, Events, MutEventSubscriber}; use vmm_sys_util::eventfd::EventFd; @@ -8,8 +10,9 @@ use super::BlockError; use super::persist::{BlockConstructorArgs, BlockState}; use super::vhost_user::device::{VhostUserBlock, VhostUserBlockConfig}; use super::virtio::device::{VirtioBlock, VirtioBlockConfig}; -use crate::devices::virtio::device::{IrqTrigger, VirtioDevice}; +use crate::devices::virtio::device::VirtioDevice; use crate::devices::virtio::queue::Queue; +use crate::devices::virtio::transport::VirtioInterrupt; use crate::devices::virtio::{ActivateError, TYPE_BLOCK}; use crate::rate_limiter::BucketUpdate; use crate::snapshot::Persist; @@ -173,10 +176,10 @@ impl VirtioDevice for Block { } } - fn interrupt_trigger(&self) -> &IrqTrigger { + fn interrupt_trigger(&self) -> &dyn VirtioInterrupt { match self { - Self::Virtio(b) => &b.irq_trigger, - Self::VhostUser(b) => &b.irq_trigger, + Self::Virtio(b) => b.interrupt_trigger(), + Self::VhostUser(b) => b.interrupt_trigger(), } } @@ -194,10 +197,14 @@ impl VirtioDevice for Block { } } - fn activate(&mut self, mem: GuestMemoryMmap) -> Result<(), ActivateError> { + fn activate( + &mut self, + mem: GuestMemoryMmap, + interrupt: Arc, + ) -> Result<(), ActivateError> { match self { - Self::Virtio(b) => b.activate(mem), - Self::VhostUser(b) => b.activate(mem), + Self::Virtio(b) => b.activate(mem, interrupt), + Self::VhostUser(b) => b.activate(mem, interrupt), } } diff --git a/src/vmm/src/devices/virtio/block/persist.rs b/src/vmm/src/devices/virtio/block/persist.rs index 2d83c416d9f..57712a8fb3a 100644 --- a/src/vmm/src/devices/virtio/block/persist.rs +++ b/src/vmm/src/devices/virtio/block/persist.rs @@ -1,10 +1,13 @@ // Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 +use std::sync::Arc; + use serde::{Deserialize, Serialize}; use super::vhost_user::persist::VhostUserBlockState; use super::virtio::persist::VirtioBlockState; +use crate::devices::virtio::transport::VirtioInterrupt; use crate::vstate::memory::GuestMemoryMmap; /// Block device state. @@ -18,4 +21,5 @@ pub enum BlockState { #[derive(Debug)] pub struct BlockConstructorArgs { pub mem: GuestMemoryMmap, + pub interrupt: Arc, } diff --git a/src/vmm/src/devices/virtio/block/vhost_user/device.rs b/src/vmm/src/devices/virtio/block/vhost_user/device.rs index b0bf5a31e3f..f32249c1cf9 100644 --- a/src/vmm/src/devices/virtio/block/vhost_user/device.rs +++ b/src/vmm/src/devices/virtio/block/vhost_user/device.rs @@ -4,6 +4,7 @@ // Portions Copyright 2019 Intel Corporation. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 +use std::ops::Deref; use std::sync::Arc; use log::error; @@ -14,11 +15,12 @@ use vmm_sys_util::eventfd::EventFd; use super::{NUM_QUEUES, QUEUE_SIZE, VhostUserBlockError}; use crate::devices::virtio::block::CacheType; -use crate::devices::virtio::device::{DeviceState, IrqTrigger, IrqType, VirtioDevice}; +use crate::devices::virtio::device::{ActiveState, DeviceState, VirtioDevice}; use crate::devices::virtio::generated::virtio_blk::{VIRTIO_BLK_F_FLUSH, VIRTIO_BLK_F_RO}; use crate::devices::virtio::generated::virtio_config::VIRTIO_F_VERSION_1; use crate::devices::virtio::generated::virtio_ring::VIRTIO_RING_F_EVENT_IDX; use crate::devices::virtio::queue::Queue; +use crate::devices::virtio::transport::{VirtioInterrupt, VirtioInterruptType}; use crate::devices::virtio::vhost_user::{VhostUserHandleBackend, VhostUserHandleImpl}; use crate::devices::virtio::vhost_user_metrics::{ VhostUserDeviceMetrics, VhostUserMetricsPerDevice, @@ -34,7 +36,7 @@ const BLOCK_CONFIG_SPACE_SIZE: u32 = 60; const AVAILABLE_FEATURES: u64 = (1 << VIRTIO_F_VERSION_1) | (1 << VIRTIO_RING_F_EVENT_IDX) - // vhost-user specific bit. Not defined in standart virtio spec. + // vhost-user specific bit. Not defined in standard virtio spec. // Specifies ability of frontend to negotiate protocol features. | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() // We always try to negotiate readonly with the backend. @@ -117,7 +119,6 @@ pub struct VhostUserBlockImpl { pub queues: Vec, pub queue_evts: [EventFd; u64_to_usize(NUM_QUEUES)], pub device_state: DeviceState, - pub irq_trigger: IrqTrigger, // Implementation specific fields. pub id: String, @@ -143,7 +144,6 @@ impl std::fmt::Debug for VhostUserBlockImpl { .field("queues", &self.queues) .field("queue_evts", &self.queue_evts) .field("device_state", &self.device_state) - .field("irq_trigger", &self.irq_trigger) .field("id", &self.id) .field("partuuid", &self.partuuid) .field("cache_type", &self.cache_type) @@ -203,7 +203,6 @@ impl VhostUserBlockImpl { let queue_evts = [EventFd::new(libc::EFD_NONBLOCK).map_err(VhostUserBlockError::EventFd)?; u64_to_usize(NUM_QUEUES)]; let device_state = DeviceState::Inactive; - let irq_trigger = IrqTrigger::new().map_err(VhostUserBlockError::IrqTrigger)?; // We negotiated features with backend. Now these acked_features // are available for guest driver to choose from. @@ -225,7 +224,6 @@ impl VhostUserBlockImpl { queues, queue_evts, device_state, - irq_trigger, id: config.drive_id, partuuid: config.partuuid, @@ -256,6 +254,12 @@ impl VhostUserBlockImpl { pub fn config_update(&mut self) -> Result<(), VhostUserBlockError> { let start_time = get_time_us(ClockType::Monotonic); + let interrupt = self + .device_state + .active_state() + .expect("Device is not initialized") + .interrupt + .clone(); // This buffer is used for config size check in vhost crate. let buffer = [0u8; BLOCK_CONFIG_SPACE_SIZE as usize]; @@ -270,9 +274,9 @@ impl VhostUserBlockImpl { ) .map_err(VhostUserBlockError::Vhost)?; self.config_space = new_config_space; - self.irq_trigger - .trigger_irq(IrqType::Config) - .map_err(VhostUserBlockError::IrqTrigger)?; + interrupt + .trigger(VirtioInterruptType::Config) + .map_err(VhostUserBlockError::Interrupt)?; let delta_us = get_time_us(ClockType::Monotonic) - start_time; self.metrics.config_change_time_us.store(delta_us); @@ -310,8 +314,12 @@ impl VirtioDevice for VhostUserBlock &self.queue_evts } - fn interrupt_trigger(&self) -> &IrqTrigger { - &self.irq_trigger + fn interrupt_trigger(&self) -> &dyn VirtioInterrupt { + self.device_state + .active_state() + .expect("Device is not initialized") + .interrupt + .deref() } fn read_config(&self, offset: u64, data: &mut [u8]) { @@ -330,7 +338,11 @@ impl VirtioDevice for VhostUserBlock // Other block config fields are immutable. } - fn activate(&mut self, mem: GuestMemoryMmap) -> Result<(), ActivateError> { + fn activate( + &mut self, + mem: GuestMemoryMmap, + interrupt: Arc, + ) -> Result<(), ActivateError> { for q in self.queues.iter_mut() { q.initialize(&mem) .map_err(ActivateError::QueueMemoryError)?; @@ -345,14 +357,14 @@ impl VirtioDevice for VhostUserBlock self.vu_handle.setup_backend( &mem, &[(0, &self.queues[0], &self.queue_evts[0])], - &self.irq_trigger, + interrupt.clone(), ) }) .map_err(|err| { self.metrics.activate_fails.inc(); ActivateError::VhostUser(err) })?; - self.device_state = DeviceState::Activated(mem); + self.device_state = DeviceState::Activated(ActiveState { mem, interrupt }); let delta_us = get_time_us(ClockType::Monotonic) - start_time; self.metrics.activate_time_us.store(delta_us); Ok(()) @@ -375,7 +387,8 @@ mod tests { use super::*; use crate::devices::virtio::block::virtio::device::FileEngineType; - use crate::devices::virtio::mmio::VIRTIO_MMIO_INT_CONFIG; + use crate::devices::virtio::test_utils::{default_interrupt, default_mem}; + use crate::devices::virtio::transport::mmio::VIRTIO_MMIO_INT_CONFIG; use crate::devices::virtio::vhost_user::tests::create_mem; use crate::test_utils::create_tmp_socket; use crate::vstate::memory::GuestAddress; @@ -651,6 +664,10 @@ mod tests { assert_eq!(vhost_block.config_space, vec![0x69, 0x69, 0x69]); // Testing [`config_update`] + vhost_block.device_state = DeviceState::Activated(ActiveState { + mem: default_mem(), + interrupt: default_interrupt(), + }); vhost_block.config_space = vec![]; vhost_block.config_update().unwrap(); assert_eq!(vhost_block.config_space, vec![0x69, 0x69, 0x69]); @@ -780,9 +797,10 @@ mod tests { file.set_len(region_size as u64).unwrap(); let regions = vec![(GuestAddress(0x0), region_size)]; let guest_memory = create_mem(file, ®ions); + let interrupt = default_interrupt(); // During actiavion of the device features, memory and queues should be set and activated. - vhost_block.activate(guest_memory).unwrap(); + vhost_block.activate(guest_memory, interrupt).unwrap(); assert!(unsafe { *vhost_block.vu_handle.vu.features_are_set.get() }); assert!(unsafe { *vhost_block.vu_handle.vu.memory_is_set.get() }); assert!(unsafe { *vhost_block.vu_handle.vu.vring_enabled.get() }); diff --git a/src/vmm/src/devices/virtio/block/vhost_user/mod.rs b/src/vmm/src/devices/virtio/block/vhost_user/mod.rs index 8d4d9f44261..0afaaed3400 100644 --- a/src/vmm/src/devices/virtio/block/vhost_user/mod.rs +++ b/src/vmm/src/devices/virtio/block/vhost_user/mod.rs @@ -28,5 +28,5 @@ pub enum VhostUserBlockError { /// Error opening eventfd: {0} EventFd(std::io::Error), /// Error creating irqfd: {0} - IrqTrigger(std::io::Error), + Interrupt(std::io::Error), } diff --git a/src/vmm/src/devices/virtio/block/virtio/device.rs b/src/vmm/src/devices/virtio/block/virtio/device.rs index b11c757d43c..aa28a325e1c 100644 --- a/src/vmm/src/devices/virtio/block/virtio/device.rs +++ b/src/vmm/src/devices/virtio/block/virtio/device.rs @@ -9,6 +9,7 @@ use std::cmp; use std::convert::From; use std::fs::{File, OpenOptions}; use std::io::{Seek, SeekFrom}; +use std::ops::Deref; use std::os::linux::fs::MetadataExt; use std::path::PathBuf; use std::sync::Arc; @@ -23,13 +24,14 @@ use super::request::*; use super::{BLOCK_QUEUE_SIZES, SECTOR_SHIFT, SECTOR_SIZE, VirtioBlockError, io as block_io}; use crate::devices::virtio::block::CacheType; use crate::devices::virtio::block::virtio::metrics::{BlockDeviceMetrics, BlockMetricsPerDevice}; -use crate::devices::virtio::device::{DeviceState, IrqTrigger, IrqType, VirtioDevice}; +use crate::devices::virtio::device::{ActiveState, DeviceState, VirtioDevice}; use crate::devices::virtio::generated::virtio_blk::{ VIRTIO_BLK_F_FLUSH, VIRTIO_BLK_F_RO, VIRTIO_BLK_ID_BYTES, }; use crate::devices::virtio::generated::virtio_config::VIRTIO_F_VERSION_1; use crate::devices::virtio::generated::virtio_ring::VIRTIO_RING_F_EVENT_IDX; use crate::devices::virtio::queue::Queue; +use crate::devices::virtio::transport::{VirtioInterrupt, VirtioInterruptType}; use crate::devices::virtio::{ActivateError, TYPE_BLOCK}; use crate::logger::{IncMetric, error, warn}; use crate::rate_limiter::{BucketUpdate, RateLimiter}; @@ -249,7 +251,6 @@ pub struct VirtioBlock { pub queues: Vec, pub queue_evts: [EventFd; 1], pub device_state: DeviceState, - pub irq_trigger: IrqTrigger, // Implementation specific fields. pub id: String, @@ -322,7 +323,6 @@ impl VirtioBlock { queues, queue_evts, device_state: DeviceState::Inactive, - irq_trigger: IrqTrigger::new().map_err(VirtioBlockError::IrqTrigger)?, id: config.drive_id.clone(), partuuid: config.partuuid, @@ -388,7 +388,7 @@ impl VirtioBlock { queue: &mut Queue, index: u16, len: u32, - irq_trigger: &IrqTrigger, + interrupt: &dyn VirtioInterrupt, block_metrics: &BlockDeviceMetrics, ) { queue.add_used(index, len).unwrap_or_else(|err| { @@ -396,44 +396,52 @@ impl VirtioBlock { }); if queue.prepare_kick() { - irq_trigger.trigger_irq(IrqType::Vring).unwrap_or_else(|_| { - block_metrics.event_fails.inc(); - }); + interrupt + .trigger(VirtioInterruptType::Queue(index)) + .unwrap_or_else(|_| { + block_metrics.event_fails.inc(); + }); } } /// Device specific function for peaking inside a queue and processing descriptors. pub fn process_queue(&mut self, queue_index: usize) { // This is safe since we checked in the event handler that the device is activated. - let mem = self.device_state.mem().unwrap(); + let active_state = self.device_state.active_state().unwrap(); let queue = &mut self.queues[queue_index]; let mut used_any = false; while let Some(head) = queue.pop_or_enable_notification() { self.metrics.remaining_reqs_count.add(queue.len().into()); - let processing_result = match Request::parse(&head, mem, self.disk.nsectors) { - Ok(request) => { - if request.rate_limit(&mut self.rate_limiter) { - // Stop processing the queue and return this descriptor chain to the - // avail ring, for later processing. - queue.undo_pop(); - self.metrics.rate_limiter_throttled_events.inc(); - break; + let processing_result = + match Request::parse(&head, &active_state.mem, self.disk.nsectors) { + Ok(request) => { + if request.rate_limit(&mut self.rate_limiter) { + // Stop processing the queue and return this descriptor chain to the + // avail ring, for later processing. + queue.undo_pop(); + self.metrics.rate_limiter_throttled_events.inc(); + break; + } + + used_any = true; + request.process( + &mut self.disk, + head.index, + &active_state.mem, + &self.metrics, + ) } - - used_any = true; - request.process(&mut self.disk, head.index, mem, &self.metrics) - } - Err(err) => { - error!("Failed to parse available descriptor chain: {:?}", err); - self.metrics.execute_fails.inc(); - ProcessingResult::Executed(FinishedRequest { - num_bytes_to_mem: 0, - desc_idx: head.index, - }) - } - }; + Err(err) => { + error!("Failed to parse available descriptor chain: {:?}", err); + self.metrics.execute_fails.inc(); + ProcessingResult::Executed(FinishedRequest { + num_bytes_to_mem: 0, + desc_idx: head.index, + }) + } + }; match processing_result { ProcessingResult::Submitted => {} @@ -447,7 +455,7 @@ impl VirtioBlock { queue, head.index, finished.num_bytes_to_mem, - &self.irq_trigger, + active_state.interrupt.deref(), &self.metrics, ); } @@ -469,11 +477,11 @@ impl VirtioBlock { let engine = unwrap_async_file_engine_or_return!(&mut self.disk.file_engine); // This is safe since we checked in the event handler that the device is activated. - let mem = self.device_state.mem().unwrap(); + let active_state = self.device_state.active_state().unwrap(); let queue = &mut self.queues[0]; loop { - match engine.pop(mem) { + match engine.pop(&active_state.mem) { Err(error) => { error!("Failed to read completed io_uring entry: {:?}", error); break; @@ -492,13 +500,13 @@ impl VirtioBlock { ))), ), }; - let finished = pending.finish(mem, res, &self.metrics); + let finished = pending.finish(&active_state.mem, res, &self.metrics); Self::add_used_descriptor( queue, finished.desc_idx, finished.num_bytes_to_mem, - &self.irq_trigger, + active_state.interrupt.deref(), &self.metrics, ); } @@ -526,8 +534,12 @@ impl VirtioBlock { self.disk.update(disk_image_path, self.read_only)?; self.config_space.capacity = self.disk.nsectors.to_le(); // virtio_block_config_space(); - // Kick the driver to pick up the changes. - self.irq_trigger.trigger_irq(IrqType::Config).unwrap(); + // Kick the driver to pick up the changes. (But only if the device is already activated). + if self.is_activated() { + self.interrupt_trigger() + .trigger(VirtioInterruptType::Config) + .unwrap(); + } self.metrics.update_count.inc(); Ok(()) @@ -594,8 +606,12 @@ impl VirtioDevice for VirtioBlock { &self.queue_evts } - fn interrupt_trigger(&self) -> &IrqTrigger { - &self.irq_trigger + fn interrupt_trigger(&self) -> &dyn VirtioInterrupt { + self.device_state + .active_state() + .expect("Device is not initialized") + .interrupt + .deref() } fn read_config(&self, offset: u64, data: &mut [u8]) { @@ -624,7 +640,11 @@ impl VirtioDevice for VirtioBlock { dst.copy_from_slice(data); } - fn activate(&mut self, mem: GuestMemoryMmap) -> Result<(), ActivateError> { + fn activate( + &mut self, + mem: GuestMemoryMmap, + interrupt: Arc, + ) -> Result<(), ActivateError> { for q in self.queues.iter_mut() { q.initialize(&mem) .map_err(ActivateError::QueueMemoryError)?; @@ -641,7 +661,7 @@ impl VirtioDevice for VirtioBlock { self.metrics.activate_fails.inc(); return Err(ActivateError::EventFd); } - self.device_state = DeviceState::Activated(mem); + self.device_state = DeviceState::Activated(ActiveState { mem, interrupt }); Ok(()) } @@ -684,7 +704,7 @@ mod tests { simulate_queue_event, }; use crate::devices::virtio::queue::{VIRTQ_DESC_F_NEXT, VIRTQ_DESC_F_WRITE}; - use crate::devices::virtio::test_utils::{VirtQueue, default_mem}; + use crate::devices::virtio::test_utils::{VirtQueue, default_interrupt, default_mem}; use crate::rate_limiter::TokenType; use crate::vstate::memory::{Address, Bytes, GuestAddress}; @@ -826,7 +846,7 @@ mod tests { block.read_config(0, actual_config_space.as_mut_slice()); assert_eq!(actual_config_space, expected_config_space); - // If priviledged user writes to `/dev/mem`, in block config space - byte by byte. + // If privileged user writes to `/dev/mem`, in block config space - byte by byte. let expected_config_space = ConfigSpace { capacity: 0x1122334455667788, }; @@ -859,9 +879,10 @@ mod tests { for engine in [FileEngineType::Sync, FileEngineType::Async] { let mut block = default_block(engine); let mem = default_mem(); + let interrupt = default_interrupt(); let vq = VirtQueue::new(GuestAddress(0), &mem, 16); set_queue(&mut block, 0, vq.create_queue()); - block.activate(mem.clone()).unwrap(); + block.activate(mem.clone(), interrupt).unwrap(); read_blk_req_descriptors(&vq); let request_type_addr = GuestAddress(vq.dtable[0].addr.get()); @@ -887,9 +908,10 @@ mod tests { let mut block = default_block(engine); // Default mem size is 0x10000 let mem = default_mem(); + let interrupt = default_interrupt(); let vq = VirtQueue::new(GuestAddress(0), &mem, 16); set_queue(&mut block, 0, vq.create_queue()); - block.activate(mem.clone()).unwrap(); + block.activate(mem.clone(), interrupt).unwrap(); read_blk_req_descriptors(&vq); let request_type_addr = GuestAddress(vq.dtable[0].addr.get()); @@ -950,9 +972,10 @@ mod tests { for engine in [FileEngineType::Sync, FileEngineType::Async] { let mut block = default_block(engine); let mem = default_mem(); + let interrupt = default_interrupt(); let vq = VirtQueue::new(GuestAddress(0), &mem, 16); set_queue(&mut block, 0, vq.create_queue()); - block.activate(mem.clone()).unwrap(); + block.activate(mem.clone(), interrupt).unwrap(); read_blk_req_descriptors(&vq); let request_type_addr = GuestAddress(vq.dtable[0].addr.get()); @@ -1001,9 +1024,10 @@ mod tests { for engine in [FileEngineType::Sync, FileEngineType::Async] { let mut block = default_block(engine); let mem = default_mem(); + let interrupt = default_interrupt(); let vq = VirtQueue::new(GuestAddress(0), &mem, 16); set_queue(&mut block, 0, vq.create_queue()); - block.activate(mem.clone()).unwrap(); + block.activate(mem.clone(), interrupt).unwrap(); read_blk_req_descriptors(&vq); let request_type_addr = GuestAddress(vq.dtable[0].addr.get()); @@ -1033,9 +1057,10 @@ mod tests { for engine in [FileEngineType::Sync, FileEngineType::Async] { let mut block = default_block(engine); let mem = default_mem(); + let interrupt = default_interrupt(); let vq = VirtQueue::new(GuestAddress(0), &mem, 16); set_queue(&mut block, 0, vq.create_queue()); - block.activate(mem.clone()).unwrap(); + block.activate(mem.clone(), interrupt).unwrap(); read_blk_req_descriptors(&vq); vq.dtable[1].set(0xf000, 0x1000, VIRTQ_DESC_F_NEXT | VIRTQ_DESC_F_WRITE, 2); @@ -1069,9 +1094,10 @@ mod tests { for engine in [FileEngineType::Sync, FileEngineType::Async] { let mut block = default_block(engine); let mem = default_mem(); + let interrupt = default_interrupt(); let vq = VirtQueue::new(GuestAddress(0), &mem, 16); set_queue(&mut block, 0, vq.create_queue()); - block.activate(mem.clone()).unwrap(); + block.activate(mem.clone(), interrupt).unwrap(); read_blk_req_descriptors(&vq); let request_type_addr = GuestAddress(vq.dtable[0].addr.get()); @@ -1116,9 +1142,10 @@ mod tests { // Default mem size is 0x10000 let mem = default_mem(); + let interrupt = default_interrupt(); let vq = VirtQueue::new(GuestAddress(0), &mem, 16); set_queue(&mut block, 0, vq.create_queue()); - block.activate(mem.clone()).unwrap(); + block.activate(mem.clone(), interrupt).unwrap(); read_blk_req_descriptors(&vq); let request_type_addr = GuestAddress(vq.dtable[0].addr.get()); @@ -1355,9 +1382,10 @@ mod tests { { // Default mem size is 0x10000 let mem = default_mem(); + let interrupt = default_interrupt(); let vq = VirtQueue::new(GuestAddress(0), &mem, 16); set_queue(&mut block, 0, vq.create_queue()); - block.activate(mem.clone()).unwrap(); + block.activate(mem.clone(), interrupt).unwrap(); read_blk_req_descriptors(&vq); vq.dtable[1].set(0xff00, 0x1000, VIRTQ_DESC_F_NEXT | VIRTQ_DESC_F_WRITE, 2); @@ -1396,9 +1424,10 @@ mod tests { for engine in [FileEngineType::Sync, FileEngineType::Async] { let mut block = default_block(engine); let mem = default_mem(); + let interrupt = default_interrupt(); let vq = VirtQueue::new(GuestAddress(0), &mem, 16); set_queue(&mut block, 0, vq.create_queue()); - block.activate(mem.clone()).unwrap(); + block.activate(mem.clone(), interrupt).unwrap(); read_blk_req_descriptors(&vq); let request_type_addr = GuestAddress(vq.dtable[0].addr.get()); @@ -1442,9 +1471,10 @@ mod tests { for engine in [FileEngineType::Sync, FileEngineType::Async] { let mut block = default_block(engine); let mem = default_mem(); + let interrupt = default_interrupt(); let vq = VirtQueue::new(GuestAddress(0), &mem, 16); set_queue(&mut block, 0, vq.create_queue()); - block.activate(mem.clone()).unwrap(); + block.activate(mem.clone(), interrupt).unwrap(); read_blk_req_descriptors(&vq); let request_type_addr = GuestAddress(vq.dtable[0].addr.get()); @@ -1566,8 +1596,9 @@ mod tests { let mut block = default_block(FileEngineType::Async); let mem = default_mem(); + let interrupt = default_interrupt(); let vq = VirtQueue::new(GuestAddress(0), &mem, IO_URING_NUM_ENTRIES * 4); - block.activate(mem.clone()).unwrap(); + block.activate(mem.clone(), interrupt).unwrap(); // Run scenario that doesn't trigger FullSq BlockError: Add sq_size flush requests. add_flush_requests_batch(&mut block, &vq, IO_URING_NUM_ENTRIES); @@ -1599,8 +1630,9 @@ mod tests { let mut block = default_block(FileEngineType::Async); let mem = default_mem(); + let interrupt = default_interrupt(); let vq = VirtQueue::new(GuestAddress(0), &mem, IO_URING_NUM_ENTRIES * 4); - block.activate(mem.clone()).unwrap(); + block.activate(mem.clone(), interrupt).unwrap(); // Run scenario that triggers FullCqError. Push 2 * IO_URING_NUM_ENTRIES and wait for // completion. Then try to push another entry. @@ -1628,8 +1660,9 @@ mod tests { let mut block = default_block(engine); let mem = default_mem(); + let interrupt = default_interrupt(); let vq = VirtQueue::new(GuestAddress(0), &mem, 16); - block.activate(mem.clone()).unwrap(); + block.activate(mem.clone(), interrupt).unwrap(); // Add a batch of flush requests. add_flush_requests_batch(&mut block, &vq, 5); @@ -1646,9 +1679,10 @@ mod tests { for engine in [FileEngineType::Sync, FileEngineType::Async] { let mut block = default_block(engine); let mem = default_mem(); + let interrupt = default_interrupt(); let vq = VirtQueue::new(GuestAddress(0), &mem, 16); set_queue(&mut block, 0, vq.create_queue()); - block.activate(mem.clone()).unwrap(); + block.activate(mem.clone(), interrupt).unwrap(); read_blk_req_descriptors(&vq); let request_type_addr = GuestAddress(vq.dtable[0].addr.get()); @@ -1715,9 +1749,10 @@ mod tests { for engine in [FileEngineType::Sync, FileEngineType::Async] { let mut block = default_block(engine); let mem = default_mem(); + let interrupt = default_interrupt(); let vq = VirtQueue::new(GuestAddress(0), &mem, 16); set_queue(&mut block, 0, vq.create_queue()); - block.activate(mem.clone()).unwrap(); + block.activate(mem.clone(), interrupt).unwrap(); read_blk_req_descriptors(&vq); let request_type_addr = GuestAddress(vq.dtable[0].addr.get()); @@ -1797,6 +1832,7 @@ mod tests { fn test_update_disk_image() { for engine in [FileEngineType::Sync, FileEngineType::Async] { let mut block = default_block(engine); + block.activate(default_mem(), default_interrupt()).unwrap(); let f = TempFile::new().unwrap(); let path = f.as_path(); let mdata = metadata(path).unwrap(); diff --git a/src/vmm/src/devices/virtio/block/virtio/event_handler.rs b/src/vmm/src/devices/virtio/block/virtio/event_handler.rs index db69e23d7f0..03c09a01972 100644 --- a/src/vmm/src/devices/virtio/block/virtio/event_handler.rs +++ b/src/vmm/src/devices/virtio/block/virtio/event_handler.rs @@ -124,7 +124,7 @@ mod tests { }; use crate::devices::virtio::block::virtio::{VIRTIO_BLK_S_OK, VIRTIO_BLK_T_OUT}; use crate::devices::virtio::queue::VIRTQ_DESC_F_NEXT; - use crate::devices::virtio::test_utils::{VirtQueue, default_mem}; + use crate::devices::virtio::test_utils::{VirtQueue, default_interrupt, default_mem}; use crate::vstate::memory::{Bytes, GuestAddress}; #[test] @@ -132,6 +132,7 @@ mod tests { let mut event_manager = EventManager::new().unwrap(); let mut block = default_block(FileEngineType::default()); let mem = default_mem(); + let interrupt = default_interrupt(); let vq = VirtQueue::new(GuestAddress(0), &mem, 16); set_queue(&mut block, 0, vq.create_queue()); read_blk_req_descriptors(&vq); @@ -162,7 +163,11 @@ mod tests { assert_eq!(ev_count, 0); // Now activate the device. - block.lock().unwrap().activate(mem.clone()).unwrap(); + block + .lock() + .unwrap() + .activate(mem.clone(), interrupt) + .unwrap(); // Process the activate event. let ev_count = event_manager.run_with_timeout(50).unwrap(); assert_eq!(ev_count, 1); diff --git a/src/vmm/src/devices/virtio/block/virtio/mod.rs b/src/vmm/src/devices/virtio/block/virtio/mod.rs index 8ea59a5aba4..9e97d6d3897 100644 --- a/src/vmm/src/devices/virtio/block/virtio/mod.rs +++ b/src/vmm/src/devices/virtio/block/virtio/mod.rs @@ -57,8 +57,8 @@ pub enum VirtioBlockError { BackingFile(std::io::Error, String), /// Error opening eventfd: {0} EventFd(std::io::Error), - /// Error creating an irqfd: {0} - IrqTrigger(std::io::Error), + /// Error creating an interrupt: {0} + Interrupt(std::io::Error), /// Error coming from the rate limiter: {0} RateLimiter(std::io::Error), /// Persistence error: {0} diff --git a/src/vmm/src/devices/virtio/block/virtio/persist.rs b/src/vmm/src/devices/virtio/block/virtio/persist.rs index 8c6f2c2453d..57e4a11b9c1 100644 --- a/src/vmm/src/devices/virtio/block/virtio/persist.rs +++ b/src/vmm/src/devices/virtio/block/virtio/persist.rs @@ -3,9 +3,6 @@ //! Defines the structures needed for saving/restoring block devices. -use std::sync::Arc; -use std::sync::atomic::AtomicU32; - use device::ConfigSpace; use serde::{Deserialize, Serialize}; use vmm_sys_util::eventfd::EventFd; @@ -16,7 +13,7 @@ use crate::devices::virtio::TYPE_BLOCK; use crate::devices::virtio::block::persist::BlockConstructorArgs; use crate::devices::virtio::block::virtio::device::FileEngineType; use crate::devices::virtio::block::virtio::metrics::BlockMetricsPerDevice; -use crate::devices::virtio::device::{DeviceState, IrqTrigger}; +use crate::devices::virtio::device::{ActiveState, DeviceState}; use crate::devices::virtio::generated::virtio_blk::VIRTIO_BLK_F_RO; use crate::devices::virtio::persist::VirtioDeviceState; use crate::rate_limiter::RateLimiter; @@ -111,14 +108,14 @@ impl Persist<'_> for VirtioBlock { ) .map_err(VirtioBlockError::Persist)?; - let mut irq_trigger = IrqTrigger::new().map_err(VirtioBlockError::IrqTrigger)?; - irq_trigger.irq_status = Arc::new(AtomicU32::new(state.virtio_state.interrupt_status)); - let avail_features = state.virtio_state.avail_features; let acked_features = state.virtio_state.acked_features; let device_state = if state.virtio_state.activated { - DeviceState::Activated(constructor_args.mem) + DeviceState::Activated(ActiveState { + mem: constructor_args.mem, + interrupt: constructor_args.interrupt, + }) } else { DeviceState::Inactive }; @@ -136,7 +133,6 @@ impl Persist<'_> for VirtioBlock { queues, queue_evts, device_state, - irq_trigger, id: state.id.clone(), partuuid: state.partuuid.clone(), @@ -154,14 +150,12 @@ impl Persist<'_> for VirtioBlock { #[cfg(test)] mod tests { - use std::sync::atomic::Ordering; - use vmm_sys_util::tempfile::TempFile; use super::*; use crate::devices::virtio::block::virtio::device::VirtioBlockConfig; use crate::devices::virtio::device::VirtioDevice; - use crate::devices::virtio::test_utils::default_mem; + use crate::devices::virtio::test_utils::{default_interrupt, default_mem}; use crate::snapshot::Snapshot; #[test] @@ -233,7 +227,10 @@ mod tests { // Restore the block device. let restored_block = VirtioBlock::restore( - BlockConstructorArgs { mem: guest_mem }, + BlockConstructorArgs { + mem: guest_mem, + interrupt: default_interrupt(), + }, &Snapshot::deserialize(&mut mem.as_slice()).unwrap(), ) .unwrap(); @@ -243,11 +240,8 @@ mod tests { assert_eq!(restored_block.avail_features(), block.avail_features()); assert_eq!(restored_block.acked_features(), block.acked_features()); assert_eq!(restored_block.queues(), block.queues()); - assert_eq!( - restored_block.interrupt_status().load(Ordering::Relaxed), - block.interrupt_status().load(Ordering::Relaxed) - ); - assert_eq!(restored_block.is_activated(), block.is_activated()); + assert!(!block.is_activated()); + assert!(!restored_block.is_activated()); // Test that block specific fields are the same. assert_eq!(restored_block.disk.file_path, block.disk.file_path); diff --git a/src/vmm/src/devices/virtio/block/virtio/test_utils.rs b/src/vmm/src/devices/virtio/block/virtio/test_utils.rs index 02dd34fbce9..e4f23c6a038 100644 --- a/src/vmm/src/devices/virtio/block/virtio/test_utils.rs +++ b/src/vmm/src/devices/virtio/block/virtio/test_utils.rs @@ -17,9 +17,11 @@ use crate::devices::virtio::block::virtio::device::FileEngineType; use crate::devices::virtio::block::virtio::io::FileEngine; use crate::devices::virtio::block::virtio::{CacheType, VirtioBlock}; #[cfg(test)] -use crate::devices::virtio::device::IrqType; +use crate::devices::virtio::device::VirtioDevice; use crate::devices::virtio::queue::{Queue, VIRTQ_DESC_F_NEXT, VIRTQ_DESC_F_WRITE}; use crate::devices::virtio::test_utils::{VirtQueue, VirtqDesc}; +#[cfg(test)] +use crate::devices::virtio::transport::VirtioInterruptType; use crate::rate_limiter::RateLimiter; use crate::vmm_config::{RateLimiterConfig, TokenBucketConfig}; use crate::vstate::memory::{Bytes, GuestAddress}; @@ -77,12 +79,17 @@ pub fn rate_limiter(blk: &mut VirtioBlock) -> &RateLimiter { #[cfg(test)] pub fn simulate_queue_event(b: &mut VirtioBlock, maybe_expected_irq: Option) { // Trigger the queue event. + b.queue_evts[0].write(1).unwrap(); // Handle event. b.process_queue_event(); // Validate the queue operation finished successfully. if let Some(expected_irq) = maybe_expected_irq { - assert_eq!(b.irq_trigger.has_pending_irq(IrqType::Vring), expected_irq); + assert_eq!( + b.interrupt_trigger() + .has_pending_interrupt(VirtioInterruptType::Queue(0)), + expected_irq + ); } } @@ -98,7 +105,11 @@ pub fn simulate_async_completion_event(b: &mut VirtioBlock, expected_irq: bool) } // Validate if there are pending IRQs. - assert_eq!(b.irq_trigger.has_pending_irq(IrqType::Vring), expected_irq); + assert_eq!( + b.interrupt_trigger() + .has_pending_interrupt(VirtioInterruptType::Queue(0)), + expected_irq + ); } #[cfg(test)] diff --git a/src/vmm/src/devices/virtio/device.rs b/src/vmm/src/devices/virtio/device.rs index 62131e775f5..083cd1bb54f 100644 --- a/src/vmm/src/devices/virtio/device.rs +++ b/src/vmm/src/devices/virtio/device.rs @@ -7,23 +7,30 @@ use std::fmt; use std::sync::Arc; -use std::sync::atomic::{AtomicU32, Ordering}; +use std::sync::atomic::AtomicU32; use vmm_sys_util::eventfd::EventFd; use super::ActivateError; -use super::mmio::{VIRTIO_MMIO_INT_CONFIG, VIRTIO_MMIO_INT_VRING}; use super::queue::{Queue, QueueError}; +use super::transport::VirtioInterrupt; use crate::devices::virtio::AsAny; -use crate::logger::{error, warn}; +use crate::logger::warn; use crate::vstate::memory::GuestMemoryMmap; +/// State of an active VirtIO device +#[derive(Debug, Clone)] +pub struct ActiveState { + pub mem: GuestMemoryMmap, + pub interrupt: Arc, +} + /// Enum that indicates if a VirtioDevice is inactive or has been activated /// and memory attached to it. #[derive(Debug)] pub enum DeviceState { Inactive, - Activated(GuestMemoryMmap), + Activated(ActiveState), } impl DeviceState { @@ -35,55 +42,15 @@ impl DeviceState { } } - /// Gets the memory attached to the device if it is activated. - pub fn mem(&self) -> Option<&GuestMemoryMmap> { + /// Gets the memory and interrupt attached to the device if it is activated. + pub fn active_state(&self) -> Option<&ActiveState> { match self { - DeviceState::Activated(mem) => Some(mem), + DeviceState::Activated(state) => Some(state), DeviceState::Inactive => None, } } } -/// The 2 types of interrupt sources in MMIO transport. -#[derive(Debug)] -pub enum IrqType { - /// Interrupt triggered by change in config. - Config, - /// Interrupt triggered by used vring buffers. - Vring, -} - -/// Helper struct that is responsible for triggering guest IRQs -#[derive(Debug)] -pub struct IrqTrigger { - pub(crate) irq_status: Arc, - pub(crate) irq_evt: EventFd, -} - -impl IrqTrigger { - pub fn new() -> std::io::Result { - Ok(Self { - irq_status: Arc::new(AtomicU32::new(0)), - irq_evt: EventFd::new(libc::EFD_NONBLOCK)?, - }) - } - - pub fn trigger_irq(&self, irq_type: IrqType) -> Result<(), std::io::Error> { - let irq = match irq_type { - IrqType::Config => VIRTIO_MMIO_INT_CONFIG, - IrqType::Vring => VIRTIO_MMIO_INT_VRING, - }; - self.irq_status.fetch_or(irq, Ordering::SeqCst); - - self.irq_evt.write(1).map_err(|err| { - error!("Failed to send irq to the guest: {:?}", err); - err - })?; - - Ok(()) - } -} - /// Trait for virtio devices to be driven by a virtio transport. /// /// The lifecycle of a virtio device is to be moved to a virtio transport, which will then query the @@ -121,10 +88,10 @@ pub trait VirtioDevice: AsAny + Send { /// Returns the current device interrupt status. fn interrupt_status(&self) -> Arc { - Arc::clone(&self.interrupt_trigger().irq_status) + self.interrupt_trigger().status() } - fn interrupt_trigger(&self) -> &IrqTrigger; + fn interrupt_trigger(&self) -> &dyn VirtioInterrupt; /// The set of feature bits shifted by `page * 32`. fn avail_features_by_page(&self, page: u32) -> u32 { @@ -170,7 +137,11 @@ pub trait VirtioDevice: AsAny + Send { fn write_config(&mut self, offset: u64, data: &[u8]); /// Performs the formal activation for a device, which can be verified also with `is_activated`. - fn activate(&mut self, mem: GuestMemoryMmap) -> Result<(), ActivateError>; + fn activate( + &mut self, + mem: GuestMemoryMmap, + interrupt: Arc, + ) -> Result<(), ActivateError>; /// Checks if the resources of this device are activated. fn is_activated(&self) -> bool; @@ -200,47 +171,6 @@ impl fmt::Debug for dyn VirtioDevice { pub(crate) mod tests { use super::*; - impl IrqTrigger { - pub fn has_pending_irq(&self, irq_type: IrqType) -> bool { - if let Ok(num_irqs) = self.irq_evt.read() { - if num_irqs == 0 { - return false; - } - - let irq_status = self.irq_status.load(Ordering::SeqCst); - return matches!( - (irq_status, irq_type), - (VIRTIO_MMIO_INT_CONFIG, IrqType::Config) - | (VIRTIO_MMIO_INT_VRING, IrqType::Vring) - ); - } - - false - } - } - - #[test] - fn irq_trigger() { - let irq_trigger = IrqTrigger::new().unwrap(); - assert_eq!(irq_trigger.irq_status.load(Ordering::SeqCst), 0); - - // Check that there are no pending irqs. - assert!(!irq_trigger.has_pending_irq(IrqType::Config)); - assert!(!irq_trigger.has_pending_irq(IrqType::Vring)); - - // Check that trigger_irq() correctly generates irqs. - irq_trigger.trigger_irq(IrqType::Config).unwrap(); - assert!(irq_trigger.has_pending_irq(IrqType::Config)); - irq_trigger.irq_status.store(0, Ordering::SeqCst); - irq_trigger.trigger_irq(IrqType::Vring).unwrap(); - assert!(irq_trigger.has_pending_irq(IrqType::Vring)); - - // Check trigger_irq() failure case (irq_evt is full). - irq_trigger.irq_evt.write(u64::MAX - 1).unwrap(); - irq_trigger.trigger_irq(IrqType::Config).unwrap_err(); - irq_trigger.trigger_irq(IrqType::Vring).unwrap_err(); - } - #[derive(Debug)] struct MockVirtioDevice { acked_features: u64, @@ -275,7 +205,7 @@ pub(crate) mod tests { todo!() } - fn interrupt_trigger(&self) -> &IrqTrigger { + fn interrupt_trigger(&self) -> &dyn VirtioInterrupt { todo!() } @@ -287,7 +217,11 @@ pub(crate) mod tests { todo!() } - fn activate(&mut self, _mem: GuestMemoryMmap) -> Result<(), ActivateError> { + fn activate( + &mut self, + _mem: GuestMemoryMmap, + _interrupt: Arc, + ) -> Result<(), ActivateError> { todo!() } diff --git a/src/vmm/src/devices/virtio/mod.rs b/src/vmm/src/devices/virtio/mod.rs index f298d28e9bd..0ac3b660397 100644 --- a/src/vmm/src/devices/virtio/mod.rs +++ b/src/vmm/src/devices/virtio/mod.rs @@ -18,12 +18,12 @@ pub mod device; pub mod generated; mod iov_deque; pub mod iovec; -pub mod mmio; pub mod net; pub mod persist; pub mod queue; pub mod rng; pub mod test_utils; +pub mod transport; pub mod vhost_user; pub mod vhost_user_metrics; pub mod vsock; diff --git a/src/vmm/src/devices/virtio/net/device.rs b/src/vmm/src/devices/virtio/net/device.rs index fff04d1da1a..e8c0135263c 100755 --- a/src/vmm/src/devices/virtio/net/device.rs +++ b/src/vmm/src/devices/virtio/net/device.rs @@ -8,6 +8,7 @@ use std::collections::VecDeque; use std::mem::{self}; use std::net::Ipv4Addr; +use std::ops::Deref; use std::sync::{Arc, Mutex}; use libc::{EAGAIN, iovec}; @@ -15,7 +16,7 @@ use log::error; use vmm_sys_util::eventfd::EventFd; use super::NET_QUEUE_MAX_SIZE; -use crate::devices::virtio::device::{DeviceState, IrqTrigger, IrqType, VirtioDevice}; +use crate::devices::virtio::device::{ActiveState, DeviceState, VirtioDevice}; use crate::devices::virtio::generated::virtio_config::VIRTIO_F_VERSION_1; use crate::devices::virtio::generated::virtio_net::{ VIRTIO_NET_F_CSUM, VIRTIO_NET_F_GUEST_CSUM, VIRTIO_NET_F_GUEST_TSO4, VIRTIO_NET_F_GUEST_TSO6, @@ -32,6 +33,7 @@ use crate::devices::virtio::net::{ MAX_BUFFER_SIZE, NET_QUEUE_SIZES, NetError, NetQueue, RX_INDEX, TX_INDEX, generated, }; use crate::devices::virtio::queue::{DescriptorChain, Queue}; +use crate::devices::virtio::transport::{VirtioInterrupt, VirtioInterruptType}; use crate::devices::virtio::{ActivateError, TYPE_NET}; use crate::devices::{DeviceError, report_net_event_fail}; use crate::dumbo::pdu::arp::ETH_IPV4_FRAME_LEN; @@ -249,8 +251,6 @@ pub struct Net { tx_frame_headers: [u8; frame_hdr_len()], - pub(crate) irq_trigger: IrqTrigger, - pub(crate) config_space: ConfigSpace, pub(crate) guest_mac: Option, @@ -313,7 +313,6 @@ impl Net { tx_rate_limiter, rx_frame_buf: [0u8; MAX_BUFFER_SIZE], tx_frame_headers: [0u8; frame_hdr_len()], - irq_trigger: IrqTrigger::new().map_err(NetError::EventFd)?, config_space, guest_mac, device_state: DeviceState::Inactive, @@ -392,14 +391,14 @@ impl Net { /// https://docs.oasis-open.org/virtio/virtio/v1.1/csprd01/virtio-v1.1-csprd01.html#x1-320005 /// 2.6.7.1 Driver Requirements: Used Buffer Notification Suppression fn try_signal_queue(&mut self, queue_type: NetQueue) -> Result<(), DeviceError> { - let queue = match queue_type { - NetQueue::Rx => &mut self.queues[RX_INDEX], - NetQueue::Tx => &mut self.queues[TX_INDEX], + let qidx = match queue_type { + NetQueue::Rx => RX_INDEX, + NetQueue::Tx => TX_INDEX, }; - if queue.prepare_kick() { - self.irq_trigger - .trigger_irq(IrqType::Vring) + if self.queues[qidx].prepare_kick() { + self.interrupt_trigger() + .trigger(VirtioInterruptType::Queue(qidx.try_into().unwrap())) .map_err(|err| { self.metrics.event_fails.inc(); DeviceError::FailedSignalingIrq(err) @@ -463,7 +462,7 @@ impl Net { /// Parse available RX `DescriptorChains` from the queue pub fn parse_rx_descriptors(&mut self) { // This is safe since we checked in the event handler that the device is activated. - let mem = self.device_state.mem().unwrap(); + let mem = &self.device_state.active_state().unwrap().mem; let queue = &mut self.queues[RX_INDEX]; while let Some(head) = queue.pop_or_enable_notification() { let index = head.index; @@ -680,7 +679,7 @@ impl Net { fn process_tx(&mut self) -> Result<(), DeviceError> { // This is safe since we checked in the event handler that the device is activated. - let mem = self.device_state.mem().unwrap(); + let mem = &self.device_state.active_state().unwrap().mem; // The MMDS network stack works like a state machine, based on synchronous calls, and // without being added to any event loop. If any frame is accepted by the MMDS, we also @@ -962,9 +961,14 @@ impl VirtioDevice for Net { &self.queue_evts } - fn interrupt_trigger(&self) -> &IrqTrigger { - &self.irq_trigger + fn interrupt_trigger(&self) -> &dyn VirtioInterrupt { + self.device_state + .active_state() + .expect("Device is not implemented") + .interrupt + .deref() } + fn read_config(&self, offset: u64, data: &mut [u8]) { if let Some(config_space_bytes) = self.config_space.as_slice().get(u64_to_usize(offset)..) { let len = config_space_bytes.len().min(data.len()); @@ -993,7 +997,11 @@ impl VirtioDevice for Net { self.metrics.mac_address_updates.inc(); } - fn activate(&mut self, mem: GuestMemoryMmap) -> Result<(), ActivateError> { + fn activate( + &mut self, + mem: GuestMemoryMmap, + interrupt: Arc, + ) -> Result<(), ActivateError> { for q in self.queues.iter_mut() { q.initialize(&mem) .map_err(ActivateError::QueueMemoryError)?; @@ -1017,7 +1025,7 @@ impl VirtioDevice for Net { self.metrics.activate_fails.inc(); return Err(ActivateError::EventFd); } - self.device_state = DeviceState::Activated(mem); + self.device_state = DeviceState::Activated(ActiveState { mem, interrupt }); Ok(()) } @@ -1394,7 +1402,12 @@ pub mod tests { // Check that the used queue has advanced. assert_eq!(th.rxq.used.idx.get(), 4); - assert!(&th.net().irq_trigger.has_pending_irq(IrqType::Vring)); + assert!( + th.net() + .interrupt_trigger() + .has_pending_interrupt(VirtioInterruptType::Queue(RX_INDEX as u16)) + ); + // Check that the invalid descriptor chains have been discarded th.rxq.check_used_elem(0, 0, 0); th.rxq.check_used_elem(1, 3, 0); @@ -1451,7 +1464,11 @@ pub mod tests { assert!(th.net().rx_buffer.used_descriptors == 0); // Check that the used queue has advanced. assert_eq!(th.rxq.used.idx.get(), 1); - assert!(&th.net().irq_trigger.has_pending_irq(IrqType::Vring)); + assert!( + th.net() + .interrupt_trigger() + .has_pending_interrupt(VirtioInterruptType::Queue(RX_INDEX as u16)) + ); // Check that the frame has been written successfully to the Rx descriptor chain. header_set_num_buffers(frame.as_mut_slice(), 1); th.rxq @@ -1514,7 +1531,11 @@ pub mod tests { assert!(th.net().rx_buffer.used_bytes == 0); // Check that the used queue has advanced. assert_eq!(th.rxq.used.idx.get(), 2); - assert!(&th.net().irq_trigger.has_pending_irq(IrqType::Vring)); + assert!( + th.net() + .interrupt_trigger() + .has_pending_interrupt(VirtioInterruptType::Queue(RX_INDEX as u16)) + ); // Check that the 1st frame was written successfully to the 1st Rx descriptor chain. header_set_num_buffers(frame_1.as_mut_slice(), 1); th.rxq @@ -1572,7 +1593,11 @@ pub mod tests { assert!(th.net().rx_buffer.used_bytes == 0); // Check that the used queue has advanced. assert_eq!(th.rxq.used.idx.get(), 2); - assert!(&th.net().irq_trigger.has_pending_irq(IrqType::Vring)); + assert!( + th.net() + .interrupt_trigger() + .has_pending_interrupt(VirtioInterruptType::Queue(RX_INDEX as u16)) + ); // 2 chains should be used for the packet. header_set_num_buffers(frame.as_mut_slice(), 2); @@ -1637,7 +1662,11 @@ pub mod tests { // Check that the used queue advanced. assert_eq!(th.txq.used.idx.get(), 1); - assert!(&th.net().irq_trigger.has_pending_irq(IrqType::Vring)); + assert!( + th.net() + .interrupt_trigger() + .has_pending_interrupt(VirtioInterruptType::Queue(TX_INDEX as u16)) + ); th.txq.check_used_elem(0, 0, 0); // Check that the frame was skipped. assert!(!tap_traffic_simulator.pop_rx_packet(&mut [])); @@ -1660,7 +1689,11 @@ pub mod tests { // Check that the used queue advanced. assert_eq!(th.txq.used.idx.get(), 1); - assert!(&th.net().irq_trigger.has_pending_irq(IrqType::Vring)); + assert!( + th.net() + .interrupt_trigger() + .has_pending_interrupt(VirtioInterruptType::Queue(TX_INDEX as u16)) + ); th.txq.check_used_elem(0, 0, 0); // Check that the frame was skipped. assert!(!tap_traffic_simulator.pop_rx_packet(&mut [])); @@ -1687,7 +1720,11 @@ pub mod tests { // Check that the used queue advanced. assert_eq!(th.txq.used.idx.get(), 1); - assert!(&th.net().irq_trigger.has_pending_irq(IrqType::Vring)); + assert!( + th.net() + .interrupt_trigger() + .has_pending_interrupt(VirtioInterruptType::Queue(TX_INDEX as u16)) + ); th.txq.check_used_elem(0, 0, 0); // Check that the frame was skipped. assert!(!tap_traffic_simulator.pop_rx_packet(&mut [])); @@ -1710,7 +1747,11 @@ pub mod tests { // Check that the used queue advanced. assert_eq!(th.txq.used.idx.get(), 1); - assert!(&th.net().irq_trigger.has_pending_irq(IrqType::Vring)); + assert!( + th.net() + .interrupt_trigger() + .has_pending_interrupt(VirtioInterruptType::Queue(TX_INDEX as u16)) + ); th.txq.check_used_elem(0, 0, 0); // Check that the frame was skipped. assert!(!tap_traffic_simulator.pop_rx_packet(&mut [])); @@ -1749,7 +1790,11 @@ pub mod tests { // Check that the used queue advanced. assert_eq!(th.txq.used.idx.get(), 4); - assert!(&th.net().irq_trigger.has_pending_irq(IrqType::Vring)); + assert!( + th.net() + .interrupt_trigger() + .has_pending_interrupt(VirtioInterruptType::Queue(TX_INDEX as u16)) + ); th.txq.check_used_elem(3, 4, 0); // Check that the valid frame was sent to the tap. let mut buf = vec![0; 1000]; @@ -1780,7 +1825,11 @@ pub mod tests { // Check that the used queue advanced. assert_eq!(th.txq.used.idx.get(), 1); - assert!(&th.net().irq_trigger.has_pending_irq(IrqType::Vring)); + assert!( + th.net() + .interrupt_trigger() + .has_pending_interrupt(VirtioInterruptType::Queue(TX_INDEX as u16)) + ); th.txq.check_used_elem(0, 3, 0); // Check that the frame was sent to the tap. let mut buf = vec![0; 1000]; @@ -1809,7 +1858,11 @@ pub mod tests { // Check that the used queue advanced. assert_eq!(th.txq.used.idx.get(), 1); - assert!(&th.net().irq_trigger.has_pending_irq(IrqType::Vring)); + assert!( + th.net() + .interrupt_trigger() + .has_pending_interrupt(VirtioInterruptType::Queue(TX_INDEX as u16)) + ); th.txq.check_used_elem(0, 0, 0); // dropping th would double close the tap fd, so leak it @@ -1840,7 +1893,11 @@ pub mod tests { // Check that the used queue advanced. assert_eq!(th.txq.used.idx.get(), 2); - assert!(&th.net().irq_trigger.has_pending_irq(IrqType::Vring)); + assert!( + th.net() + .interrupt_trigger() + .has_pending_interrupt(VirtioInterruptType::Queue(TX_INDEX as u16)) + ); th.txq.check_used_elem(0, 0, 0); th.txq.check_used_elem(1, 3, 0); // Check that the first frame was sent to the tap. @@ -2192,7 +2249,11 @@ pub mod tests { assert_eq!(th.net().metrics.rx_rate_limiter_throttled.count(), 1); assert!(th.net().rx_buffer.used_descriptors != 0); // assert that no operation actually completed (limiter blocked it) - assert!(&th.net().irq_trigger.has_pending_irq(IrqType::Vring)); + assert!( + th.net() + .interrupt_trigger() + .has_pending_interrupt(VirtioInterruptType::Queue(RX_INDEX as u16)) + ); // make sure the data is still queued for processing assert_eq!(th.rxq.used.idx.get(), 0); } @@ -2220,7 +2281,11 @@ pub mod tests { // validate the rate_limiter is no longer blocked assert!(!th.net().rx_rate_limiter.is_blocked()); // make sure the virtio queue operation completed this time - assert!(&th.net().irq_trigger.has_pending_irq(IrqType::Vring)); + assert!( + th.net() + .interrupt_trigger() + .has_pending_interrupt(VirtioInterruptType::Queue(RX_INDEX as u16)) + ); // make sure the data queue advanced assert_eq!(th.rxq.used.idx.get(), 1); th.rxq @@ -2317,14 +2382,22 @@ pub mod tests { assert!(th.net().metrics.rx_rate_limiter_throttled.count() >= 1); assert!(th.net().rx_buffer.used_descriptors != 0); // assert that no operation actually completed (limiter blocked it) - assert!(&th.net().irq_trigger.has_pending_irq(IrqType::Vring)); + assert!( + th.net() + .interrupt_trigger() + .has_pending_interrupt(VirtioInterruptType::Queue(RX_INDEX as u16)) + ); // make sure the data is still queued for processing assert_eq!(th.rxq.used.idx.get(), 0); // trigger the RX handler again, this time it should do the limiter fast path exit th.simulate_event(NetEvent::Tap); // assert that no operation actually completed, that the limiter blocked it - assert!(!&th.net().irq_trigger.has_pending_irq(IrqType::Vring)); + assert!( + !th.net() + .interrupt_trigger() + .has_pending_interrupt(VirtioInterruptType::Queue(RX_INDEX as u16)) + ); // make sure the data is still queued for processing assert_eq!(th.rxq.used.idx.get(), 0); } @@ -2337,7 +2410,11 @@ pub mod tests { { th.simulate_event(NetEvent::RxRateLimiter); // make sure the virtio queue operation completed this time - assert!(&th.net().irq_trigger.has_pending_irq(IrqType::Vring)); + assert!( + th.net() + .interrupt_trigger() + .has_pending_interrupt(VirtioInterruptType::Queue(RX_INDEX as u16)) + ); // make sure the data queue advanced assert_eq!(th.rxq.used.idx.get(), 1); th.rxq @@ -2407,7 +2484,14 @@ pub mod tests { assert_eq!(net.queue_events().len(), NET_QUEUE_SIZES.len()); // Test interrupts. - assert!(!&net.irq_trigger.has_pending_irq(IrqType::Vring)); + assert!( + !net.interrupt_trigger() + .has_pending_interrupt(VirtioInterruptType::Queue(RX_INDEX as u16)) + ); + assert!( + !net.interrupt_trigger() + .has_pending_interrupt(VirtioInterruptType::Queue(TX_INDEX as u16)) + ); } #[test] diff --git a/src/vmm/src/devices/virtio/net/persist.rs b/src/vmm/src/devices/virtio/net/persist.rs index 5f2d6f560b4..961b56556c8 100644 --- a/src/vmm/src/devices/virtio/net/persist.rs +++ b/src/vmm/src/devices/virtio/net/persist.rs @@ -4,7 +4,6 @@ //! Defines the structures needed for saving/restoring net devices. use std::io; -use std::sync::atomic::AtomicU32; use std::sync::{Arc, Mutex}; use serde::{Deserialize, Serialize}; @@ -12,8 +11,9 @@ use serde::{Deserialize, Serialize}; use super::device::{Net, RxBuffers}; use super::{NET_NUM_QUEUES, NET_QUEUE_MAX_SIZE, RX_INDEX, TapError}; use crate::devices::virtio::TYPE_NET; -use crate::devices::virtio::device::DeviceState; +use crate::devices::virtio::device::{ActiveState, DeviceState}; use crate::devices::virtio::persist::{PersistError as VirtioStateError, VirtioDeviceState}; +use crate::devices::virtio::transport::VirtioInterrupt; use crate::mmds::data_store::Mmds; use crate::mmds::ns::MmdsNetworkStack; use crate::mmds::persist::MmdsNetworkStackState; @@ -71,6 +71,8 @@ pub struct NetState { pub struct NetConstructorArgs { /// Pointer to guest memory. pub mem: GuestMemoryMmap, + /// Interrupt for the device. + pub interrupt: Arc, /// Pointer to the MMDS data store. pub mmds: Option>>, } @@ -148,7 +150,6 @@ impl Persist<'_> for Net { NET_NUM_QUEUES, NET_QUEUE_MAX_SIZE, )?; - net.irq_trigger.irq_status = Arc::new(AtomicU32::new(state.virtio_state.interrupt_status)); net.avail_features = state.virtio_state.avail_features; net.acked_features = state.virtio_state.acked_features; @@ -158,7 +159,10 @@ impl Persist<'_> for Net { .set_offload(supported_flags) .map_err(NetPersistError::TapSetOffload)?; - net.device_state = DeviceState::Activated(constructor_args.mem); + net.device_state = DeviceState::Activated(ActiveState { + mem: constructor_args.mem, + interrupt: constructor_args.interrupt, + }); // Recreate `Net::rx_buffer`. We do it by re-parsing the RX queue. We're temporarily // rolling back `next_avail` in the RX queue and call `parse_rx_descriptors`. @@ -174,12 +178,11 @@ impl Persist<'_> for Net { #[cfg(test)] mod tests { - use std::sync::atomic::Ordering; use super::*; use crate::devices::virtio::device::VirtioDevice; use crate::devices::virtio::net::test_utils::{default_net, default_net_no_mmds}; - use crate::devices::virtio::test_utils::default_mem; + use crate::devices::virtio::test_utils::{default_interrupt, default_mem}; use crate::snapshot::Snapshot; fn validate_save_and_restore(net: Net, mmds_ds: Option>>) { @@ -212,6 +215,7 @@ mod tests { match Net::restore( NetConstructorArgs { mem: guest_mem, + interrupt: default_interrupt(), mmds: mmds_ds, }, &Snapshot::deserialize(&mut mem.as_slice()).unwrap(), @@ -221,10 +225,6 @@ mod tests { assert_eq!(restored_net.device_type(), TYPE_NET); assert_eq!(restored_net.avail_features(), virtio_state.avail_features); assert_eq!(restored_net.acked_features(), virtio_state.acked_features); - assert_eq!( - restored_net.interrupt_status().load(Ordering::Relaxed), - virtio_state.interrupt_status - ); assert_eq!(restored_net.is_activated(), virtio_state.activated); // Test that net specific fields are the same. diff --git a/src/vmm/src/devices/virtio/net/test_utils.rs b/src/vmm/src/devices/virtio/net/test_utils.rs index 5762123be68..59c2817aa6b 100644 --- a/src/vmm/src/devices/virtio/net/test_utils.rs +++ b/src/vmm/src/devices/virtio/net/test_utils.rs @@ -104,7 +104,7 @@ impl TapTrafficSimulator { let send_addr_ptr = &mut storage as *mut libc::sockaddr_storage; - // SAFETY: `sock_addr` is a valid pointer and safe to derference. + // SAFETY: `sock_addr` is a valid pointer and safe to dereference. unsafe { let sock_addr: *mut libc::sockaddr_ll = send_addr_ptr.cast::(); (*sock_addr).sll_family = libc::sa_family_t::try_from(libc::AF_PACKET).unwrap(); @@ -223,7 +223,7 @@ pub fn if_index(tap: &Tap) -> i32 { /// Enable the tap interface. pub fn enable(tap: &Tap) { - // Disable IPv6 router advertisment requests + // Disable IPv6 router advertisement requests Command::new("sh") .arg("-c") .arg(format!( @@ -311,7 +311,7 @@ pub mod test { use event_manager::{EventManager, SubscriberId, SubscriberOps}; use crate::check_metric_after_block; - use crate::devices::virtio::device::{IrqType, VirtioDevice}; + use crate::devices::virtio::device::VirtioDevice; use crate::devices::virtio::net::device::vnet_hdr_len; use crate::devices::virtio::net::generated::ETH_HLEN; use crate::devices::virtio::net::test_utils::{ @@ -319,7 +319,8 @@ pub mod test { }; use crate::devices::virtio::net::{MAX_BUFFER_SIZE, Net, RX_INDEX, TX_INDEX}; use crate::devices::virtio::queue::{VIRTQ_DESC_F_NEXT, VIRTQ_DESC_F_WRITE}; - use crate::devices::virtio::test_utils::{VirtQueue, VirtqDesc}; + use crate::devices::virtio::test_utils::{VirtQueue, VirtqDesc, default_interrupt}; + use crate::devices::virtio::transport::VirtioInterruptType; use crate::logger::IncMetric; use crate::vstate::memory::{Address, Bytes, GuestAddress, GuestMemoryMmap}; @@ -378,7 +379,12 @@ pub mod test { } pub fn activate_net(&mut self) { - self.net.lock().unwrap().activate(self.mem.clone()).unwrap(); + let interrupt = default_interrupt(); + self.net + .lock() + .unwrap() + .activate(self.mem.clone(), interrupt) + .unwrap(); // Process the activate event. let ev_count = self.event_manager.run_with_timeout(100).unwrap(); assert_eq!(ev_count, 1); @@ -455,7 +461,11 @@ pub mod test { old_used_descriptors + 1 ); - assert!(&self.net().irq_trigger.has_pending_irq(IrqType::Vring)); + assert!( + self.net() + .interrupt_trigger() + .has_pending_interrupt(VirtioInterruptType::Queue(RX_INDEX as u16)) + ); frame } @@ -481,7 +491,11 @@ pub mod test { ); // Check that the expected frame was sent to the Rx queue eventually. assert_eq!(self.rxq.used.idx.get(), used_idx + 1); - assert!(&self.net().irq_trigger.has_pending_irq(IrqType::Vring)); + assert!( + self.net() + .interrupt_trigger() + .has_pending_interrupt(VirtioInterruptType::Queue(RX_INDEX as u16)) + ); self.rxq .check_used_elem(used_idx, 0, expected_frame.len().try_into().unwrap()); self.rxq.dtable[0].check_data(expected_frame); diff --git a/src/vmm/src/devices/virtio/persist.rs b/src/vmm/src/devices/virtio/persist.rs index 7c861352317..06095052fae 100644 --- a/src/vmm/src/devices/virtio/persist.rs +++ b/src/vmm/src/devices/virtio/persist.rs @@ -10,10 +10,11 @@ use std::sync::{Arc, Mutex}; use serde::{Deserialize, Serialize}; use super::queue::QueueError; +use super::transport::mmio::IrqTrigger; use crate::devices::virtio::device::VirtioDevice; use crate::devices::virtio::generated::virtio_ring::VIRTIO_RING_F_EVENT_IDX; -use crate::devices::virtio::mmio::MmioTransport; use crate::devices::virtio::queue::Queue; +use crate::devices::virtio::transport::mmio::MmioTransport; use crate::snapshot::Persist; use crate::vstate::memory::{GuestAddress, GuestMemoryMmap}; @@ -121,8 +122,6 @@ pub struct VirtioDeviceState { pub acked_features: u64, /// List of queues. pub queues: Vec, - /// The MMIO interrupt status. - pub interrupt_status: u32, /// Flag for activated status. pub activated: bool, } @@ -135,7 +134,6 @@ impl VirtioDeviceState { avail_features: device.avail_features(), acked_features: device.acked_features(), queues: device.queues().iter().map(Persist::save).collect(), - interrupt_status: device.interrupt_status().load(Ordering::Relaxed), activated: device.is_activated(), } } @@ -207,6 +205,7 @@ pub struct MmioTransportState { queue_select: u32, device_status: u32, config_generation: u32, + interrupt_status: u32, } /// Auxiliary structure for initializing the transport when resuming from a snapshot. @@ -214,6 +213,8 @@ pub struct MmioTransportState { pub struct MmioTransportConstructorArgs { /// Pointer to guest memory. pub mem: GuestMemoryMmap, + /// Interrupt to use for the device + pub interrupt: Arc, /// Device associated with the current MMIO state. pub device: Arc>, /// Is device backed by vhost-user. @@ -232,6 +233,7 @@ impl Persist<'_> for MmioTransport { queue_select: self.queue_select, device_status: self.device_status, config_generation: self.config_generation, + interrupt_status: self.interrupt.irq_status.load(Ordering::SeqCst), } } @@ -241,6 +243,7 @@ impl Persist<'_> for MmioTransport { ) -> Result { let mut transport = MmioTransport::new( constructor_args.mem, + constructor_args.interrupt, constructor_args.device, constructor_args.is_vhost_user, ); @@ -249,6 +252,10 @@ impl Persist<'_> for MmioTransport { transport.queue_select = state.queue_select; transport.device_status = state.device_status; transport.config_generation = state.config_generation; + transport + .interrupt + .irq_status + .store(state.interrupt_status, Ordering::SeqCst); Ok(transport) } } @@ -261,10 +268,10 @@ mod tests { use crate::devices::virtio::block::virtio::VirtioBlock; use crate::devices::virtio::block::virtio::device::FileEngineType; use crate::devices::virtio::block::virtio::test_utils::default_block_with_path; - use crate::devices::virtio::mmio::tests::DummyDevice; use crate::devices::virtio::net::Net; use crate::devices::virtio::net::test_utils::default_net; use crate::devices::virtio::test_utils::default_mem; + use crate::devices::virtio::transport::mmio::tests::DummyDevice; use crate::devices::virtio::vsock::{Vsock, VsockUnixBackend}; use crate::snapshot::Snapshot; @@ -385,7 +392,7 @@ mod tests { self.queue_select == other.queue_select && self.device_status == other.device_status && self.config_generation == other.config_generation && - self.interrupt_status.load(Ordering::SeqCst) == other.interrupt_status.load(Ordering::SeqCst) && + self.interrupt.irq_status.load(Ordering::SeqCst) == other.interrupt.irq_status.load(Ordering::SeqCst) && // Only checking equality of device type, actual device (de)ser is tested by that // device's tests. self_dev_type == other.device().lock().unwrap().device_type() @@ -394,6 +401,7 @@ mod tests { fn generic_mmiotransport_persistence_test( mmio_transport: MmioTransport, + interrupt: Arc, mem: GuestMemoryMmap, device: Arc>, ) { @@ -403,6 +411,7 @@ mod tests { let restore_args = MmioTransportConstructorArgs { mem, + interrupt, device, is_vhost_user: false, }; @@ -415,8 +424,14 @@ mod tests { assert_eq!(restored_mmio_transport, mmio_transport); } - fn create_default_block() -> (MmioTransport, GuestMemoryMmap, Arc>) { + fn create_default_block() -> ( + MmioTransport, + Arc, + GuestMemoryMmap, + Arc>, + ) { let mem = default_mem(); + let interrupt = Arc::new(IrqTrigger::new()); // Create backing file. let f = TempFile::new().unwrap(); @@ -426,25 +441,34 @@ mod tests { FileEngineType::default(), ); let block = Arc::new(Mutex::new(block)); - let mmio_transport = MmioTransport::new(mem.clone(), block.clone(), false); + let mmio_transport = + MmioTransport::new(mem.clone(), interrupt.clone(), block.clone(), false); - (mmio_transport, mem, block) + (mmio_transport, interrupt, mem, block) } - fn create_default_net() -> (MmioTransport, GuestMemoryMmap, Arc>) { + fn create_default_net() -> ( + MmioTransport, + Arc, + GuestMemoryMmap, + Arc>, + ) { let mem = default_mem(); + let interrupt = Arc::new(IrqTrigger::new()); let net = Arc::new(Mutex::new(default_net())); - let mmio_transport = MmioTransport::new(mem.clone(), net.clone(), false); + let mmio_transport = MmioTransport::new(mem.clone(), interrupt.clone(), net.clone(), false); - (mmio_transport, mem, net) + (mmio_transport, interrupt, mem, net) } fn default_vsock() -> ( MmioTransport, + Arc, GuestMemoryMmap, Arc>>, ) { let mem = default_mem(); + let interrupt = Arc::new(IrqTrigger::new()); let guest_cid = 52; let mut temp_uds_path = TempFile::new().unwrap(); @@ -454,26 +478,27 @@ mod tests { let backend = VsockUnixBackend::new(guest_cid, uds_path).unwrap(); let vsock = Vsock::new(guest_cid, backend).unwrap(); let vsock = Arc::new(Mutex::new(vsock)); - let mmio_transport = MmioTransport::new(mem.clone(), vsock.clone(), false); + let mmio_transport = + MmioTransport::new(mem.clone(), interrupt.clone(), vsock.clone(), false); - (mmio_transport, mem, vsock) + (mmio_transport, interrupt, mem, vsock) } #[test] fn test_block_over_mmiotransport_persistence() { - let (mmio_transport, mem, block) = create_default_block(); - generic_mmiotransport_persistence_test(mmio_transport, mem, block); + let (mmio_transport, interrupt, mem, block) = create_default_block(); + generic_mmiotransport_persistence_test(mmio_transport, interrupt, mem, block); } #[test] fn test_net_over_mmiotransport_persistence() { - let (mmio_transport, mem, net) = create_default_net(); - generic_mmiotransport_persistence_test(mmio_transport, mem, net); + let (mmio_transport, interrupt, mem, net) = create_default_net(); + generic_mmiotransport_persistence_test(mmio_transport, interrupt, mem, net); } #[test] fn test_vsock_over_mmiotransport_persistence() { - let (mmio_transport, mem, vsock) = default_vsock(); - generic_mmiotransport_persistence_test(mmio_transport, mem, vsock); + let (mmio_transport, interrupt, mem, vsock) = default_vsock(); + generic_mmiotransport_persistence_test(mmio_transport, interrupt, mem, vsock); } } diff --git a/src/vmm/src/devices/virtio/queue.rs b/src/vmm/src/devices/virtio/queue.rs index efe42bfc3dc..686d3ee3da3 100644 --- a/src/vmm/src/devices/virtio/queue.rs +++ b/src/vmm/src/devices/virtio/queue.rs @@ -20,7 +20,7 @@ pub(super) const FIRECRACKER_MAX_QUEUE_SIZE: u16 = 256; // GuestMemoryMmap::read_obj_from_addr() will be used to fetch the descriptor, // which has an explicit constraint that the entire descriptor doesn't -// cross the page boundary. Otherwise the descriptor may be splitted into +// cross the page boundary. Otherwise the descriptor may be split into // two mmap regions which causes failure of GuestMemoryMmap::read_obj_from_addr(). // // The Virtio Spec 1.0 defines the alignment of VirtIO descriptor is 16 bytes, diff --git a/src/vmm/src/devices/virtio/rng/device.rs b/src/vmm/src/devices/virtio/rng/device.rs index 97ac8676e0a..937e113d096 100644 --- a/src/vmm/src/devices/virtio/rng/device.rs +++ b/src/vmm/src/devices/virtio/rng/device.rs @@ -2,8 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 use std::io; +use std::ops::Deref; use std::sync::Arc; -use std::sync::atomic::AtomicU32; use aws_lc_rs::rand; use vm_memory::GuestMemoryError; @@ -12,11 +12,12 @@ use vmm_sys_util::eventfd::EventFd; use super::metrics::METRICS; use super::{RNG_NUM_QUEUES, RNG_QUEUE}; use crate::devices::DeviceError; -use crate::devices::virtio::device::{DeviceState, IrqTrigger, IrqType, VirtioDevice}; +use crate::devices::virtio::device::{ActiveState, DeviceState, VirtioDevice}; use crate::devices::virtio::generated::virtio_config::VIRTIO_F_VERSION_1; use crate::devices::virtio::iov_deque::IovDequeError; use crate::devices::virtio::iovec::IoVecBufferMut; use crate::devices::virtio::queue::{FIRECRACKER_MAX_QUEUE_SIZE, Queue}; +use crate::devices::virtio::transport::{VirtioInterrupt, VirtioInterruptType}; use crate::devices::virtio::{ActivateError, TYPE_RNG}; use crate::logger::{IncMetric, debug, error}; use crate::rate_limiter::{RateLimiter, TokenType}; @@ -47,7 +48,6 @@ pub struct Entropy { device_state: DeviceState, pub(crate) queues: Vec, queue_events: Vec, - irq_trigger: IrqTrigger, // Device specific fields rate_limiter: RateLimiter, @@ -69,7 +69,6 @@ impl Entropy { let queue_events = (0..RNG_NUM_QUEUES) .map(|_| EventFd::new(libc::EFD_NONBLOCK)) .collect::, io::Error>>()?; - let irq_trigger = IrqTrigger::new()?; Ok(Self { avail_features: 1 << VIRTIO_F_VERSION_1, @@ -78,7 +77,6 @@ impl Entropy { device_state: DeviceState::Inactive, queues, queue_events, - irq_trigger, rate_limiter, buffer: IoVecBufferMut::new()?, }) @@ -89,8 +87,8 @@ impl Entropy { } fn signal_used_queue(&self) -> Result<(), DeviceError> { - self.irq_trigger - .trigger_irq(IrqType::Vring) + self.interrupt_trigger() + .trigger(VirtioInterruptType::Queue(RNG_QUEUE.try_into().unwrap())) .map_err(DeviceError::FailedSignalingIrq) } @@ -132,7 +130,7 @@ impl Entropy { let mut used_any = false; while let Some(desc) = self.queues[RNG_QUEUE].pop() { // This is safe since we checked in the event handler that the device is activated. - let mem = self.device_state.mem().unwrap(); + let mem = &self.device_state.active_state().unwrap().mem; let index = desc.index; METRICS.entropy_event_count.inc(); @@ -236,12 +234,12 @@ impl Entropy { self.acked_features = features; } - pub(crate) fn set_irq_status(&mut self, status: u32) { - self.irq_trigger.irq_status = Arc::new(AtomicU32::new(status)); - } - - pub(crate) fn set_activated(&mut self, mem: GuestMemoryMmap) { - self.device_state = DeviceState::Activated(mem); + pub(crate) fn set_activated( + &mut self, + mem: GuestMemoryMmap, + interrupt: Arc, + ) { + self.device_state = DeviceState::Activated(ActiveState { mem, interrupt }); } pub(crate) fn activate_event(&self) -> &EventFd { @@ -266,8 +264,12 @@ impl VirtioDevice for Entropy { &self.queue_events } - fn interrupt_trigger(&self) -> &IrqTrigger { - &self.irq_trigger + fn interrupt_trigger(&self) -> &dyn VirtioInterrupt { + self.device_state + .active_state() + .expect("Device is not initialized") + .interrupt + .deref() } fn avail_features(&self) -> u64 { @@ -290,7 +292,11 @@ impl VirtioDevice for Entropy { self.device_state.is_activated() } - fn activate(&mut self, mem: GuestMemoryMmap) -> Result<(), ActivateError> { + fn activate( + &mut self, + mem: GuestMemoryMmap, + interrupt: Arc, + ) -> Result<(), ActivateError> { for q in self.queues.iter_mut() { q.initialize(&mem) .map_err(ActivateError::QueueMemoryError)?; @@ -300,7 +306,7 @@ impl VirtioDevice for Entropy { METRICS.activate_fails.inc(); ActivateError::EventFd })?; - self.device_state = DeviceState::Activated(mem); + self.device_state = DeviceState::Activated(ActiveState { mem, interrupt }); Ok(()) } } diff --git a/src/vmm/src/devices/virtio/rng/persist.rs b/src/vmm/src/devices/virtio/rng/persist.rs index 2f2519b4962..75db947c9c7 100644 --- a/src/vmm/src/devices/virtio/rng/persist.rs +++ b/src/vmm/src/devices/virtio/rng/persist.rs @@ -3,12 +3,15 @@ //! Defines the structures needed for saving/restoring entropy devices. +use std::sync::Arc; + use serde::{Deserialize, Serialize}; use crate::devices::virtio::TYPE_RNG; use crate::devices::virtio::persist::{PersistError as VirtioStateError, VirtioDeviceState}; use crate::devices::virtio::queue::FIRECRACKER_MAX_QUEUE_SIZE; use crate::devices::virtio::rng::{Entropy, EntropyError, RNG_NUM_QUEUES}; +use crate::devices::virtio::transport::VirtioInterrupt; use crate::rate_limiter::RateLimiter; use crate::rate_limiter::persist::RateLimiterState; use crate::snapshot::Persist; @@ -21,11 +24,14 @@ pub struct EntropyState { } #[derive(Debug)] -pub struct EntropyConstructorArgs(GuestMemoryMmap); +pub struct EntropyConstructorArgs { + mem: GuestMemoryMmap, + interrupt: Arc, +} impl EntropyConstructorArgs { - pub fn new(mem: GuestMemoryMmap) -> Self { - Self(mem) + pub fn new(mem: GuestMemoryMmap, interrupt: Arc) -> Self { + Self { mem, interrupt } } } @@ -56,7 +62,7 @@ impl Persist<'_> for Entropy { state: &Self::State, ) -> Result { let queues = state.virtio_state.build_queues_checked( - &constructor_args.0, + &constructor_args.mem, TYPE_RNG, RNG_NUM_QUEUES, FIRECRACKER_MAX_QUEUE_SIZE, @@ -66,9 +72,8 @@ impl Persist<'_> for Entropy { let mut entropy = Entropy::new_with_queues(queues, rate_limiter)?; entropy.set_avail_features(state.virtio_state.avail_features); entropy.set_acked_features(state.virtio_state.acked_features); - entropy.set_irq_status(state.virtio_state.interrupt_status); if state.virtio_state.activated { - entropy.set_activated(constructor_args.0); + entropy.set_activated(constructor_args.mem, constructor_args.interrupt); } Ok(entropy) @@ -77,11 +82,11 @@ impl Persist<'_> for Entropy { #[cfg(test)] mod tests { - use std::sync::atomic::Ordering; use super::*; use crate::devices::virtio::device::VirtioDevice; use crate::devices::virtio::rng::device::ENTROPY_DEV_ID; + use crate::devices::virtio::test_utils::default_interrupt; use crate::devices::virtio::test_utils::test::create_virtio_mem; use crate::snapshot::Snapshot; @@ -94,19 +99,16 @@ mod tests { let guest_mem = create_virtio_mem(); let restored = Entropy::restore( - EntropyConstructorArgs(guest_mem), + EntropyConstructorArgs::new(guest_mem, default_interrupt()), &Snapshot::deserialize(&mut mem.as_slice()).unwrap(), ) .unwrap(); assert_eq!(restored.device_type(), TYPE_RNG); assert_eq!(restored.id(), ENTROPY_DEV_ID); - assert_eq!(restored.is_activated(), entropy.is_activated()); + assert!(!restored.is_activated()); + assert!(!entropy.is_activated()); assert_eq!(restored.avail_features(), entropy.avail_features()); assert_eq!(restored.acked_features(), entropy.acked_features()); - assert_eq!( - restored.interrupt_status().load(Ordering::Relaxed), - entropy.interrupt_status().load(Ordering::Relaxed) - ); } } diff --git a/src/vmm/src/devices/virtio/test_utils.rs b/src/vmm/src/devices/virtio/test_utils.rs index 8642d0a85f4..861394c1c7d 100644 --- a/src/vmm/src/devices/virtio/test_utils.rs +++ b/src/vmm/src/devices/virtio/test_utils.rs @@ -6,9 +6,12 @@ use std::fmt::Debug; use std::marker::PhantomData; use std::mem; +use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; use crate::devices::virtio::queue::Queue; +use crate::devices::virtio::transport::VirtioInterrupt; +use crate::devices::virtio::transport::mmio::IrqTrigger; use crate::test_utils::single_region_mem; use crate::utils::{align_up, u64_to_usize}; use crate::vstate::memory::{Address, Bytes, GuestAddress, GuestMemoryMmap}; @@ -28,6 +31,11 @@ pub fn default_mem() -> GuestMemoryMmap { single_region_mem(0x10000) } +/// Creates a default ['IrqTrigger'] interrupt for a VirtIO device. +pub fn default_interrupt() -> Arc { + Arc::new(IrqTrigger::new()) +} + #[derive(Debug)] pub struct InputData { pub data: Vec, @@ -323,7 +331,7 @@ pub(crate) mod test { use crate::devices::virtio::device::VirtioDevice; use crate::devices::virtio::net::MAX_BUFFER_SIZE; use crate::devices::virtio::queue::{Queue, VIRTQ_DESC_F_NEXT}; - use crate::devices::virtio::test_utils::{VirtQueue, VirtqDesc}; + use crate::devices::virtio::test_utils::{VirtQueue, VirtqDesc, default_interrupt}; use crate::test_utils::single_region_mem; use crate::vstate::memory::{Address, GuestAddress, GuestMemoryMmap}; @@ -414,7 +422,12 @@ pub(crate) mod test { /// Activate the device pub fn activate_device(&mut self, mem: &'a GuestMemoryMmap) { - self.device.lock().unwrap().activate(mem.clone()).unwrap(); + let interrupt = default_interrupt(); + self.device + .lock() + .unwrap() + .activate(mem.clone(), interrupt) + .unwrap(); // Process the activate event let ev_count = self.event_manager.run_with_timeout(100).unwrap(); assert_eq!(ev_count, 1); diff --git a/src/vmm/src/devices/virtio/mmio.rs b/src/vmm/src/devices/virtio/transport/mmio.rs similarity index 84% rename from src/vmm/src/devices/virtio/mmio.rs rename to src/vmm/src/devices/virtio/transport/mmio.rs index 12ee54bfb0a..5557c4c500e 100644 --- a/src/vmm/src/devices/virtio/mmio.rs +++ b/src/vmm/src/devices/virtio/transport/mmio.rs @@ -9,7 +9,10 @@ use std::fmt::Debug; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::{Arc, Mutex, MutexGuard}; -use crate::devices::virtio::device::{IrqType, VirtioDevice}; +use vmm_sys_util::eventfd::EventFd; + +use super::{VirtioInterrupt, VirtioInterruptType}; +use crate::devices::virtio::device::VirtioDevice; use crate::devices::virtio::device_status; use crate::devices::virtio::queue::Queue; use crate::logger::{error, warn}; @@ -55,7 +58,7 @@ pub struct MmioTransport { pub(crate) device_status: u32, pub(crate) config_generation: u32, mem: GuestMemoryMmap, - pub(crate) interrupt_status: Arc, + pub(crate) interrupt: Arc, pub is_vhost_user: bool, } @@ -63,11 +66,10 @@ impl MmioTransport { /// Constructs a new MMIO transport for the given virtio device. pub fn new( mem: GuestMemoryMmap, + interrupt: Arc, device: Arc>, is_vhost_user: bool, ) -> MmioTransport { - let interrupt_status = device.lock().expect("Poisoned lock").interrupt_status(); - MmioTransport { device, features_select: 0, @@ -76,7 +78,7 @@ impl MmioTransport { device_status: device_status::INIT, config_generation: 0, mem, - interrupt_status, + interrupt, is_vhost_user, } } @@ -151,7 +153,7 @@ impl MmioTransport { self.features_select = 0; self.acked_features_select = 0; self.queue_select = 0; - self.interrupt_status.store(0, Ordering::SeqCst); + self.interrupt.irq_status.store(0, Ordering::SeqCst); self.device_status = device_status::INIT; // . Keep interrupt_evt and queue_evts as is. There may be pending notifications in those // eventfds, but nothing will happen other than supurious wakeups. @@ -187,7 +189,9 @@ impl MmioTransport { let device_activated = self.locked_device().is_activated(); if !device_activated && self.are_queues_valid() { // temporary variable needed for borrow checker - let activate_result = self.locked_device().activate(self.mem.clone()); + let activate_result = self + .locked_device() + .activate(self.mem.clone(), self.interrupt.clone()); if let Err(err) = activate_result { self.device_status |= DEVICE_NEEDS_RESET; @@ -196,7 +200,7 @@ impl MmioTransport { let _ = self .locked_device() .interrupt_trigger() - .trigger_irq(IrqType::Config); + .trigger(VirtioInterruptType::Config); error!("Failed to activate virtio device: {}", err) } @@ -270,7 +274,7 @@ impl MmioTransport { // `VIRTIO_MMIO_INT_CONFIG` or not to understand if we need to send // `VIRTIO_MMIO_INT_CONFIG` or // `VIRTIO_MMIO_INT_VRING`. - let is = self.interrupt_status.load(Ordering::SeqCst); + let is = self.interrupt.irq_status.load(Ordering::SeqCst); if !self.is_vhost_user { is } else if is == VIRTIO_MMIO_INT_CONFIG { @@ -331,7 +335,7 @@ impl MmioTransport { 0x44 => self.update_queue_field(|q| q.ready = v == 1), 0x64 => { if self.check_device_status(device_status::DRIVER_OK, 0) { - self.interrupt_status.fetch_and(!v, Ordering::SeqCst); + self.interrupt.irq_status.fetch_and(!v, Ordering::SeqCst); } } 0x70 => self.set_device_status(v), @@ -363,13 +367,105 @@ impl MmioTransport { } } +/// The 2 types of interrupt sources in MMIO transport. +#[derive(Debug)] +pub enum IrqType { + /// Interrupt triggered by change in config. + Config, + /// Interrupt triggered by used vring buffers. + Vring, +} + +impl From for IrqType { + fn from(interrupt_type: VirtioInterruptType) -> Self { + match interrupt_type { + VirtioInterruptType::Config => IrqType::Config, + VirtioInterruptType::Queue(_) => IrqType::Vring, + } + } +} + +/// Helper struct that is responsible for triggering guest IRQs +#[derive(Debug)] +pub struct IrqTrigger { + pub(crate) irq_status: Arc, + pub(crate) irq_evt: EventFd, +} + +impl Default for IrqTrigger { + fn default() -> Self { + Self::new() + } +} + +impl VirtioInterrupt for IrqTrigger { + fn trigger(&self, interrupt_type: VirtioInterruptType) -> Result<(), std::io::Error> { + match interrupt_type { + VirtioInterruptType::Config => self.trigger_irq(IrqType::Config), + VirtioInterruptType::Queue(_) => self.trigger_irq(IrqType::Vring), + } + } + + fn notifier(&self, _interrupt_type: VirtioInterruptType) -> Option<&EventFd> { + Some(&self.irq_evt) + } + + fn status(&self) -> Arc { + self.irq_status.clone() + } + + #[cfg(test)] + fn has_pending_interrupt(&self, interrupt_type: VirtioInterruptType) -> bool { + if let Ok(num_irqs) = self.irq_evt.read() { + if num_irqs == 0 { + return false; + } + + let irq_status = self.irq_status.load(Ordering::SeqCst); + return matches!( + (irq_status, interrupt_type.into()), + (VIRTIO_MMIO_INT_CONFIG, IrqType::Config) | (VIRTIO_MMIO_INT_VRING, IrqType::Vring) + ); + } + + false + } +} + +impl IrqTrigger { + pub fn new() -> Self { + Self { + irq_status: Arc::new(AtomicU32::new(0)), + irq_evt: EventFd::new(libc::EFD_NONBLOCK) + .expect("Could not create EventFd for IrqTrigger"), + } + } + + fn trigger_irq(&self, irq_type: IrqType) -> Result<(), std::io::Error> { + let irq = match irq_type { + IrqType::Config => VIRTIO_MMIO_INT_CONFIG, + IrqType::Vring => VIRTIO_MMIO_INT_VRING, + }; + self.irq_status.fetch_or(irq, Ordering::SeqCst); + + self.irq_evt.write(1).map_err(|err| { + error!("Failed to send irq to the guest: {:?}", err); + err + })?; + + Ok(()) + } +} + #[cfg(test)] pub(crate) mod tests { + + use std::ops::Deref; + use vmm_sys_util::eventfd::EventFd; use super::*; use crate::devices::virtio::ActivateError; - use crate::devices::virtio::device::IrqTrigger; use crate::devices::virtio::device_status::DEVICE_NEEDS_RESET; use crate::test_utils::single_region_mem; use crate::utils::byte_order::{read_le_u32, write_le_u32}; @@ -380,7 +476,7 @@ pub(crate) mod tests { pub(crate) struct DummyDevice { acked_features: u64, avail_features: u64, - interrupt_trigger: IrqTrigger, + interrupt_trigger: Option>, queue_evts: Vec, queues: Vec, device_activated: bool, @@ -393,7 +489,7 @@ pub(crate) mod tests { DummyDevice { acked_features: 0, avail_features: 0, - interrupt_trigger: IrqTrigger::new().unwrap(), + interrupt_trigger: None, queue_evts: vec![ EventFd::new(libc::EFD_NONBLOCK).unwrap(), EventFd::new(libc::EFD_NONBLOCK).unwrap(), @@ -439,8 +535,11 @@ pub(crate) mod tests { &self.queue_evts } - fn interrupt_trigger(&self) -> &IrqTrigger { - &self.interrupt_trigger + fn interrupt_trigger(&self) -> &dyn VirtioInterrupt { + self.interrupt_trigger + .as_ref() + .expect("Device is not activated") + .deref() } fn read_config(&self, offset: u64, data: &mut [u8]) { @@ -453,8 +552,13 @@ pub(crate) mod tests { } } - fn activate(&mut self, _: GuestMemoryMmap) -> Result<(), ActivateError> { + fn activate( + &mut self, + _: GuestMemoryMmap, + interrupt: Arc, + ) -> Result<(), ActivateError> { self.device_activated = true; + self.interrupt_trigger = Some(interrupt); if self.activate_should_error { Err(ActivateError::EventFd) } else { @@ -476,10 +580,11 @@ pub(crate) mod tests { #[test] fn test_new() { let m = single_region_mem(0x1000); + let interrupt = Arc::new(IrqTrigger::new()); let mut dummy = DummyDevice::new(); // Validate reset is no-op. assert!(dummy.reset().is_none()); - let mut d = MmioTransport::new(m, Arc::new(Mutex::new(dummy)), false); + let mut d = MmioTransport::new(m, interrupt, Arc::new(Mutex::new(dummy)), false); // We just make sure here that the implementation of a mmio device behaves as we expect, // given a known virtio device implementation (the dummy device). @@ -508,7 +613,13 @@ pub(crate) mod tests { #[test] fn test_bus_device_read() { let m = single_region_mem(0x1000); - let mut d = MmioTransport::new(m, Arc::new(Mutex::new(DummyDevice::new())), false); + let interrupt = Arc::new(IrqTrigger::new()); + let mut d = MmioTransport::new( + m, + interrupt, + Arc::new(Mutex::new(DummyDevice::new())), + false, + ); let mut buf = vec![0xff, 0, 0xfe, 0]; let buf_copy = buf.to_vec(); @@ -555,17 +666,18 @@ pub(crate) mod tests { d.bus_read(0x44, &mut buf[..]); assert_eq!(read_le_u32(&buf[..]), u32::from(false)); - d.interrupt_status.store(111, Ordering::SeqCst); + d.interrupt.irq_status.store(111, Ordering::SeqCst); d.bus_read(0x60, &mut buf[..]); assert_eq!(read_le_u32(&buf[..]), 111); d.is_vhost_user = true; - d.interrupt_status.store(0, Ordering::SeqCst); + d.interrupt.irq_status.store(0, Ordering::SeqCst); d.bus_read(0x60, &mut buf[..]); assert_eq!(read_le_u32(&buf[..]), VIRTIO_MMIO_INT_VRING); d.is_vhost_user = true; - d.interrupt_status + d.interrupt + .irq_status .store(VIRTIO_MMIO_INT_CONFIG, Ordering::SeqCst); d.bus_read(0x60, &mut buf[..]); assert_eq!(read_le_u32(&buf[..]), VIRTIO_MMIO_INT_CONFIG); @@ -597,8 +709,9 @@ pub(crate) mod tests { #[allow(clippy::cognitive_complexity)] fn test_bus_device_write() { let m = single_region_mem(0x1000); + let interrupt = Arc::new(IrqTrigger::new()); let dummy_dev = Arc::new(Mutex::new(DummyDevice::new())); - let mut d = MmioTransport::new(m, dummy_dev.clone(), false); + let mut d = MmioTransport::new(m, interrupt, dummy_dev.clone(), false); let mut buf = vec![0; 5]; write_le_u32(&mut buf[..4], 1); @@ -725,10 +838,10 @@ pub(crate) mod tests { | device_status::DRIVER_OK, ); - d.interrupt_status.store(0b10_1010, Ordering::Relaxed); + d.interrupt.irq_status.store(0b10_1010, Ordering::Relaxed); write_le_u32(&mut buf[..], 0b111); d.bus_write(0x64, &buf[..]); - assert_eq!(d.interrupt_status.load(Ordering::Relaxed), 0b10_1000); + assert_eq!(d.interrupt.irq_status.load(Ordering::Relaxed), 0b10_1000); // Write to an invalid address in generic register range. write_le_u32(&mut buf[..], 0xf); @@ -759,7 +872,13 @@ pub(crate) mod tests { #[test] fn test_bus_device_activate() { let m = single_region_mem(0x1000); - let mut d = MmioTransport::new(m, Arc::new(Mutex::new(DummyDevice::new())), false); + let interrupt = Arc::new(IrqTrigger::new()); + let mut d = MmioTransport::new( + m, + interrupt, + Arc::new(Mutex::new(DummyDevice::new())), + false, + ); assert!(!d.are_queues_valid()); assert!(!d.locked_device().is_activated()); @@ -838,11 +957,12 @@ pub(crate) mod tests { #[test] fn test_bus_device_activate_failure() { let m = single_region_mem(0x1000); + let interrupt = Arc::new(IrqTrigger::new()); let device = DummyDevice { activate_should_error: true, ..DummyDevice::new() }; - let mut d = MmioTransport::new(m, Arc::new(Mutex::new(device)), false); + let mut d = MmioTransport::new(m, interrupt, Arc::new(Mutex::new(device)), false); set_device_status(&mut d, device_status::ACKNOWLEDGE); set_device_status(&mut d, device_status::ACKNOWLEDGE | device_status::DRIVER); @@ -861,10 +981,6 @@ pub(crate) mod tests { d.bus_write(0x44, &buf[..]); } assert!(d.are_queues_valid()); - assert_eq!( - d.locked_device().interrupt_status().load(Ordering::SeqCst), - 0 - ); set_device_status( &mut d, @@ -885,7 +1001,8 @@ pub(crate) mod tests { assert_eq!( d.locked_device() .interrupt_trigger() - .irq_evt + .notifier(VirtioInterruptType::Config) + .unwrap() .read() .unwrap(), 1 @@ -934,7 +1051,13 @@ pub(crate) mod tests { #[test] fn test_bus_device_reset() { let m = single_region_mem(0x1000); - let mut d = MmioTransport::new(m, Arc::new(Mutex::new(DummyDevice::new())), false); + let interrupt = Arc::new(IrqTrigger::new()); + let mut d = MmioTransport::new( + m, + interrupt, + Arc::new(Mutex::new(DummyDevice::new())), + false, + ); let mut buf = [0; 4]; assert!(!d.are_queues_valid()); @@ -984,4 +1107,30 @@ pub(crate) mod tests { dummy_dev.ack_features_by_page(0, 8); assert_eq!(dummy_dev.acked_features(), 24); } + + #[test] + fn irq_trigger() { + let irq_trigger = IrqTrigger::new(); + assert_eq!(irq_trigger.irq_status.load(Ordering::SeqCst), 0); + + // Check that there are no pending irqs. + assert!(!irq_trigger.has_pending_interrupt(VirtioInterruptType::Config)); + assert!(!irq_trigger.has_pending_interrupt(VirtioInterruptType::Queue(0))); + + // Check that trigger_irq() correctly generates irqs. + irq_trigger.trigger(VirtioInterruptType::Config).unwrap(); + assert!(irq_trigger.has_pending_interrupt(VirtioInterruptType::Config)); + irq_trigger.irq_status.store(0, Ordering::SeqCst); + irq_trigger.trigger(VirtioInterruptType::Queue(0)).unwrap(); + assert!(irq_trigger.has_pending_interrupt(VirtioInterruptType::Queue(0))); + + // Check trigger_irq() failure case (irq_evt is full). + irq_trigger.irq_evt.write(u64::MAX - 1).unwrap(); + irq_trigger + .trigger(VirtioInterruptType::Config) + .unwrap_err(); + irq_trigger + .trigger(VirtioInterruptType::Queue(0)) + .unwrap_err(); + } } diff --git a/src/vmm/src/devices/virtio/transport/mod.rs b/src/vmm/src/devices/virtio/transport/mod.rs new file mode 100644 index 00000000000..d41ad943aa2 --- /dev/null +++ b/src/vmm/src/devices/virtio/transport/mod.rs @@ -0,0 +1,37 @@ +// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::sync::Arc; +use std::sync::atomic::AtomicU32; + +use vmm_sys_util::eventfd::EventFd; + +/// MMIO transport for VirtIO devices +pub mod mmio; + +/// Represents the types of interrupts used by VirtIO devices +#[derive(Debug, Clone)] +pub enum VirtioInterruptType { + /// Interrupt for VirtIO configuration changes + Config, + /// Interrupts for new events in a queue. + Queue(u16), +} + +/// API of interrupt types used by VirtIO devices +pub trait VirtioInterrupt: std::fmt::Debug + Send + Sync { + /// Trigger a VirtIO interrupt. + fn trigger(&self, interrupt_type: VirtioInterruptType) -> Result<(), std::io::Error>; + + /// Get the `EventFd` (if any) that backs the underlying interrupt. + fn notifier(&self, _interrupt_type: VirtioInterruptType) -> Option<&EventFd> { + None + } + + /// Get the current device interrupt status. + fn status(&self) -> Arc; + + /// Returns true if there is any pending interrupt + #[cfg(test)] + fn has_pending_interrupt(&self, interrupt_type: VirtioInterruptType) -> bool; +} diff --git a/src/vmm/src/devices/virtio/vhost_user.rs b/src/vmm/src/devices/virtio/vhost_user.rs index 83174fbc4d3..53e479ef652 100644 --- a/src/vmm/src/devices/virtio/vhost_user.rs +++ b/src/vmm/src/devices/virtio/vhost_user.rs @@ -6,6 +6,7 @@ use std::os::fd::AsRawFd; use std::os::unix::net::UnixStream; +use std::sync::Arc; use vhost::vhost_user::message::*; use vhost::vhost_user::{Frontend, VhostUserFrontend}; @@ -13,8 +14,8 @@ use vhost::{Error as VhostError, VhostBackend, VhostUserMemoryRegionInfo, VringC use vm_memory::{Address, Error as MmapError, GuestMemory, GuestMemoryError, GuestMemoryRegion}; use vmm_sys_util::eventfd::EventFd; -use crate::devices::virtio::device::IrqTrigger; use crate::devices::virtio::queue::Queue; +use crate::devices::virtio::transport::{VirtioInterrupt, VirtioInterruptType}; use crate::vstate::memory::GuestMemoryMmap; /// vhost-user error. @@ -400,7 +401,7 @@ impl VhostUserHandleImpl { &mut self, mem: &GuestMemoryMmap, queues: &[(usize, &Queue, &EventFd)], - irq_trigger: &IrqTrigger, + interrupt: Arc, ) -> Result<(), VhostUserError> { // Provide the memory table to the backend. self.update_mem_table(mem)?; @@ -442,7 +443,17 @@ impl VhostUserHandleImpl { // No matter the queue, we set irq_evt for signaling the guest that buffers were // consumed. self.vu - .set_vring_call(*queue_index, &irq_trigger.irq_evt) + .set_vring_call( + *queue_index, + interrupt + .notifier(VirtioInterruptType::Queue( + (*queue_index).try_into().unwrap_or_else(|_| { + panic!("vhost-user: invalid queue index: {}", *queue_index) + }), + )) + .as_ref() + .unwrap(), + ) .map_err(VhostUserError::VhostUserSetVringCall)?; self.vu @@ -467,6 +478,7 @@ pub(crate) mod tests { use vmm_sys_util::tempfile::TempFile; use super::*; + use crate::devices::virtio::test_utils::default_interrupt; use crate::test_utils::create_tmp_socket; use crate::vstate::memory; use crate::vstate::memory::GuestAddress; @@ -899,11 +911,11 @@ pub(crate) mod tests { queue.initialize(&guest_memory).unwrap(); let event_fd = EventFd::new(0).unwrap(); - let irq_trigger = IrqTrigger::new().unwrap(); let queues = [(0, &queue, &event_fd)]; - vuh.setup_backend(&guest_memory, &queues, &irq_trigger) + let interrupt = default_interrupt(); + vuh.setup_backend(&guest_memory, &queues, interrupt.clone()) .unwrap(); // VhostUserHandleImpl should correctly send memory and queues information to @@ -927,7 +939,11 @@ pub(crate) mod tests { log_addr: None, }, base: queue.avail_ring_idx_get(), - call: irq_trigger.irq_evt.as_raw_fd(), + call: interrupt + .notifier(VirtioInterruptType::Queue(0u16)) + .as_ref() + .unwrap() + .as_raw_fd(), kick: event_fd.as_raw_fd(), enable: true, }; diff --git a/src/vmm/src/devices/virtio/vsock/device.rs b/src/vmm/src/devices/virtio/vsock/device.rs index aa114f6cccb..c9daf19fd94 100644 --- a/src/vmm/src/devices/virtio/vsock/device.rs +++ b/src/vmm/src/devices/virtio/vsock/device.rs @@ -6,7 +6,7 @@ // found in the THIRD-PARTY file. //! This is the `VirtioDevice` implementation for our vsock device. It handles the virtio-level -//! device logic: feature negociation, device configuration, and device activation. +//! device logic: feature negotiation, device configuration, and device activation. //! //! We aim to conform to the VirtIO v1.1 spec: //! https://docs.oasis-open.org/virtio/virtio/v1.1/virtio-v1.1.html @@ -21,6 +21,8 @@ //! - a backend FD. use std::fmt::Debug; +use std::ops::Deref; +use std::sync::Arc; use log::{error, warn}; use vmm_sys_util::eventfd::EventFd; @@ -30,9 +32,10 @@ use super::defs::uapi; use super::packet::{VSOCK_PKT_HDR_SIZE, VsockPacketRx, VsockPacketTx}; use super::{VsockBackend, defs}; use crate::devices::virtio::ActivateError; -use crate::devices::virtio::device::{DeviceState, IrqTrigger, IrqType, VirtioDevice}; +use crate::devices::virtio::device::{ActiveState, DeviceState, VirtioDevice}; use crate::devices::virtio::generated::virtio_config::{VIRTIO_F_IN_ORDER, VIRTIO_F_VERSION_1}; use crate::devices::virtio::queue::Queue as VirtQueue; +use crate::devices::virtio::transport::{VirtioInterrupt, VirtioInterruptType}; use crate::devices::virtio::vsock::VsockError; use crate::devices::virtio::vsock::metrics::METRICS; use crate::logger::IncMetric; @@ -61,7 +64,6 @@ pub struct Vsock { pub(crate) backend: B, pub(crate) avail_features: u64, pub(crate) acked_features: u64, - pub(crate) irq_trigger: IrqTrigger, // This EventFd is the only one initially registered for a vsock device, and is used to convert // a VirtioDevice::activate call into an EventHandler read event which allows the other events // (queue and backend related) to be registered post virtio device activation. That's @@ -102,7 +104,6 @@ where backend, avail_features: AVAIL_FEATURES, acked_features: 0, - irq_trigger: IrqTrigger::new().map_err(VsockError::EventFd)?, activate_evt: EventFd::new(libc::EFD_NONBLOCK).map_err(VsockError::EventFd)?, device_state: DeviceState::Inactive, rx_packet: VsockPacketRx::new()?, @@ -136,9 +137,14 @@ where /// Signal the guest driver that we've used some virtio buffers that it had previously made /// available. - pub fn signal_used_queue(&self) -> Result<(), DeviceError> { - self.irq_trigger - .trigger_irq(IrqType::Vring) + pub fn signal_used_queue(&self, qidx: usize) -> Result<(), DeviceError> { + self.device_state + .active_state() + .expect("Device is not initialized") + .interrupt + .trigger(VirtioInterruptType::Queue(qidx.try_into().unwrap_or_else( + |_| panic!("vsock: invalid queue index: {qidx}"), + ))) .map_err(DeviceError::FailedSignalingIrq) } @@ -147,7 +153,7 @@ where /// otherwise. pub fn process_rx(&mut self) -> bool { // This is safe since we checked in the event handler that the device is activated. - let mem = self.device_state.mem().unwrap(); + let mem = &self.device_state.active_state().unwrap().mem; let mut have_used = false; @@ -200,7 +206,7 @@ where /// ring, and `false` otherwise. pub fn process_tx(&mut self) -> bool { // This is safe since we checked in the event handler that the device is activated. - let mem = self.device_state.mem().unwrap(); + let mem = &self.device_state.active_state().unwrap().mem; let mut have_used = false; @@ -242,7 +248,7 @@ where // remain but their CID is updated to reflect the current guest_cid. pub fn send_transport_reset_event(&mut self) -> Result<(), DeviceError> { // This is safe since we checked in the caller function that the device is activated. - let mem = self.device_state.mem().unwrap(); + let mem = &self.device_state.active_state().unwrap().mem; let head = self.queues[EVQ_INDEX].pop().ok_or_else(|| { METRICS.ev_queue_event_fails.inc(); @@ -258,7 +264,7 @@ where error!("Failed to add used descriptor {}: {}", head.index, err); }); - self.signal_used_queue()?; + self.signal_used_queue(EVQ_INDEX)?; Ok(()) } @@ -296,8 +302,12 @@ where &self.queue_events } - fn interrupt_trigger(&self) -> &IrqTrigger { - &self.irq_trigger + fn interrupt_trigger(&self) -> &dyn VirtioInterrupt { + self.device_state + .active_state() + .expect("Device is not initialized") + .interrupt + .deref() } fn read_config(&self, offset: u64, data: &mut [u8]) { @@ -329,7 +339,11 @@ where ); } - fn activate(&mut self, mem: GuestMemoryMmap) -> Result<(), ActivateError> { + fn activate( + &mut self, + mem: GuestMemoryMmap, + interrupt: Arc, + ) -> Result<(), ActivateError> { for q in self.queues.iter_mut() { q.initialize(&mem) .map_err(ActivateError::QueueMemoryError)?; @@ -348,7 +362,7 @@ where return Err(ActivateError::EventFd); } - self.device_state = DeviceState::Activated(mem); + self.device_state = DeviceState::Activated(ActiveState { mem, interrupt }); Ok(()) } @@ -431,6 +445,8 @@ mod tests { // } // Test a correct activation. - ctx.device.activate(ctx.mem.clone()).unwrap(); + ctx.device + .activate(ctx.mem.clone(), ctx.interrupt.clone()) + .unwrap(); } } diff --git a/src/vmm/src/devices/virtio/vsock/event_handler.rs b/src/vmm/src/devices/virtio/vsock/event_handler.rs index 632148546e5..e1b2876a0f3 100755 --- a/src/vmm/src/devices/virtio/vsock/event_handler.rs +++ b/src/vmm/src/devices/virtio/vsock/event_handler.rs @@ -190,9 +190,10 @@ where Self::PROCESS_EVQ => raise_irq = self.handle_evq_event(evset), Self::PROCESS_NOTIFY_BACKEND => raise_irq = self.notify_backend(evset), _ => warn!("Unexpected vsock event received: {:?}", source), - } + }; if raise_irq { - self.signal_used_queue().unwrap_or_default(); + self.signal_used_queue(source as usize) + .expect("vsock: Could not trigger device interrupt"); } } else { warn!( @@ -233,7 +234,7 @@ mod tests { { let test_ctx = TestContext::new(); let mut ctx = test_ctx.create_event_handler_context(); - ctx.mock_activate(test_ctx.mem.clone()); + ctx.mock_activate(test_ctx.mem.clone(), test_ctx.interrupt.clone()); ctx.device.backend.set_pending_rx(false); ctx.signal_txq_event(); @@ -250,7 +251,7 @@ mod tests { { let test_ctx = TestContext::new(); let mut ctx = test_ctx.create_event_handler_context(); - ctx.mock_activate(test_ctx.mem.clone()); + ctx.mock_activate(test_ctx.mem.clone(), test_ctx.interrupt.clone()); ctx.device.backend.set_pending_rx(true); ctx.signal_txq_event(); @@ -266,7 +267,7 @@ mod tests { { let test_ctx = TestContext::new(); let mut ctx = test_ctx.create_event_handler_context(); - ctx.mock_activate(test_ctx.mem.clone()); + ctx.mock_activate(test_ctx.mem.clone(), test_ctx.interrupt.clone()); ctx.device.backend.set_pending_rx(false); ctx.device.backend.set_tx_err(Some(VsockError::NoData)); @@ -282,7 +283,7 @@ mod tests { { let test_ctx = TestContext::new(); let mut ctx = test_ctx.create_event_handler_context(); - ctx.mock_activate(test_ctx.mem.clone()); + ctx.mock_activate(test_ctx.mem.clone(), test_ctx.interrupt.clone()); // Invalidate the descriptor chain, by setting its length to 0. ctx.guest_txvq.dtable[0].len.set(0); @@ -299,7 +300,7 @@ mod tests { { let test_ctx = TestContext::new(); let mut ctx = test_ctx.create_event_handler_context(); - ctx.mock_activate(test_ctx.mem.clone()); + ctx.mock_activate(test_ctx.mem.clone(), test_ctx.interrupt.clone()); assert!(!ctx.device.handle_txq_event(EventSet::IN)); } @@ -314,7 +315,7 @@ mod tests { { let test_ctx = TestContext::new(); let mut ctx = test_ctx.create_event_handler_context(); - ctx.mock_activate(test_ctx.mem.clone()); + ctx.mock_activate(test_ctx.mem.clone(), test_ctx.interrupt.clone()); ctx.device.backend.set_pending_rx(true); ctx.device.backend.set_rx_err(Some(VsockError::NoData)); @@ -331,7 +332,7 @@ mod tests { { let test_ctx = TestContext::new(); let mut ctx = test_ctx.create_event_handler_context(); - ctx.mock_activate(test_ctx.mem.clone()); + ctx.mock_activate(test_ctx.mem.clone(), test_ctx.interrupt.clone()); ctx.device.backend.set_pending_rx(true); ctx.signal_rxq_event(); @@ -344,7 +345,7 @@ mod tests { { let test_ctx = TestContext::new(); let mut ctx = test_ctx.create_event_handler_context(); - ctx.mock_activate(test_ctx.mem.clone()); + ctx.mock_activate(test_ctx.mem.clone(), test_ctx.interrupt.clone()); // Invalidate the descriptor chain, by setting its length to 0. ctx.guest_rxvq.dtable[0].len.set(0); @@ -360,7 +361,7 @@ mod tests { { let test_ctx = TestContext::new(); let mut ctx = test_ctx.create_event_handler_context(); - ctx.mock_activate(test_ctx.mem.clone()); + ctx.mock_activate(test_ctx.mem.clone(), test_ctx.interrupt.clone()); ctx.device.backend.set_pending_rx(false); assert!(!ctx.device.handle_rxq_event(EventSet::IN)); } @@ -385,7 +386,7 @@ mod tests { { let test_ctx = TestContext::new(); let mut ctx = test_ctx.create_event_handler_context(); - ctx.mock_activate(test_ctx.mem.clone()); + ctx.mock_activate(test_ctx.mem.clone(), test_ctx.interrupt.clone()); ctx.device.backend.set_pending_rx(true); ctx.device.notify_backend(EventSet::IN); @@ -404,7 +405,7 @@ mod tests { { let test_ctx = TestContext::new(); let mut ctx = test_ctx.create_event_handler_context(); - ctx.mock_activate(test_ctx.mem.clone()); + ctx.mock_activate(test_ctx.mem.clone(), test_ctx.interrupt.clone()); ctx.device.backend.set_pending_rx(false); ctx.device.notify_backend(EventSet::IN); @@ -447,7 +448,7 @@ mod tests { { let mut ctx = test_ctx.create_event_handler_context(); - // When modifiyng the buffer descriptor, make sure the len field is altered in the + // When modifying the buffer descriptor, make sure the len field is altered in the // vsock packet header descriptor as well. if desc_idx == 1 { // The vsock packet len field has offset 24 in the header. @@ -575,7 +576,7 @@ mod tests { vsock .lock() .unwrap() - .activate(test_ctx.mem.clone()) + .activate(test_ctx.mem.clone(), test_ctx.interrupt.clone()) .unwrap(); // Process the activate event. let ev_count = event_manager.run_with_timeout(50).unwrap(); diff --git a/src/vmm/src/devices/virtio/vsock/persist.rs b/src/vmm/src/devices/virtio/vsock/persist.rs index fce6affae69..9d2fd61d9d5 100644 --- a/src/vmm/src/devices/virtio/vsock/persist.rs +++ b/src/vmm/src/devices/virtio/vsock/persist.rs @@ -5,14 +5,14 @@ use std::fmt::Debug; use std::sync::Arc; -use std::sync::atomic::AtomicU32; use serde::{Deserialize, Serialize}; use super::*; -use crate::devices::virtio::device::DeviceState; +use crate::devices::virtio::device::{ActiveState, DeviceState}; use crate::devices::virtio::persist::VirtioDeviceState; use crate::devices::virtio::queue::FIRECRACKER_MAX_QUEUE_SIZE; +use crate::devices::virtio::transport::VirtioInterrupt; use crate::devices::virtio::vsock::TYPE_VSOCK; use crate::snapshot::Persist; use crate::vstate::memory::GuestMemoryMmap; @@ -29,7 +29,7 @@ pub struct VsockState { /// The Vsock frontend serializable state. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct VsockFrontendState { - /// Context IDentifier. + /// Context Identifier. pub cid: u64, virtio_state: VirtioDeviceState, } @@ -53,6 +53,8 @@ pub struct VsockUdsState { pub struct VsockConstructorArgs { /// Pointer to guest memory. pub mem: GuestMemoryMmap, + /// Interrupt to use for the device. + pub interrupt: Arc, /// The vsock Unix Backend. pub backend: B, } @@ -121,10 +123,11 @@ where vsock.acked_features = state.virtio_state.acked_features; vsock.avail_features = state.virtio_state.avail_features; - vsock.irq_trigger.irq_status = - Arc::new(AtomicU32::new(state.virtio_state.interrupt_status)); vsock.device_state = if state.virtio_state.activated { - DeviceState::Activated(constructor_args.mem) + DeviceState::Activated(ActiveState { + mem: constructor_args.mem, + interrupt: constructor_args.interrupt, + }) } else { DeviceState::Inactive }; @@ -137,6 +140,7 @@ pub(crate) mod tests { use super::device::AVAIL_FEATURES; use super::*; use crate::devices::virtio::device::VirtioDevice; + use crate::devices::virtio::test_utils::default_interrupt; use crate::devices::virtio::vsock::defs::uapi; use crate::devices::virtio::vsock::test_utils::{TestBackend, TestContext}; use crate::snapshot::Snapshot; @@ -189,6 +193,7 @@ pub(crate) mod tests { let mut restored_device = Vsock::restore( VsockConstructorArgs { mem: ctx.mem.clone(), + interrupt: default_interrupt(), backend: match restored_state.backend { VsockBackendState::Uds(uds_state) => { assert_eq!(uds_state.path, "test".to_owned()); diff --git a/src/vmm/src/devices/virtio/vsock/test_utils.rs b/src/vmm/src/devices/virtio/vsock/test_utils.rs index 804f0442559..0db293466c6 100644 --- a/src/vmm/src/devices/virtio/vsock/test_utils.rs +++ b/src/vmm/src/devices/virtio/vsock/test_utils.rs @@ -5,6 +5,7 @@ #![doc(hidden)] use std::os::unix::io::{AsRawFd, RawFd}; +use std::sync::Arc; use vmm_sys_util::epoll::EventSet; use vmm_sys_util::eventfd::EventFd; @@ -12,7 +13,8 @@ use vmm_sys_util::eventfd::EventFd; use super::packet::{VsockPacketRx, VsockPacketTx}; use crate::devices::virtio::device::VirtioDevice; use crate::devices::virtio::queue::{VIRTQ_DESC_F_NEXT, VIRTQ_DESC_F_WRITE}; -use crate::devices::virtio::test_utils::VirtQueue as GuestQ; +use crate::devices::virtio::test_utils::{VirtQueue as GuestQ, default_interrupt}; +use crate::devices::virtio::transport::VirtioInterrupt; use crate::devices::virtio::vsock::device::{RXQ_INDEX, TXQ_INDEX}; use crate::devices::virtio::vsock::packet::VSOCK_PKT_HDR_SIZE; use crate::devices::virtio::vsock::{ @@ -117,6 +119,7 @@ impl VsockBackend for TestBackend {} pub struct TestContext { pub cid: u64, pub mem: GuestMemoryMmap, + pub interrupt: Arc, pub mem_size: usize, pub device: Vsock, } @@ -129,6 +132,7 @@ impl TestContext { Self { cid: CID, mem, + interrupt: default_interrupt(), mem_size: MEM_SIZE, device: Vsock::new(CID, TestBackend::new()).unwrap(), } @@ -191,9 +195,9 @@ pub struct EventHandlerContext<'a> { } impl EventHandlerContext<'_> { - pub fn mock_activate(&mut self, mem: GuestMemoryMmap) { + pub fn mock_activate(&mut self, mem: GuestMemoryMmap, interrupt: Arc) { // Artificially activate the device. - self.device.activate(mem).unwrap(); + self.device.activate(mem, interrupt).unwrap(); } pub fn signal_txq_event(&mut self) {