diff --git a/Cargo.lock b/Cargo.lock index 0ea518f3a49..5a66c51b66e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -248,6 +248,19 @@ version = "2.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2261d10cca569e4643e526d8dc2e62e433cc8aba21ab764233731f8d369bf394" +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "serde", + "tap", + "wyz", +] + [[package]] name = "bumpalo" version = "3.19.0" @@ -648,6 +661,12 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + [[package]] name = "gdbstub" version = "0.7.7" @@ -1140,6 +1159,12 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + [[package]] name = "rand" version = "0.9.2" @@ -1414,6 +1439,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + [[package]] name = "thiserror" version = "1.0.69" @@ -1682,6 +1713,7 @@ dependencies = [ "base64", "bincode", "bitflags 2.9.4", + "bitvec", "byteorder", "crc64", "criterion", @@ -2008,6 +2040,15 @@ version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" +[[package]] +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +dependencies = [ + "tap", +] + [[package]] name = "zerocopy" version = "0.8.27" diff --git a/docs/device-api.md b/docs/device-api.md index 01e470f1d64..0b8035651f8 100644 --- a/docs/device-api.md +++ b/docs/device-api.md @@ -104,6 +104,7 @@ specification: | `MemoryHotplugConfig` | total_size_mib | O | O | O | O | O | O | O | **R** | | | slot_size_mib | O | O | O | O | O | O | O | **R** | | | block_size_mi | O | O | O | O | O | O | O | **R** | +| `MemoryHotplugSizeUpdate` | requested_size_mib | O | O | O | O | O | O | O | **R** | \* `Drive`'s `drive_id`, `is_root_device` and `partuuid` can be configured by either virtio-block or vhost-user-block devices. diff --git a/src/firecracker/examples/uffd/on_demand_handler.rs b/src/firecracker/examples/uffd/on_demand_handler.rs index 3be958b3578..3101aa19253 100644 --- a/src/firecracker/examples/uffd/on_demand_handler.rs +++ b/src/firecracker/examples/uffd/on_demand_handler.rs @@ -87,7 +87,7 @@ fn main() { } } userfaultfd::Event::Remove { start, end } => { - uffd_handler.mark_range_removed(start as u64, end as u64) + uffd_handler.unregister_range(start, end) } _ => panic!("Unexpected event on userfaultfd"), } diff --git a/src/firecracker/examples/uffd/uffd_utils.rs b/src/firecracker/examples/uffd/uffd_utils.rs index 97c6150b65b..30aa16f040a 100644 --- a/src/firecracker/examples/uffd/uffd_utils.rs +++ b/src/firecracker/examples/uffd/uffd_utils.rs @@ -9,7 +9,7 @@ dead_code )] -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::ffi::c_void; use std::fs::File; use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd}; @@ -54,7 +54,6 @@ pub struct UffdHandler { pub page_size: usize, backing_buffer: *const u8, uffd: Uffd, - removed_pages: HashSet, } impl UffdHandler { @@ -125,7 +124,6 @@ impl UffdHandler { page_size, backing_buffer, uffd, - removed_pages: HashSet::new(), } } @@ -133,24 +131,18 @@ impl UffdHandler { self.uffd.read_event() } - pub fn mark_range_removed(&mut self, start: u64, end: u64) { - let pfn_start = start / self.page_size as u64; - let pfn_end = end / self.page_size as u64; - - for pfn in pfn_start..pfn_end { - self.removed_pages.insert(pfn); - } + pub fn unregister_range(&mut self, start: *mut c_void, end: *mut c_void) { + // SAFETY: start and end are valid and provided by UFFD + let len = unsafe { end.offset_from_unsigned(start) }; + self.uffd + .unregister(start, len) + .expect("range should be valid"); } pub fn serve_pf(&mut self, addr: *mut u8, len: usize) -> bool { // Find the start of the page that the current faulting address belongs to. let dst = (addr as usize & !(self.page_size - 1)) as *mut libc::c_void; let fault_page_addr = dst as u64; - let fault_pfn = fault_page_addr / self.page_size as u64; - - if self.removed_pages.contains(&fault_pfn) { - return self.zero_out(fault_page_addr); - } for region in self.mem_regions.iter() { if region.contains(fault_page_addr) { @@ -193,14 +185,6 @@ impl UffdHandler { true } - - fn zero_out(&mut self, addr: u64) -> bool { - match unsafe { self.uffd.zeropage(addr as *mut _, self.page_size, true) } { - Ok(_) => true, - Err(Error::ZeropageFailed(error)) if error as i32 == libc::EAGAIN => false, - r => panic!("Unexpected zeropage result: {:?}", r), - } - } } #[derive(Debug)] diff --git a/src/firecracker/src/api_server/parsed_request.rs b/src/firecracker/src/api_server/parsed_request.rs index 287742ede41..3d21695ce3e 100644 --- a/src/firecracker/src/api_server/parsed_request.rs +++ b/src/firecracker/src/api_server/parsed_request.rs @@ -28,7 +28,7 @@ use super::request::snapshot::{parse_patch_vm_state, parse_put_snapshot}; use super::request::version::parse_get_version; use super::request::vsock::parse_put_vsock; use crate::api_server::request::hotplug::memory::{ - parse_get_memory_hotplug, parse_put_memory_hotplug, + parse_get_memory_hotplug, parse_patch_memory_hotplug, parse_put_memory_hotplug, }; use crate::api_server::request::serial::parse_put_serial; @@ -119,6 +119,9 @@ impl TryFrom<&Request> for ParsedRequest { parse_patch_net(body, path_tokens.next()) } (Method::Patch, "vm", Some(body)) => parse_patch_vm_state(body), + (Method::Patch, "hotplug", Some(body)) if path_tokens.next() == Some("memory") => { + parse_patch_memory_hotplug(body) + } (Method::Patch, _, None) => method_to_error(Method::Patch), (method, unknown_uri, _) => Err(RequestError::InvalidPathMethod( unknown_uri.to_string(), diff --git a/src/firecracker/src/api_server/request/hotplug/memory.rs b/src/firecracker/src/api_server/request/hotplug/memory.rs index 4bdeec73a6d..5ec514ca964 100644 --- a/src/firecracker/src/api_server/request/hotplug/memory.rs +++ b/src/firecracker/src/api_server/request/hotplug/memory.rs @@ -4,7 +4,7 @@ use micro_http::Body; use vmm::logger::{IncMetric, METRICS}; use vmm::rpc_interface::VmmAction; -use vmm::vmm_config::memory_hotplug::MemoryHotplugConfig; +use vmm::vmm_config::memory_hotplug::{MemoryHotplugConfig, MemoryHotplugSizeUpdate}; use crate::api_server::parsed_request::{ParsedRequest, RequestError}; @@ -23,11 +23,23 @@ pub(crate) fn parse_get_memory_hotplug() -> Result Ok(ParsedRequest::new_sync(VmmAction::GetMemoryHotplugStatus)) } +pub(crate) fn parse_patch_memory_hotplug(body: &Body) -> Result { + METRICS.patch_api_requests.hotplug_memory_count.inc(); + let config = + serde_json::from_slice::(body.raw()).inspect_err(|_| { + METRICS.patch_api_requests.hotplug_memory_fails.inc(); + })?; + Ok(ParsedRequest::new_sync(VmmAction::UpdateMemoryHotplugSize( + config, + ))) +} + #[cfg(test)] mod tests { use vmm::devices::virtio::mem::{ VIRTIO_MEM_DEFAULT_BLOCK_SIZE_MIB, VIRTIO_MEM_DEFAULT_SLOT_SIZE_MIB, }; + use vmm::vmm_config::memory_hotplug::MemoryHotplugSizeUpdate; use super::*; use crate::api_server::parsed_request::tests::vmm_action_from_request; @@ -80,4 +92,27 @@ mod tests { VmmAction::GetMemoryHotplugStatus ); } + + #[test] + fn test_parse_patch_memory_hotplug_request() { + parse_patch_memory_hotplug(&Body::new("invalid_payload")).unwrap_err(); + + // PATCH with invalid fields. + let body = r#"{ + "requested_size_mib": "bar" + }"#; + parse_patch_memory_hotplug(&Body::new(body)).unwrap_err(); + + // PATCH with valid input fields. + let body = r#"{ + "requested_size_mib": 2048 + }"#; + let expected_config = MemoryHotplugSizeUpdate { + requested_size_mib: 2048, + }; + assert_eq!( + vmm_action_from_request(parse_patch_memory_hotplug(&Body::new(body)).unwrap()), + VmmAction::UpdateMemoryHotplugSize(expected_config) + ); + } } diff --git a/src/firecracker/swagger/firecracker.yaml b/src/firecracker/swagger/firecracker.yaml index c5011a79fd7..bddc4942c2a 100644 --- a/src/firecracker/swagger/firecracker.yaml +++ b/src/firecracker/swagger/firecracker.yaml @@ -549,6 +549,26 @@ paths: description: Internal server error schema: $ref: "#/definitions/Error" + patch: + summary: Updates the size of the hotpluggable memory region + operationId: patchMemoryHotplug + description: + Updates the size of the hotpluggable memory region. The guest will plug and unplug memory to + hit the requested memory. + parameters: + - name: body + in: body + description: Hotpluggable memory size update + required: true + schema: + $ref: "#/definitions/MemoryHotplugSizeUpdate" + responses: + 204: + description: Hotpluggable memory configured + default: + description: Internal server error + schema: + $ref: "#/definitions/Error" get: summary: Retrieves the status of the hotpluggable memory operationId: getMemoryHotplug @@ -1422,6 +1442,15 @@ definitions: description: (Logical) Block size for the hotpluggable memory in MiB. This will determine the logical granularity of hot-plug memory for the guest. Refer to the device documentation on how to tune this value. + MemoryHotplugSizeUpdate: + type: object + description: + An update to the size of the hotpluggable memory region. + properties: + requested_size_mib: + type: integer + description: New target region size. + MemoryHotplugStatus: type: object description: diff --git a/src/vmm/Cargo.toml b/src/vmm/Cargo.toml index d2601f4a305..aa6f9a2fc47 100644 --- a/src/vmm/Cargo.toml +++ b/src/vmm/Cargo.toml @@ -23,6 +23,7 @@ aws-lc-rs = { version = "1.14.0", features = ["bindgen"] } base64 = "0.22.1" bincode = { version = "2.0.1", features = ["serde"] } bitflags = "2.9.4" +bitvec = { version = "1.0.1", features = ["atomic", "serde"] } byteorder = "1.5.0" crc64 = "2.0.0" derive_more = { version = "2.0.1", default-features = false, features = [ diff --git a/src/vmm/src/builder.rs b/src/vmm/src/builder.rs index cb901a78c63..5556c62e44f 100644 --- a/src/vmm/src/builder.rs +++ b/src/vmm/src/builder.rs @@ -32,7 +32,7 @@ use crate::device_manager::{ use crate::devices::acpi::vmgenid::VmGenIdError; use crate::devices::virtio::balloon::Balloon; use crate::devices::virtio::block::device::Block; -use crate::devices::virtio::mem::VirtioMem; +use crate::devices::virtio::mem::{VIRTIO_MEM_GUEST_ADDRESS, VirtioMem}; use crate::devices::virtio::net::Net; use crate::devices::virtio::rng::Entropy; use crate::devices::virtio::vsock::{Vsock, VsockUnixBackend}; @@ -44,6 +44,7 @@ use crate::persist::{MicrovmState, MicrovmStateError}; use crate::resources::VmResources; use crate::seccomp::BpfThreadMap; use crate::snapshot::Persist; +use crate::utils::mib_to_bytes; use crate::vmm_config::instance_info::InstanceInfo; use crate::vmm_config::machine_config::MachineConfigError; use crate::vmm_config::memory_hotplug::MemoryHotplugConfig; @@ -172,6 +173,18 @@ pub fn build_microvm_for_boot( let (mut vcpus, vcpus_exit_evt) = vm.create_vcpus(vm_resources.machine_config.vcpu_count)?; vm.register_dram_memory_regions(guest_memory)?; + // Allocate memory as soon as possible to make hotpluggable memory available to all consumers, + // before they clone the GuestMemoryMmap object + if let Some(memory_hotplug) = &vm_resources.memory_hotplug { + let hotplug_memory_region = vm_resources + .allocate_memory_region( + VIRTIO_MEM_GUEST_ADDRESS, + mib_to_bytes(memory_hotplug.total_size_mib), + ) + .map_err(StartMicrovmError::GuestMemory)?; + vm.register_hotpluggable_memory_region(hotplug_memory_region)?; + } + let mut device_manager = DeviceManager::new( event_manager, &vcpus_exit_evt, diff --git a/src/vmm/src/devices/virtio/mem/device.rs b/src/vmm/src/devices/virtio/mem/device.rs index c8bcb6cbf53..8c53df4751f 100644 --- a/src/vmm/src/devices/virtio/mem/device.rs +++ b/src/vmm/src/devices/virtio/mem/device.rs @@ -2,35 +2,38 @@ // SPDX-License-Identifier: Apache-2.0 use std::io; -use std::ops::Deref; +use std::ops::{Deref, Range}; use std::sync::Arc; use std::sync::atomic::AtomicU32; +use bitvec::vec::BitVec; use log::info; use serde::{Deserialize, Serialize}; use vm_memory::{ - Address, GuestAddress, GuestMemory, GuestMemoryError, GuestMemoryRegion, GuestUsize, + Address, Bytes, GuestAddress, GuestMemory, GuestMemoryError, GuestMemoryRegion, GuestUsize, }; use vmm_sys_util::eventfd::EventFd; use super::{MEM_NUM_QUEUES, MEM_QUEUE}; -use crate::devices::DeviceError; use crate::devices::virtio::ActivateError; use crate::devices::virtio::device::{ActiveState, DeviceState, VirtioDevice}; use crate::devices::virtio::generated::virtio_config::VIRTIO_F_VERSION_1; use crate::devices::virtio::generated::virtio_ids::VIRTIO_ID_MEM; use crate::devices::virtio::generated::virtio_mem::{ - VIRTIO_MEM_F_UNPLUGGED_INACCESSIBLE, virtio_mem_config, + self, VIRTIO_MEM_F_UNPLUGGED_INACCESSIBLE, virtio_mem_config, }; use crate::devices::virtio::iov_deque::IovDequeError; use crate::devices::virtio::mem::metrics::METRICS; +use crate::devices::virtio::mem::request::{BlockRangeState, Request, RequestedRange, Response}; use crate::devices::virtio::mem::{VIRTIO_MEM_DEV_ID, VIRTIO_MEM_GUEST_ADDRESS}; -use crate::devices::virtio::queue::{FIRECRACKER_MAX_QUEUE_SIZE, InvalidAvailIdx, Queue}; +use crate::devices::virtio::queue::{ + DescriptorChain, FIRECRACKER_MAX_QUEUE_SIZE, InvalidAvailIdx, Queue, QueueError, +}; use crate::devices::virtio::transport::{VirtioInterrupt, VirtioInterruptType}; use crate::logger::{IncMetric, debug, error}; use crate::utils::{bytes_to_mib, mib_to_bytes, u64_to_usize, usize_to_u64}; use crate::vstate::interrupts::InterruptError; -use crate::vstate::memory::{ByteValued, GuestMemoryMmap, GuestRegionMmap}; +use crate::vstate::memory::{ByteValued, GuestMemoryExtension, GuestMemoryMmap, GuestRegionMmap}; use crate::vstate::vm::VmError; use crate::{Vm, impl_device_type}; @@ -43,6 +46,36 @@ pub enum VirtioMemError { EventFd(#[from] io::Error), /// Received error while sending an interrupt: {0} InterruptError(#[from] InterruptError), + /// Size {0} is invalid: it must be a multiple of block size and less than the total size + InvalidSize(u64), + /// Device is not active + DeviceNotActive, + /// Descriptor is write-only + UnexpectedWriteOnlyDescriptor, + /// Error reading virtio descriptor + DescriptorWriteFailed, + /// Error writing virtio descriptor + DescriptorReadFailed, + /// Unknown request type: {0:?} + UnknownRequestType(u32), + /// Descriptor chain is too short + DescriptorChainTooShort, + /// Descriptor is too small + DescriptorLengthTooSmall, + /// Descriptor is read-only + UnexpectedReadOnlyDescriptor, + /// Error popping from virtio queue: {0} + InvalidAvailIdx(#[from] InvalidAvailIdx), + /// Error adding used queue: {0} + QueueError(#[from] QueueError), + /// Invalid requested range: {0:?}. + InvalidRange(RequestedRange), + /// The requested range cannot be plugged because it's {0:?}. + PlugRequestBlockStateInvalid(BlockRangeState), + /// Plug request rejected as plugged_size would be greater than requested_size + PlugRequestIsTooBig, + /// The requested range cannot be unplugged because it's {0:?}. + UnplugRequestBlockStateInvalid(BlockRangeState), } #[derive(Debug)] @@ -60,6 +93,8 @@ pub struct VirtioMem { // Device specific fields pub(crate) config: virtio_mem_config, pub(crate) slot_size: usize, + // Bitmap to track which blocks are plugged + pub(crate) plugged_blocks: BitVec, vm: Arc, } @@ -93,8 +128,15 @@ impl VirtioMem { block_size: mib_to_bytes(block_size_mib) as u64, ..Default::default() }; + let plugged_blocks = BitVec::repeat(false, total_size_mib / block_size_mib); - Self::from_state(vm, queues, config, mib_to_bytes(slot_size_mib)) + Self::from_state( + vm, + queues, + config, + mib_to_bytes(slot_size_mib), + plugged_blocks, + ) } pub fn from_state( @@ -102,6 +144,7 @@ impl VirtioMem { queues: Vec, config: virtio_mem_config, slot_size: usize, + plugged_blocks: BitVec, ) -> Result { let activate_event = EventFd::new(libc::EFD_NONBLOCK)?; let queue_events = (0..MEM_NUM_QUEUES) @@ -118,6 +161,7 @@ impl VirtioMem { config, vm, slot_size, + plugged_blocks, }) } @@ -125,6 +169,10 @@ impl VirtioMem { VIRTIO_MEM_DEV_ID } + pub fn guest_address(&self) -> GuestAddress { + GuestAddress(self.config.addr) + } + /// Gets the total hotpluggable size. pub fn total_size_mib(&self) -> usize { bytes_to_mib(u64_to_usize(self.config.region_size)) @@ -166,8 +214,251 @@ impl VirtioMem { .map_err(VirtioMemError::InterruptError) } + fn guest_memory(&self) -> &GuestMemoryMmap { + &self.device_state.active_state().unwrap().mem + } + + fn nb_blocks_to_len(&self, nb_blocks: usize) -> usize { + nb_blocks * u64_to_usize(self.config.block_size) + } + + fn is_range_plugged(&self, range: &RequestedRange) -> BlockRangeState { + let plugged_count = self.plugged_blocks[self.unchecked_block_range(range)].count_ones(); + + match plugged_count { + nb_blocks if nb_blocks == range.nb_blocks => BlockRangeState::Plugged, + 0 => BlockRangeState::Unplugged, + _ => BlockRangeState::Mixed, + } + } + + fn parse_request( + &self, + avail_desc: &DescriptorChain, + ) -> Result<(Request, GuestAddress, u16), VirtioMemError> { + // The head contains the request type which MUST be readable. + if avail_desc.is_write_only() { + return Err(VirtioMemError::UnexpectedWriteOnlyDescriptor); + } + + if (avail_desc.len as usize) < size_of::() { + return Err(VirtioMemError::DescriptorLengthTooSmall); + } + + let request: virtio_mem::virtio_mem_req = self + .guest_memory() + .read_obj(avail_desc.addr) + .map_err(|_| VirtioMemError::DescriptorReadFailed)?; + + let resp_desc = avail_desc + .next_descriptor() + .ok_or(VirtioMemError::DescriptorChainTooShort)?; + + // The response MUST always be writable. + if !resp_desc.is_write_only() { + return Err(VirtioMemError::UnexpectedReadOnlyDescriptor); + } + + if (resp_desc.len as usize) < std::mem::size_of::() { + return Err(VirtioMemError::DescriptorLengthTooSmall); + } + + Ok((request.into(), resp_desc.addr, avail_desc.index)) + } + + fn write_response( + &mut self, + resp: Response, + resp_addr: GuestAddress, + used_idx: u16, + ) -> Result<(), VirtioMemError> { + debug!("virtio-mem: Response: {:?}", resp); + self.guest_memory() + .write_obj(virtio_mem::virtio_mem_resp::from(resp), resp_addr) + .map_err(|_| VirtioMemError::DescriptorWriteFailed) + .map(|_| size_of::())?; + self.queues[MEM_QUEUE] + .add_used( + used_idx, + u32::try_from(std::mem::size_of::()).unwrap(), + ) + .map_err(VirtioMemError::QueueError) + } + + fn validate_range(&self, range: &RequestedRange) -> Result<(), VirtioMemError> { + // Ensure the range is aligned + if !range + .addr + .raw_value() + .is_multiple_of(self.config.block_size) + { + return Err(VirtioMemError::InvalidRange(*range)); + } + + if range.nb_blocks == 0 { + return Err(VirtioMemError::InvalidRange(*range)); + } + + // Ensure the start addr is within the usable region + let start_off = range + .addr + .checked_offset_from(GuestAddress(self.config.addr)) + .filter(|&off| off < self.config.usable_region_size) + .ok_or(VirtioMemError::InvalidRange(*range))?; + + // Ensure the end offset (exclusive) is within the usable region + let end_off = start_off + .checked_add(usize_to_u64(self.nb_blocks_to_len(range.nb_blocks))) + .filter(|&end_off| end_off <= self.config.usable_region_size) + .ok_or(VirtioMemError::InvalidRange(*range))?; + + Ok(()) + } + + fn unchecked_block_range(&self, range: &RequestedRange) -> Range { + let start_block = u64_to_usize((range.addr.0 - self.config.addr) / self.config.block_size); + + start_block..(start_block + range.nb_blocks) + } + + fn do_plug_request(&mut self, range: &RequestedRange) -> Result<(), VirtioMemError> { + self.validate_range(range)?; + + if self.config.plugged_size + usize_to_u64(self.nb_blocks_to_len(range.nb_blocks)) + > self.config.requested_size + { + return Err(VirtioMemError::PlugRequestIsTooBig); + } + + match self.is_range_plugged(range) { + BlockRangeState::Unplugged => self.plug_range(range, true), + state => Err(VirtioMemError::PlugRequestBlockStateInvalid(state)), + } + } + + fn handle_plug_request( + &mut self, + range: &RequestedRange, + resp_addr: GuestAddress, + used_idx: u16, + ) -> Result<(), VirtioMemError> { + METRICS.plug_count.inc(); + let _metric = METRICS.plug_agg.record_latency_metrics(); + + let response = self.do_plug_request(range).map_or_else( + |err| { + METRICS.plug_fails.inc(); + error!("virtio-mem: Failed to plug range: {}", err); + Response::error() + }, + |_| { + METRICS + .plug_bytes + .add(usize_to_u64(self.nb_blocks_to_len(range.nb_blocks))); + Response::ack() + }, + ); + self.write_response(response, resp_addr, used_idx) + } + + fn do_unplug_request(&mut self, range: &RequestedRange) -> Result<(), VirtioMemError> { + self.validate_range(range)?; + + match self.is_range_plugged(range) { + BlockRangeState::Plugged => self.plug_range(range, false), + state => Err(VirtioMemError::UnplugRequestBlockStateInvalid(state)), + } + } + + fn handle_unplug_request( + &mut self, + range: &RequestedRange, + resp_addr: GuestAddress, + used_idx: u16, + ) -> Result<(), VirtioMemError> { + METRICS.unplug_count.inc(); + let _metric = METRICS.unplug_agg.record_latency_metrics(); + let response = self.do_unplug_request(range).map_or_else( + |err| { + METRICS.unplug_fails.inc(); + error!("virtio-mem: Failed to unplug range: {}", err); + Response::error() + }, + |_| { + METRICS + .unplug_bytes + .add(usize_to_u64(self.nb_blocks_to_len(range.nb_blocks))); + Response::ack() + }, + ); + self.write_response(response, resp_addr, used_idx) + } + + fn handle_unplug_all_request( + &mut self, + resp_addr: GuestAddress, + used_idx: u16, + ) -> Result<(), VirtioMemError> { + METRICS.unplug_all_count.inc(); + let _metric = METRICS.unplug_all_agg.record_latency_metrics(); + let range = RequestedRange { + addr: GuestAddress(self.config.addr), + nb_blocks: self.plugged_blocks.len(), + }; + let response = self.plug_range(&range, false).map_or_else( + |err| { + METRICS.unplug_all_fails.inc(); + error!("virtio-mem: Failed to unplug all: {}", err); + Response::error() + }, + |_| { + self.config.usable_region_size = 0; + Response::ack() + }, + ); + self.write_response(response, resp_addr, used_idx) + } + + fn handle_state_request( + &mut self, + range: &RequestedRange, + resp_addr: GuestAddress, + used_idx: u16, + ) -> Result<(), VirtioMemError> { + METRICS.state_count.inc(); + let _metric = METRICS.state_agg.record_latency_metrics(); + let response = self.validate_range(range).map_or_else( + |err| { + METRICS.state_fails.inc(); + error!("virtio-mem: Failed to retrieve state of range: {}", err); + Response::error() + }, + |_| Response::ack_with_state(self.is_range_plugged(range)), + ); + self.write_response(response, resp_addr, used_idx) + } + fn process_mem_queue(&mut self) -> Result<(), VirtioMemError> { - info!("TODO: Received mem queue event, but it's not implemented."); + while let Some(desc) = self.queues[MEM_QUEUE].pop()? { + let index = desc.index; + + let (req, resp_addr, used_idx) = self.parse_request(&desc)?; + debug!("virtio-mem: Request: {:?}", req); + // Handle request and write response + match req { + Request::State(ref range) => self.handle_state_request(range, resp_addr, used_idx), + Request::Plug(ref range) => self.handle_plug_request(range, resp_addr, used_idx), + Request::Unplug(ref range) => { + self.handle_unplug_request(range, resp_addr, used_idx) + } + Request::UnplugAll => self.handle_unplug_all_request(resp_addr, used_idx), + Request::Unsupported(t) => Err(VirtioMemError::UnknownRequestType(t)), + }?; + } + + self.queues[MEM_QUEUE].advance_used_ring_idx(); + self.signal_used_queue()?; + Ok(()) } @@ -200,6 +491,70 @@ impl VirtioMem { pub(crate) fn activate_event(&self) -> &EventFd { &self.activate_event } + + fn plug_range(&mut self, range: &RequestedRange, plug: bool) -> Result<(), VirtioMemError> { + // Update internal state + let block_range = self.unchecked_block_range(range); + let plugged_blocks_slice = &mut self.plugged_blocks[block_range]; + let plugged_before = plugged_blocks_slice.count_ones(); + plugged_blocks_slice.fill(plug); + let plugged_after = plugged_blocks_slice.count_ones(); + self.config.plugged_size -= usize_to_u64(self.nb_blocks_to_len(plugged_before)); + self.config.plugged_size += usize_to_u64(self.nb_blocks_to_len(plugged_after)); + + // If unplugging, discard the range + if !plug { + self.guest_memory() + .discard_range(range.addr, self.nb_blocks_to_len(range.nb_blocks)) + .inspect_err(|err| { + // Failure to discard is not fatal and is not reported to the driver. It only + // gets logged. + METRICS.unplug_discard_fails.inc(); + error!("virtio-mem: Failed to discard memory range: {}", err); + }); + } + + // TODO: update KVM slots to plug/unplug them + + Ok(()) + } + + /// Updates the requested size of the virtio-mem device. + pub fn update_requested_size( + &mut self, + requested_size_mib: usize, + ) -> Result<(), VirtioMemError> { + let requested_size = usize_to_u64(mib_to_bytes(requested_size_mib)); + if !self.is_activated() { + return Err(VirtioMemError::DeviceNotActive); + } + + if requested_size % self.config.block_size != 0 { + return Err(VirtioMemError::InvalidSize(requested_size)); + } + if requested_size > self.config.region_size { + return Err(VirtioMemError::InvalidSize(requested_size)); + } + + // usable_region_size can only be increased + if self.config.usable_region_size < requested_size { + self.config.usable_region_size = + requested_size.next_multiple_of(usize_to_u64(self.slot_size)); + debug!( + "virtio-mem: Updated usable size to {} bytes", + self.config.usable_region_size + ); + } + + self.config.requested_size = requested_size; + debug!( + "virtio-mem: Updated requested size to {} bytes", + requested_size + ); + self.interrupt_trigger() + .trigger(VirtioInterruptType::Config) + .map_err(VirtioMemError::InterruptError) + } } impl VirtioDevice for VirtioMem { @@ -293,10 +648,34 @@ impl VirtioDevice for VirtioMem { #[cfg(test)] pub(crate) mod test_utils { use super::*; + use crate::devices::virtio::test_utils::test::VirtioTestDevice; + use crate::test_utils::single_region_mem; + use crate::vmm_config::machine_config::HugePageConfig; + use crate::vstate::memory; use crate::vstate::vm::tests::setup_vm_with_memory; + impl VirtioTestDevice for VirtioMem { + fn set_queues(&mut self, queues: Vec) { + self.queues = queues; + } + + fn num_queues() -> usize { + MEM_NUM_QUEUES + } + } + pub(crate) fn default_virtio_mem() -> VirtioMem { - let (_, vm) = setup_vm_with_memory(0x1000); + let (_, mut vm) = setup_vm_with_memory(0x1000); + vm.register_hotpluggable_memory_region( + memory::anonymous( + std::iter::once((VIRTIO_MEM_GUEST_ADDRESS, mib_to_bytes(1024))), + false, + HugePageConfig::None, + ) + .unwrap() + .pop() + .unwrap(), + ); let vm = Arc::new(vm); VirtioMem::new(vm, 1024, 2, 128).unwrap() } @@ -306,11 +685,15 @@ pub(crate) mod test_utils { mod tests { use std::ptr::null_mut; + use serde_json::de; + use vm_memory::guest_memory; use vm_memory::mmap::MmapRegionBuilder; use super::*; use crate::devices::virtio::device::VirtioDevice; use crate::devices::virtio::mem::device::test_utils::default_virtio_mem; + use crate::devices::virtio::queue::VIRTQ_DESC_F_WRITE; + use crate::devices::virtio::test_utils::test::VirtioTestHelper; use crate::vstate::vm::tests::setup_vm_with_memory; #[test] @@ -351,7 +734,18 @@ mod tests { usable_region_size, ..Default::default() }; - let mem = VirtioMem::from_state(vm, queues, config, mib_to_bytes(slot_size_mib)).unwrap(); + let plugged_blocks = BitVec::repeat( + false, + mib_to_bytes(region_size_mib) / mib_to_bytes(block_size_mib), + ); + let mem = VirtioMem::from_state( + vm, + queues, + config, + mib_to_bytes(slot_size_mib), + plugged_blocks, + ) + .unwrap(); assert_eq!(mem.total_size_mib(), region_size_mib); assert_eq!(mem.block_size_mib(), block_size_mib); assert_eq!(mem.slot_size_mib(), slot_size_mib); @@ -434,4 +828,478 @@ mod tests { } ); } + + #[allow(clippy::cast_possible_truncation)] + const REQ_SIZE: u32 = std::mem::size_of::() as u32; + #[allow(clippy::cast_possible_truncation)] + const RESP_SIZE: u32 = std::mem::size_of::() as u32; + + fn test_helper<'a>( + mut dev: VirtioMem, + mem: &'a GuestMemoryMmap, + ) -> VirtioTestHelper<'a, VirtioMem> { + dev.set_acked_features(dev.avail_features); + + let mut th = VirtioTestHelper::::new(mem, dev); + th.activate_device(mem); + th + } + + fn emulate_request( + th: &mut VirtioTestHelper, + mem: &GuestMemoryMmap, + req: Request, + ) -> Response { + th.add_desc_chain( + MEM_QUEUE, + 0, + &[(0, REQ_SIZE, 0), (1, RESP_SIZE, VIRTQ_DESC_F_WRITE)], + ); + mem.write_obj( + virtio_mem::virtio_mem_req::from(req), + th.desc_address(MEM_QUEUE, 0), + ) + .unwrap(); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + mem.read_obj::(th.desc_address(MEM_QUEUE, 1)) + .unwrap() + .into() + } + + #[test] + fn test_event_fail_descriptor_chain_too_short() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + + th.add_desc_chain(MEM_QUEUE, 0, &[(0, REQ_SIZE, 0)]); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails + 1); + } + + #[test] + fn test_event_fail_descriptor_length_too_small() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + + th.add_desc_chain(MEM_QUEUE, 0, &[(0, 1, 0)]); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails + 1); + } + + #[test] + fn test_event_fail_unexpected_writeonly_descriptor() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + + th.add_desc_chain(MEM_QUEUE, 0, &[(0, REQ_SIZE, VIRTQ_DESC_F_WRITE)]); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails + 1); + } + + #[test] + fn test_event_fail_unexpected_readonly_descriptor() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + + th.add_desc_chain(MEM_QUEUE, 0, &[(0, REQ_SIZE, 0), (1, RESP_SIZE, 0)]); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails + 1); + } + + #[test] + fn test_event_fail_response_descriptor_length_too_small() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + + th.add_desc_chain( + MEM_QUEUE, + 0, + &[(0, REQ_SIZE, 0), (1, 1, VIRTQ_DESC_F_WRITE)], + ); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails + 1); + } + + #[test] + fn test_update_requested_size_device_not_active() { + let mut mem_dev = default_virtio_mem(); + let result = mem_dev.update_requested_size(512); + assert!(matches!(result, Err(VirtioMemError::DeviceNotActive))); + } + + #[test] + fn test_update_requested_size_invalid_size() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + // Size not multiple of block size + let result = th.device().update_requested_size(3); + assert!(matches!(result, Err(VirtioMemError::InvalidSize(_)))); + + // Size too large + let result = th.device().update_requested_size(2048); + assert!(matches!(result, Err(VirtioMemError::InvalidSize(_)))); + } + + #[test] + fn test_update_requested_size_success() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + th.device().update_requested_size(512).unwrap(); + assert_eq!(th.device().requested_size_mib(), 512); + } + + #[test] + fn test_plug_request_success() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + let plug_count = METRICS.plug_count.count(); + let plug_bytes = METRICS.plug_bytes.count(); + let plug_fails = METRICS.plug_fails.count(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_ack()); + assert_eq!(th.device().plugged_size_mib(), 2); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails); + assert_eq!(METRICS.plug_count.count(), plug_count + 1); + assert_eq!(METRICS.plug_bytes.count(), plug_bytes + (2 << 20)); + assert_eq!(METRICS.plug_fails.count(), plug_fails); + } + + #[test] + fn test_plug_request_too_big() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(2); + let addr = th.device().guest_address(); + + let plug_count = METRICS.plug_count.count(); + let plug_bytes = METRICS.plug_bytes.count(); + let plug_fails = METRICS.plug_fails.count(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 2 }), + ); + assert!(resp.is_error()); + + assert_eq!(METRICS.plug_count.count(), plug_count + 1); + assert_eq!(METRICS.plug_bytes.count(), plug_bytes); + assert_eq!(METRICS.plug_fails.count(), plug_fails + 1); + } + + #[test] + fn test_plug_request_already_plugged() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + // First plug succeeds + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_ack()); + + // Second plug fails + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_error()); + } + + #[test] + fn test_unplug_request_success() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + let unplug_count = METRICS.unplug_count.count(); + let unplug_bytes = METRICS.unplug_bytes.count(); + let unplug_fails = METRICS.unplug_fails.count(); + + // First plug + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_ack()); + assert_eq!(th.device().plugged_size_mib(), 2); + + // Then unplug + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Unplug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_ack()); + assert_eq!(th.device().plugged_size_mib(), 0); + + assert_eq!(METRICS.unplug_count.count(), unplug_count + 1); + assert_eq!(METRICS.unplug_bytes.count(), unplug_bytes + (2 << 20)); + assert_eq!(METRICS.unplug_fails.count(), unplug_fails); + } + + #[test] + fn test_unplug_request_not_plugged() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + let unplug_count = METRICS.unplug_count.count(); + let unplug_bytes = METRICS.unplug_bytes.count(); + let unplug_fails = METRICS.unplug_fails.count(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Unplug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_error()); + + assert_eq!(METRICS.unplug_count.count(), unplug_count + 1); + assert_eq!(METRICS.unplug_bytes.count(), unplug_bytes); + assert_eq!(METRICS.unplug_fails.count(), unplug_fails + 1); + } + + #[test] + fn test_unplug_all_request() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + let unplug_all_count = METRICS.unplug_all_count.count(); + let unplug_all_fails = METRICS.unplug_all_fails.count(); + + // Plug some blocks + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 2 }), + ); + assert!(resp.is_ack()); + assert_eq!(th.device().plugged_size_mib(), 4); + + // Unplug all + let resp = emulate_request(&mut th, &guest_mem, Request::UnplugAll); + assert!(resp.is_ack()); + assert_eq!(th.device().plugged_size_mib(), 0); + + assert_eq!(METRICS.unplug_all_count.count(), unplug_all_count + 1); + assert_eq!(METRICS.unplug_all_fails.count(), unplug_all_fails); + } + + #[test] + fn test_state_request_unplugged() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + let state_count = METRICS.state_count.count(); + let state_fails = METRICS.state_fails.count(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::State(RequestedRange { addr, nb_blocks: 1 }), + ); + assert_eq!(resp, Response::ack_with_state(BlockRangeState::Unplugged)); + + assert_eq!(METRICS.state_count.count(), state_count + 1); + assert_eq!(METRICS.state_fails.count(), state_fails); + } + + #[test] + fn test_state_request_plugged() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + // Plug first + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_ack()); + + // Check state + let resp = emulate_request( + &mut th, + &guest_mem, + Request::State(RequestedRange { addr, nb_blocks: 1 }), + ); + assert_eq!(resp, Response::ack_with_state(BlockRangeState::Plugged)); + } + + #[test] + fn test_state_request_mixed() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + // Plug first block only + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_ack()); + + // Check state of 2 blocks (one plugged, one unplugged) + let resp = emulate_request( + &mut th, + &guest_mem, + Request::State(RequestedRange { addr, nb_blocks: 2 }), + ); + assert_eq!(resp, Response::ack_with_state(BlockRangeState::Mixed)); + } + + #[test] + fn test_invalid_range_unaligned() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address().unchecked_add(1); + + let state_count = METRICS.state_count.count(); + let state_fails = METRICS.state_fails.count(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::State(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_error()); + + assert_eq!(METRICS.state_count.count(), state_count + 1); + assert_eq!(METRICS.state_fails.count(), state_fails + 1); + } + + #[test] + fn test_invalid_range_zero_blocks() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::State(RequestedRange { addr, nb_blocks: 0 }), + ); + assert!(resp.is_error()); + } + + #[test] + fn test_invalid_range_out_of_bounds() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(4); + let addr = th.device().guest_address(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::State(RequestedRange { + addr, + nb_blocks: 1024, + }), + ); + assert!(resp.is_error()); + } + + #[test] + fn test_unsupported_request() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + + th.add_desc_chain( + MEM_QUEUE, + 0, + &[(0, REQ_SIZE, 0), (1, RESP_SIZE, VIRTQ_DESC_F_WRITE)], + ); + guest_mem + .write_obj( + virtio_mem::virtio_mem_req::from(Request::Unsupported(999)), + th.desc_address(MEM_QUEUE, 0), + ) + .unwrap(); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails + 1); + } } diff --git a/src/vmm/src/devices/virtio/mem/metrics.rs b/src/vmm/src/devices/virtio/mem/metrics.rs index 443e9a8b8f1..d69255d44ec 100644 --- a/src/vmm/src/devices/virtio/mem/metrics.rs +++ b/src/vmm/src/devices/virtio/mem/metrics.rs @@ -45,6 +45,36 @@ pub(super) struct VirtioMemDeviceMetrics { pub queue_event_fails: SharedIncMetric, /// Number of queue events handled pub queue_event_count: SharedIncMetric, + /// Latency of Plug operations + pub plug_agg: LatencyAggregateMetrics, + /// Number of Plug operations + pub plug_count: SharedIncMetric, + /// Number of plugged bytes + pub plug_bytes: SharedIncMetric, + /// Number of Plug operations failed + pub plug_fails: SharedIncMetric, + /// Latency of Unplug operations + pub unplug_agg: LatencyAggregateMetrics, + /// Number of Unplug operations + pub unplug_count: SharedIncMetric, + /// Number of unplugged bytes + pub unplug_bytes: SharedIncMetric, + /// Number of Unplug operations failed + pub unplug_fails: SharedIncMetric, + /// Number of discards failed for an Unplug or UnplugAll operation + pub unplug_discard_fails: SharedIncMetric, + /// Latency of UnplugAll operations + pub unplug_all_agg: LatencyAggregateMetrics, + /// Number of UnplugAll operations + pub unplug_all_count: SharedIncMetric, + /// Number of UnplugAll operations failed + pub unplug_all_fails: SharedIncMetric, + /// Latency of State operations + pub state_agg: LatencyAggregateMetrics, + /// Number of State operations + pub state_count: SharedIncMetric, + /// Number of State operations failed + pub state_fails: SharedIncMetric, } impl VirtioMemDeviceMetrics { @@ -54,6 +84,21 @@ impl VirtioMemDeviceMetrics { activate_fails: SharedIncMetric::new(), queue_event_fails: SharedIncMetric::new(), queue_event_count: SharedIncMetric::new(), + plug_agg: LatencyAggregateMetrics::new(), + plug_count: SharedIncMetric::new(), + plug_bytes: SharedIncMetric::new(), + plug_fails: SharedIncMetric::new(), + unplug_agg: LatencyAggregateMetrics::new(), + unplug_count: SharedIncMetric::new(), + unplug_bytes: SharedIncMetric::new(), + unplug_fails: SharedIncMetric::new(), + unplug_discard_fails: SharedIncMetric::new(), + unplug_all_agg: LatencyAggregateMetrics::new(), + unplug_all_count: SharedIncMetric::new(), + unplug_all_fails: SharedIncMetric::new(), + state_agg: LatencyAggregateMetrics::new(), + state_count: SharedIncMetric::new(), + state_fails: SharedIncMetric::new(), } } } @@ -66,13 +111,8 @@ pub mod tests { #[test] fn test_memory_hotplug_metrics() { let mem_metrics: VirtioMemDeviceMetrics = VirtioMemDeviceMetrics::new(); - let mem_metrics_local: String = serde_json::to_string(&mem_metrics).unwrap(); - // the 1st serialize flushes the metrics and resets values to 0 so that - // we can compare the values with local metrics. - serde_json::to_string(&METRICS).unwrap(); - let mem_metrics_global: String = serde_json::to_string(&METRICS).unwrap(); - assert_eq!(mem_metrics_local, mem_metrics_global); mem_metrics.queue_event_count.inc(); assert_eq!(mem_metrics.queue_event_count.count(), 1); + let _ = serde_json::to_string(&mem_metrics).unwrap(); } } diff --git a/src/vmm/src/devices/virtio/mem/mod.rs b/src/vmm/src/devices/virtio/mem/mod.rs index 1c9e98f98a6..5c76afc4c24 100644 --- a/src/vmm/src/devices/virtio/mem/mod.rs +++ b/src/vmm/src/devices/virtio/mem/mod.rs @@ -5,6 +5,7 @@ mod device; mod event_handler; pub mod metrics; pub mod persist; +mod request; use vm_memory::GuestAddress; diff --git a/src/vmm/src/devices/virtio/mem/persist.rs b/src/vmm/src/devices/virtio/mem/persist.rs index e48246de500..09f41680e32 100644 --- a/src/vmm/src/devices/virtio/mem/persist.rs +++ b/src/vmm/src/devices/virtio/mem/persist.rs @@ -5,6 +5,7 @@ use std::sync::Arc; +use bitvec::vec::BitVec; use serde::{Deserialize, Serialize}; use vm_memory::Address; @@ -17,6 +18,7 @@ use crate::devices::virtio::mem::{ use crate::devices::virtio::persist::{PersistError as VirtioStateError, VirtioDeviceState}; use crate::devices::virtio::queue::FIRECRACKER_MAX_QUEUE_SIZE; use crate::snapshot::Persist; +use crate::utils::usize_to_u64; use crate::vstate::memory::{GuestMemoryMmap, GuestRegionMmap}; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -25,9 +27,9 @@ pub struct VirtioMemState { region_size: u64, block_size: u64, usable_region_size: u64, - plugged_size: u64, requested_size: u64, slot_size: usize, + plugged_blocks: BitVec, } #[derive(Debug)] @@ -60,7 +62,7 @@ impl Persist<'_> for VirtioMem { region_size: self.config.region_size, block_size: self.config.block_size, usable_region_size: self.config.usable_region_size, - plugged_size: self.config.plugged_size, + plugged_blocks: self.plugged_blocks.clone(), requested_size: self.config.requested_size, slot_size: self.slot_size, } @@ -82,13 +84,18 @@ impl Persist<'_> for VirtioMem { region_size: state.region_size, block_size: state.block_size, usable_region_size: state.usable_region_size, - plugged_size: state.plugged_size, + plugged_size: usize_to_u64(state.plugged_blocks.count_ones()) * state.block_size, requested_size: state.requested_size, ..Default::default() }; - let mut virtio_mem = - VirtioMem::from_state(constructor_args.vm, queues, config, state.slot_size)?; + let mut virtio_mem = VirtioMem::from_state( + constructor_args.vm, + queues, + config, + state.slot_size, + state.plugged_blocks.clone(), + )?; virtio_mem.set_avail_features(state.virtio_state.avail_features); virtio_mem.set_acked_features(state.virtio_state.acked_features); @@ -111,7 +118,7 @@ mod tests { assert_eq!(state.region_size, dev.config.region_size); assert_eq!(state.block_size, dev.config.block_size); assert_eq!(state.usable_region_size, dev.config.usable_region_size); - assert_eq!(state.plugged_size, dev.config.plugged_size); + assert_eq!(state.plugged_blocks, dev.plugged_blocks); assert_eq!(state.requested_size, dev.config.requested_size); assert_eq!(state.slot_size, dev.slot_size); } diff --git a/src/vmm/src/devices/virtio/mem/request.rs b/src/vmm/src/devices/virtio/mem/request.rs new file mode 100644 index 00000000000..a55bdb2bbf6 --- /dev/null +++ b/src/vmm/src/devices/virtio/mem/request.rs @@ -0,0 +1,230 @@ +// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use vm_memory::{Address, ByteValued, GuestAddress}; + +use crate::devices::virtio::generated::virtio_mem; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct RequestedRange { + pub(crate) addr: GuestAddress, + pub(crate) nb_blocks: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) enum Request { + Plug(RequestedRange), + Unplug(RequestedRange), + UnplugAll, + State(RequestedRange), + Unsupported(u32), +} + +// SAFETY: this is safe, trust me bro +unsafe impl ByteValued for virtio_mem::virtio_mem_req {} + +impl From for Request { + fn from(req: virtio_mem::virtio_mem_req) -> Self { + match req.type_.into() { + // SAFETY: union type is checked in the match + virtio_mem::VIRTIO_MEM_REQ_PLUG => unsafe { + Request::Plug(RequestedRange { + addr: GuestAddress(req.u.plug.addr), + nb_blocks: req.u.plug.nb_blocks.into(), + }) + }, + // SAFETY: union type is checked in the match + virtio_mem::VIRTIO_MEM_REQ_UNPLUG => unsafe { + Request::Unplug(RequestedRange { + addr: GuestAddress(req.u.unplug.addr), + nb_blocks: req.u.unplug.nb_blocks.into(), + }) + }, + virtio_mem::VIRTIO_MEM_REQ_UNPLUG_ALL => Request::UnplugAll, + // SAFETY: union type is checked in the match + virtio_mem::VIRTIO_MEM_REQ_STATE => unsafe { + Request::State(RequestedRange { + addr: GuestAddress(req.u.state.addr), + nb_blocks: req.u.state.nb_blocks.into(), + }) + }, + t => Request::Unsupported(t), + } + } +} + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub enum ResponseType { + Ack, + Nack, + Busy, + Error, +} + +impl From for u16 { + fn from(code: ResponseType) -> Self { + match code { + ResponseType::Ack => virtio_mem::VIRTIO_MEM_RESP_ACK, + ResponseType::Nack => virtio_mem::VIRTIO_MEM_RESP_NACK, + ResponseType::Busy => virtio_mem::VIRTIO_MEM_RESP_BUSY, + ResponseType::Error => virtio_mem::VIRTIO_MEM_RESP_ERROR, + } + .try_into() + .unwrap() + } +} + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub enum BlockRangeState { + Plugged, + Unplugged, + Mixed, +} + +impl From for virtio_mem::virtio_mem_resp_state { + fn from(code: BlockRangeState) -> Self { + virtio_mem::virtio_mem_resp_state { + state: match code { + BlockRangeState::Plugged => virtio_mem::VIRTIO_MEM_STATE_PLUGGED, + BlockRangeState::Unplugged => virtio_mem::VIRTIO_MEM_STATE_UNPLUGGED, + BlockRangeState::Mixed => virtio_mem::VIRTIO_MEM_STATE_MIXED, + } + .try_into() + .unwrap(), + } + } +} + +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct Response { + pub resp_type: ResponseType, + // Only for State requests + pub state: Option, +} + +impl Response { + pub(crate) fn error() -> Self { + Response { + resp_type: ResponseType::Error, + state: None, + } + } + + pub(crate) fn ack() -> Self { + Response { + resp_type: ResponseType::Ack, + state: None, + } + } + + pub(crate) fn ack_with_state(state: BlockRangeState) -> Self { + Response { + resp_type: ResponseType::Ack, + state: Some(state), + } + } + + pub(crate) fn is_ack(&self) -> bool { + self.resp_type == ResponseType::Ack + } + + pub(crate) fn is_error(&self) -> bool { + self.resp_type == ResponseType::Error + } +} + +// SAFETY: Plain data structures +unsafe impl ByteValued for virtio_mem::virtio_mem_resp {} + +impl From for virtio_mem::virtio_mem_resp { + fn from(resp: Response) -> Self { + let mut out = virtio_mem::virtio_mem_resp { + type_: resp.resp_type.into(), + ..Default::default() + }; + if let Some(state) = resp.state { + out.u.state = state.into(); + } + out + } +} + +#[cfg(test)] +mod test_util { + use super::*; + + // Implement the reverse conversions to use in test code. + + impl From for virtio_mem::virtio_mem_req { + fn from(req: Request) -> virtio_mem::virtio_mem_req { + match req { + Request::Plug(r) => virtio_mem::virtio_mem_req { + type_: virtio_mem::VIRTIO_MEM_REQ_PLUG.try_into().unwrap(), + u: virtio_mem::virtio_mem_req__bindgen_ty_1 { + plug: virtio_mem::virtio_mem_req_plug { + addr: r.addr.raw_value(), + nb_blocks: r.nb_blocks.try_into().unwrap(), + ..Default::default() + }, + }, + ..Default::default() + }, + Request::Unplug(r) => virtio_mem::virtio_mem_req { + type_: virtio_mem::VIRTIO_MEM_REQ_UNPLUG.try_into().unwrap(), + u: virtio_mem::virtio_mem_req__bindgen_ty_1 { + unplug: virtio_mem::virtio_mem_req_unplug { + addr: r.addr.raw_value(), + nb_blocks: r.nb_blocks.try_into().unwrap(), + ..Default::default() + }, + }, + ..Default::default() + }, + Request::UnplugAll => virtio_mem::virtio_mem_req { + type_: virtio_mem::VIRTIO_MEM_REQ_UNPLUG_ALL.try_into().unwrap(), + ..Default::default() + }, + Request::State(r) => virtio_mem::virtio_mem_req { + type_: virtio_mem::VIRTIO_MEM_REQ_STATE.try_into().unwrap(), + u: virtio_mem::virtio_mem_req__bindgen_ty_1 { + state: virtio_mem::virtio_mem_req_state { + addr: r.addr.raw_value(), + nb_blocks: r.nb_blocks.try_into().unwrap(), + ..Default::default() + }, + }, + ..Default::default() + }, + Request::Unsupported(t) => virtio_mem::virtio_mem_req { + type_: t.try_into().unwrap(), + ..Default::default() + }, + } + } + } + + impl From for Response { + fn from(resp: virtio_mem::virtio_mem_resp) -> Self { + Response { + resp_type: match resp.type_.into() { + virtio_mem::VIRTIO_MEM_RESP_ACK => ResponseType::Ack, + virtio_mem::VIRTIO_MEM_RESP_NACK => ResponseType::Nack, + virtio_mem::VIRTIO_MEM_RESP_BUSY => ResponseType::Busy, + virtio_mem::VIRTIO_MEM_RESP_ERROR => ResponseType::Error, + t => panic!("Invalid response type: {:?}", t), + }, + // There is no way to know whether this is present or not as it depends on the + // request types. Callers should ignore this value if the request wasn't STATE + /// SAFETY: test code only. Uninitialized values are 0 and recognized as PLUGGED. + state: Some(unsafe { + match resp.u.state.state.into() { + virtio_mem::VIRTIO_MEM_STATE_PLUGGED => BlockRangeState::Plugged, + virtio_mem::VIRTIO_MEM_STATE_UNPLUGGED => BlockRangeState::Unplugged, + virtio_mem::VIRTIO_MEM_STATE_MIXED => BlockRangeState::Mixed, + t => panic!("Invalid state: {:?}", t), + } + }), + } + } + } +} diff --git a/src/vmm/src/devices/virtio/test_utils.rs b/src/vmm/src/devices/virtio/test_utils.rs index 6f1489dd380..0c7978504e7 100644 --- a/src/vmm/src/devices/virtio/test_utils.rs +++ b/src/vmm/src/devices/virtio/test_utils.rs @@ -442,6 +442,11 @@ pub(crate) mod test { self.virtqueues.last().unwrap().end().raw_value() } + /// Get the address of a descriptor + pub fn desc_address(&self, queue: usize, index: usize) -> GuestAddress { + GuestAddress(self.virtqueues[queue].dtable[index].addr.get()) + } + /// Add a new Descriptor in one of the device's queues /// /// This function adds in one of the queues of the device a DescriptorChain at some offset diff --git a/src/vmm/src/lib.rs b/src/vmm/src/lib.rs index 0b6fee2e0a0..79e26c706a1 100644 --- a/src/vmm/src/lib.rs +++ b/src/vmm/src/lib.rs @@ -609,6 +609,15 @@ impl Vmm { .map_err(VmmError::FindDeviceError) } + /// Returns the current state of the memory hotplug device. + pub fn update_memory_hotplug_size(&self, requested_size_mib: usize) -> Result<(), VmmError> { + self.device_manager + .try_with_virtio_device_with_id(VIRTIO_MEM_DEV_ID, |dev: &mut VirtioMem| { + dev.update_requested_size(requested_size_mib) + }) + .map_err(VmmError::FindDeviceError) + } + /// Signals Vmm to stop and exit. pub fn stop(&mut self, exit_code: FcExitCode) { // To avoid cycles, all teardown paths take the following route: diff --git a/src/vmm/src/logger/metrics.rs b/src/vmm/src/logger/metrics.rs index c983a5a9f16..060a751562a 100644 --- a/src/vmm/src/logger/metrics.rs +++ b/src/vmm/src/logger/metrics.rs @@ -479,6 +479,10 @@ pub struct PatchRequestsMetrics { pub mmds_count: SharedIncMetric, /// Number of failures in PATCHing an mmds. pub mmds_fails: SharedIncMetric, + /// Number of PATCHes to /hotplug/memory + pub hotplug_memory_count: SharedIncMetric, + /// Number of failed PATCHes to /hotplug/memory + pub hotplug_memory_fails: SharedIncMetric, } impl PatchRequestsMetrics { /// Const default construction. @@ -492,6 +496,8 @@ impl PatchRequestsMetrics { machine_cfg_fails: SharedIncMetric::new(), mmds_count: SharedIncMetric::new(), mmds_fails: SharedIncMetric::new(), + hotplug_memory_count: SharedIncMetric::new(), + hotplug_memory_fails: SharedIncMetric::new(), } } } diff --git a/src/vmm/src/resources.rs b/src/vmm/src/resources.rs index 03d9f9b0c77..25fbc14e5db 100644 --- a/src/vmm/src/resources.rs +++ b/src/vmm/src/resources.rs @@ -536,6 +536,18 @@ impl VmResources { crate::arch::arch_memory_regions(mib_to_bytes(self.machine_config.mem_size_mib)); self.allocate_memory_regions(®ions) } + + /// Allocates a single guest memory region. + pub fn allocate_memory_region( + &self, + start: GuestAddress, + size: usize, + ) -> Result { + Ok(self + .allocate_memory_regions(&[(start, size)])? + .pop() + .unwrap()) + } } impl From<&VmResources> for VmmConfig { diff --git a/src/vmm/src/rpc_interface.rs b/src/vmm/src/rpc_interface.rs index fe3e9c296e7..d25ddb735a8 100644 --- a/src/vmm/src/rpc_interface.rs +++ b/src/vmm/src/rpc_interface.rs @@ -29,7 +29,9 @@ use crate::vmm_config::drive::{BlockDeviceConfig, BlockDeviceUpdateConfig, Drive use crate::vmm_config::entropy::{EntropyDeviceConfig, EntropyDeviceError}; use crate::vmm_config::instance_info::InstanceInfo; use crate::vmm_config::machine_config::{MachineConfig, MachineConfigError, MachineConfigUpdate}; -use crate::vmm_config::memory_hotplug::{MemoryHotplugConfig, MemoryHotplugConfigError}; +use crate::vmm_config::memory_hotplug::{ + MemoryHotplugConfig, MemoryHotplugConfigError, MemoryHotplugSizeUpdate, +}; use crate::vmm_config::metrics::{MetricsConfig, MetricsConfigError}; use crate::vmm_config::mmds::{MmdsConfig, MmdsConfigError}; use crate::vmm_config::net::{ @@ -113,6 +115,9 @@ pub enum VmmAction { /// Set the memory hotplug device using `MemoryHotplugConfig` as input. This action can only be /// called before the microVM has booted. SetMemoryHotplugDevice(MemoryHotplugConfig), + /// Updates the memory hotplug device using `MemoryHotplugConfigUpdate` as input. This action + /// can only be called after the microVM has booted. + UpdateMemoryHotplugSize(MemoryHotplugSizeUpdate), /// Launch the microVM. This action can only be called before the microVM has booted. StartMicroVm, /// Send CTRL+ALT+DEL to the microVM, using the i8042 keyboard function. If an AT-keyboard @@ -152,6 +157,8 @@ pub enum VmmActionError { EntropyDevice(#[from] EntropyDeviceError), /// Memory hotplug config error: {0} MemoryHotplugConfig(#[from] MemoryHotplugConfigError), + /// Memory hotplug update error: {0} + MemoryHotplugUpdate(VmmError), /// Internal VMM error: {0} InternalVmm(#[from] VmmError), /// Load snapshot error: {0} @@ -469,6 +476,7 @@ impl<'a> PrebootApiController<'a> { | UpdateBalloon(_) | UpdateBalloonStatistics(_) | UpdateBlockDevice(_) + | UpdateMemoryHotplugSize(_) | UpdateNetworkInterface(_) => Err(VmmActionError::OperationNotSupportedPreBoot), #[cfg(target_arch = "x86_64")] SendCtrlAltDel => Err(VmmActionError::OperationNotSupportedPreBoot), @@ -709,7 +717,13 @@ impl RuntimeApiController { .map_err(VmmActionError::BalloonUpdate), UpdateBlockDevice(new_cfg) => self.update_block_device(new_cfg), UpdateNetworkInterface(netif_update) => self.update_net_rate_limiters(netif_update), - + UpdateMemoryHotplugSize(cfg) => self + .vmm + .lock() + .expect("Poisoned lock") + .update_memory_hotplug_size(cfg.requested_size_mib) + .map(|_| VmmData::Empty) + .map_err(VmmActionError::MemoryHotplugUpdate), // Operations not allowed post-boot. ConfigureBootSource(_) | ConfigureLogger(_) @@ -1181,6 +1195,11 @@ mod tests { ))); #[cfg(target_arch = "x86_64")] check_unsupported(preboot_request(VmmAction::SendCtrlAltDel)); + check_unsupported(preboot_request(VmmAction::UpdateMemoryHotplugSize( + MemoryHotplugSizeUpdate { + requested_size_mib: 0, + }, + ))); } fn runtime_request(request: VmmAction) -> Result { diff --git a/src/vmm/src/test_utils/mod.rs b/src/vmm/src/test_utils/mod.rs index 6fe66cdbadb..887acc54d38 100644 --- a/src/vmm/src/test_utils/mod.rs +++ b/src/vmm/src/test_utils/mod.rs @@ -16,6 +16,7 @@ use crate::vm_memory_vendored::GuestRegionCollection; use crate::vmm_config::boot_source::BootSourceConfig; use crate::vmm_config::instance_info::InstanceInfo; use crate::vmm_config::machine_config::HugePageConfig; +use crate::vmm_config::memory_hotplug::MemoryHotplugConfig; use crate::vstate::memory::{self, GuestMemoryMmap, GuestRegionMmap, GuestRegionMmapExt}; use crate::{EventManager, Vmm}; @@ -73,6 +74,7 @@ pub fn create_vmm( is_diff: bool, boot_microvm: bool, pci_enabled: bool, + memory_hotplug_enabled: bool, ) -> (Arc>, EventManager) { let mut event_manager = EventManager::new().unwrap(); let empty_seccomp_filters = get_empty_filters(); @@ -96,6 +98,14 @@ pub fn create_vmm( resources.pci_enabled = pci_enabled; + if memory_hotplug_enabled { + resources.memory_hotplug = Some(MemoryHotplugConfig { + total_size_mib: 1024, + block_size_mib: 2, + slot_size_mib: 128, + }); + } + let vmm = build_microvm_for_boot( &InstanceInfo::default(), &resources, @@ -112,23 +122,15 @@ pub fn create_vmm( } pub fn default_vmm(kernel_image: Option<&str>) -> (Arc>, EventManager) { - create_vmm(kernel_image, false, true, false) + create_vmm(kernel_image, false, true, false, false) } pub fn default_vmm_no_boot(kernel_image: Option<&str>) -> (Arc>, EventManager) { - create_vmm(kernel_image, false, false, false) -} - -pub fn default_vmm_pci_no_boot(kernel_image: Option<&str>) -> (Arc>, EventManager) { - create_vmm(kernel_image, false, false, true) + create_vmm(kernel_image, false, false, false, false) } pub fn dirty_tracking_vmm(kernel_image: Option<&str>) -> (Arc>, EventManager) { - create_vmm(kernel_image, true, true, false) -} - -pub fn default_vmm_pci(kernel_image: Option<&str>) -> (Arc>, EventManager) { - create_vmm(kernel_image, false, true, false) + create_vmm(kernel_image, true, true, false, false) } #[allow(clippy::undocumented_unsafe_blocks)] diff --git a/src/vmm/src/vmm_config/memory_hotplug.rs b/src/vmm/src/vmm_config/memory_hotplug.rs index d09141c1b66..85cf45ee5e8 100644 --- a/src/vmm/src/vmm_config/memory_hotplug.rs +++ b/src/vmm/src/vmm_config/memory_hotplug.rs @@ -86,6 +86,14 @@ impl MemoryHotplugConfig { } } +/// Configuration for memory hotplug device. +#[derive(Clone, Debug, Default, PartialEq, Eq, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] +pub struct MemoryHotplugSizeUpdate { + /// Requested size in MiB to resize the hotpluggable memory to. + pub requested_size_mib: usize, +} + #[cfg(test)] mod tests { use serde_json; diff --git a/src/vmm/src/vstate/memory.rs b/src/vmm/src/vstate/memory.rs index 1bf7cda6342..7fa349f6f4a 100644 --- a/src/vmm/src/vstate/memory.rs +++ b/src/vmm/src/vstate/memory.rs @@ -57,6 +57,8 @@ pub enum MemoryError { pub enum GuestRegionType { /// Guest DRAM Dram, + /// Hotpluggable memory + Hotpluggable, } /// An extension to GuestMemoryRegion that stores the type of region, and the KVM slot @@ -80,6 +82,14 @@ impl GuestRegionMmapExt { } } + pub(crate) fn hotpluggable_from_mmap_region(region: GuestRegionMmap, slot: u32) -> Self { + GuestRegionMmapExt { + inner: region, + region_type: GuestRegionType::Hotpluggable, + slot, + } + } + pub(crate) fn from_state( region: GuestRegionMmap, state: &GuestMemoryRegionState, diff --git a/src/vmm/src/vstate/vm.rs b/src/vmm/src/vstate/vm.rs index cc6afb722a2..67771473355 100644 --- a/src/vmm/src/vstate/vm.rs +++ b/src/vmm/src/vstate/vm.rs @@ -222,6 +222,19 @@ impl Vm { Ok(()) } + /// Register a new hotpluggable region to this [`Vm`]. + pub fn register_hotpluggable_memory_region( + &mut self, + region: GuestRegionMmap, + ) -> Result<(), VmError> { + let arcd_region = Arc::new(GuestRegionMmapExt::hotpluggable_from_mmap_region( + region, + self.allocate_slot_ids(1)?, + )); + + self._register_memory_region(arcd_region) + } + /// Register a list of new memory regions to this [`Vm`]. /// /// Note: regions and state.regions need to be in the same order. diff --git a/src/vmm/tests/integration_tests.rs b/src/vmm/tests/integration_tests.rs index 4abbedc4530..6a5e6a08a14 100644 --- a/src/vmm/tests/integration_tests.rs +++ b/src/vmm/tests/integration_tests.rs @@ -18,9 +18,7 @@ use vmm::rpc_interface::{ use vmm::seccomp::get_empty_filters; use vmm::snapshot::Snapshot; use vmm::test_utils::mock_resources::{MockVmResources, NOISY_KERNEL_IMAGE}; -use vmm::test_utils::{ - create_vmm, default_vmm, default_vmm_no_boot, default_vmm_pci, default_vmm_pci_no_boot, -}; +use vmm::test_utils::{create_vmm, default_vmm, default_vmm_no_boot}; use vmm::vmm_config::balloon::BalloonDeviceConfig; use vmm::vmm_config::boot_source::BootSourceConfig; use vmm::vmm_config::drive::BlockDeviceConfig; @@ -66,13 +64,12 @@ fn test_build_and_boot_microvm() { assert_eq!(format!("{:?}", vmm_ret.err()), "Some(MissingKernelConfig)"); } - // Success case. - let (vmm, evmgr) = default_vmm(None); - check_booted_microvm(vmm, evmgr); - - // microVM with PCI - let (vmm, evmgr) = default_vmm_pci(None); - check_booted_microvm(vmm, evmgr); + for pci_enabled in [false, true] { + for memory_hotplug in [false, true] { + let (vmm, evmgr) = create_vmm(None, false, true, pci_enabled, memory_hotplug); + check_booted_microvm(vmm, evmgr); + } + } } #[allow(unused_mut, unused_variables)] @@ -96,10 +93,12 @@ fn check_build_microvm(vmm: Arc>, mut evmgr: EventManager) { #[test] fn test_build_microvm() { - let (vmm, evtmgr) = default_vmm_no_boot(None); - check_build_microvm(vmm, evtmgr); - let (vmm, evtmgr) = default_vmm_pci_no_boot(None); - check_build_microvm(vmm, evtmgr); + for pci_enabled in [false, true] { + for memory_hotplug in [false, true] { + let (vmm, evmgr) = create_vmm(None, false, false, pci_enabled, memory_hotplug); + check_build_microvm(vmm, evmgr); + } + } } fn pause_resume_microvm(vmm: Arc>) { @@ -118,13 +117,14 @@ fn pause_resume_microvm(vmm: Arc>) { #[test] fn test_pause_resume_microvm() { - // Tests that pausing and resuming a microVM work as expected. - let (vmm, _) = default_vmm(None); + for pci_enabled in [false, true] { + for memory_hotplug in [false, true] { + // Tests that pausing and resuming a microVM work as expected. + let (vmm, _) = create_vmm(None, false, true, pci_enabled, memory_hotplug); - pause_resume_microvm(vmm); - - let (vmm, _) = default_vmm_pci(None); - pause_resume_microvm(vmm); + pause_resume_microvm(vmm); + } + } } #[test] @@ -195,11 +195,21 @@ fn test_disallow_dump_cpu_config_without_pausing() { vmm.lock().unwrap().stop(FcExitCode::Ok); } -fn verify_create_snapshot(is_diff: bool, pci_enabled: bool) -> (TempFile, TempFile) { +fn verify_create_snapshot( + is_diff: bool, + pci_enabled: bool, + memory_hotplug: bool, +) -> (TempFile, TempFile) { let snapshot_file = TempFile::new().unwrap(); let memory_file = TempFile::new().unwrap(); - let (vmm, _) = create_vmm(Some(NOISY_KERNEL_IMAGE), is_diff, true, pci_enabled); + let (vmm, _) = create_vmm( + Some(NOISY_KERNEL_IMAGE), + is_diff, + true, + pci_enabled, + memory_hotplug, + ); let resources = VmResources { machine_config: MachineConfig { mem_size_mib: 1, @@ -303,14 +313,19 @@ fn verify_load_snapshot(snapshot_file: TempFile, memory_file: TempFile) { #[test] fn test_create_and_load_snapshot() { - for (diff_snap, pci_enabled) in [(false, false), (false, true), (true, false), (true, true)] { - // Create snapshot. - let (snapshot_file, memory_file) = verify_create_snapshot(diff_snap, pci_enabled); - // Create a new microVm from snapshot. This only tests code-level logic; it verifies - // that a microVM can be built with no errors from given snapshot. - // It does _not_ verify that the guest is actually restored properly. We're using - // python integration tests for that. - verify_load_snapshot(snapshot_file, memory_file); + for diff_snap in [false, true] { + for pci_enabled in [false, true] { + for memory_hotplug in [false, true] { + // Create snapshot. + let (snapshot_file, memory_file) = + verify_create_snapshot(diff_snap, pci_enabled, memory_hotplug); + // Create a new microVm from snapshot. This only tests code-level logic; it verifies + // that a microVM can be built with no errors from given snapshot. + // It does _not_ verify that the guest is actually restored properly. We're using + // python integration tests for that. + verify_load_snapshot(snapshot_file, memory_file); + } + } } } @@ -338,7 +353,7 @@ fn check_snapshot(mut microvm_state: MicrovmState) { fn get_microvm_state_from_snapshot(pci_enabled: bool) -> MicrovmState { // Create a diff snapshot - let (snapshot_file, _) = verify_create_snapshot(true, pci_enabled); + let (snapshot_file, _) = verify_create_snapshot(true, pci_enabled, false); // Deserialize the microVM state. snapshot_file.as_file().seek(SeekFrom::Start(0)).unwrap(); @@ -346,7 +361,7 @@ fn get_microvm_state_from_snapshot(pci_enabled: bool) -> MicrovmState { } fn verify_load_snap_disallowed_after_boot_resources(res: VmmAction, res_name: &str) { - let (snapshot_file, memory_file) = verify_create_snapshot(false, false); + let (snapshot_file, memory_file) = verify_create_snapshot(false, false, false); let mut event_manager = EventManager::new().unwrap(); let empty_seccomp_filters = get_empty_filters(); diff --git a/tests/framework/guest_stats.py b/tests/framework/guest_stats.py new file mode 100644 index 00000000000..570b9a4ea63 --- /dev/null +++ b/tests/framework/guest_stats.py @@ -0,0 +1,79 @@ +# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Classes for querying guest stats inside microVMs. +""" + + +class ByteUnit: + """Represents a byte unit that can be converted to other units.""" + + value_bytes: int + + def __init__(self, value_bytes: int): + self.value_bytes = value_bytes + + @classmethod + def from_kib(cls, value_kib: int): + """Creates a ByteUnit from a value in KiB.""" + if value_kib < 0: + raise ValueError("value_kib must be non-negative") + return ByteUnit(value_kib * 1024) + + def bytes(self) -> float: + """Returns the value in B.""" + return self.value_bytes + + def kib(self) -> float: + """Returns the value in KiB.""" + return self.value_bytes / 1024 + + def mib(self) -> float: + """Returns the value in MiB.""" + return self.value_bytes / (1 << 20) + + def gib(self) -> float: + """Returns the value in GiB.""" + return self.value_bytes / (1 << 30) + + +class Meminfo: + """Represents the contents of /proc/meminfo inside the guest""" + + mem_total: ByteUnit + mem_free: ByteUnit + mem_available: ByteUnit + buffers: ByteUnit + cached: ByteUnit + + def __init__(self): + self.mem_total = ByteUnit(0) + self.mem_free = ByteUnit(0) + self.mem_available = ByteUnit(0) + self.buffers = ByteUnit(0) + self.cached = ByteUnit(0) + + +class MeminfoGuest: + """Queries /proc/meminfo inside the guest""" + + def __init__(self, vm): + self.vm = vm + + def get(self) -> Meminfo: + """Returns the contents of /proc/meminfo inside the guest""" + meminfo = Meminfo() + for line in self.vm.ssh.check_output("cat /proc/meminfo").stdout.splitlines(): + parts = line.split() + if parts[0] == "MemTotal:": + meminfo.mem_total = ByteUnit.from_kib(int(parts[1])) + elif parts[0] == "MemFree:": + meminfo.mem_free = ByteUnit.from_kib(int(parts[1])) + elif parts[0] == "MemAvailable:": + meminfo.mem_available = ByteUnit.from_kib(int(parts[1])) + elif parts[0] == "Buffers:": + meminfo.buffers = ByteUnit.from_kib(int(parts[1])) + elif parts[0] == "Cached:": + meminfo.cached = ByteUnit.from_kib(int(parts[1])) + + return meminfo diff --git a/tests/framework/microvm.py b/tests/framework/microvm.py index fa9dea79b82..853de6fc4ef 100644 --- a/tests/framework/microvm.py +++ b/tests/framework/microvm.py @@ -23,10 +23,11 @@ from collections import namedtuple from dataclasses import dataclass from enum import Enum, auto -from functools import lru_cache +from functools import cached_property, lru_cache from pathlib import Path from typing import Optional +import psutil from tenacity import Retrying, retry, stop_after_attempt, wait_fixed import host_tools.cargo_build as build_tools @@ -472,7 +473,7 @@ def state(self): """Get the InstanceInfo property and return the state field.""" return self.api.describe.get().json()["state"] - @property + @cached_property def firecracker_pid(self): """Return Firecracker's PID @@ -491,6 +492,11 @@ def firecracker_pid(self): with attempt: return int(self.jailer.pid_file.read_text(encoding="ascii")) + @cached_property + def ps(self): + """Returns a handle to the psutil.Process for this VM""" + return psutil.Process(self.firecracker_pid) + @property def dimensions(self): """Gets a default set of cloudwatch dimensions describing the configuration of this microvm""" @@ -1180,6 +1186,22 @@ def wait_for_ssh_up(self): # run commands. The actual connection retry loop happens in SSHConnection._init_connection _ = self.ssh_iface(0) + def hotplug_memory( + self, requested_size_mib: int, timeout: int = 60, poll: float = 0.1 + ): + """Send a hot(un)plug request and wait up to timeout seconds for completion polling every poll seconds""" + self.api.memory_hotplug.patch(requested_size_mib=requested_size_mib) + # Wait for the hotplug to complete + deadline = time.time() + timeout + while time.time() < deadline: + if ( + self.api.memory_hotplug.get().json()["plugged_size_mib"] + == requested_size_mib + ): + return + time.sleep(poll) + raise TimeoutError(f"Hotplug did not complete within {timeout} seconds") + class MicroVMFactory: """MicroVM factory""" @@ -1294,6 +1316,18 @@ def build_n_from_snapshot( last_snapshot.delete() current_snapshot.delete() + def clone_uvm(self, uvm, uffd_handler_name=None): + """ + Clone the given VM and start it. + """ + snapshot = uvm.snapshot_full() + restored_vm = self.build() + restored_vm.spawn() + restored_vm.restore_from_snapshot( + snapshot, resume=True, uffd_handler_name=uffd_handler_name + ) + return restored_vm + def kill(self): """Clean up all built VMs""" for vm in self.vms: diff --git a/tests/framework/utils.py b/tests/framework/utils.py index 64bc9526e5c..448b351fd86 100644 --- a/tests/framework/utils.py +++ b/tests/framework/utils.py @@ -14,6 +14,7 @@ import typing from collections import defaultdict, namedtuple from contextlib import contextmanager +from pathlib import Path from typing import Dict import psutil @@ -129,6 +130,19 @@ def track_cpu_utilization( return cpu_utilization +def get_resident_memory(process: psutil.Process): + """Returns current memory utilization in KiB, including used HugeTLBFS""" + + proc_status = Path("/proc", str(process.pid), "status").read_text("utf-8") + for line in proc_status.splitlines(): + if line.startswith("HugetlbPages:"): # entry is in KiB + hugetlbfs_usage = int(line.split()[1]) + break + else: + assert False, f"HugetlbPages not found in {str(proc_status)}" + return hugetlbfs_usage + process.memory_info().rss // 1024 + + @contextmanager def chroot(path): """ @@ -240,25 +254,6 @@ def search_output_from_cmd(cmd: str, find_regex: typing.Pattern) -> typing.Match ) -def get_free_mem_ssh(ssh_connection): - """ - Get how much free memory in kB a guest sees, over ssh. - - :param ssh_connection: connection to the guest - :return: available mem column output of 'free' - """ - _, stdout, stderr = ssh_connection.run("cat /proc/meminfo | grep MemAvailable") - assert stderr == "" - - # Split "MemAvailable: 123456 kB" and validate it - meminfo_data = stdout.split() - if len(meminfo_data) == 3: - # Return the middle element in the array - return int(meminfo_data[1]) - - raise Exception("Available memory not found in `/proc/meminfo") - - def _format_output_message(proc, stdout, stderr): output_message = f"\n[{proc.pid}] Command:\n{proc.args}" # Append stdout/stderr to the output message diff --git a/tests/host_tools/fcmetrics.py b/tests/host_tools/fcmetrics.py index 0dcff5eed00..3b65901b5aa 100644 --- a/tests/host_tools/fcmetrics.py +++ b/tests/host_tools/fcmetrics.py @@ -202,6 +202,8 @@ def validate_fc_metrics(metrics): "machine_cfg_fails", "mmds_count", "mmds_fails", + "hotplug_memory_count", + "hotplug_memory_fails", ], "put_api_requests": [ "actions_count", @@ -300,6 +302,21 @@ def validate_fc_metrics(metrics): "activate_fails", "queue_event_fails", "queue_event_count", + "plug_count", + "plug_bytes", + "plug_fails", + {"plug_agg": latency_agg_metrics_fields}, + "unplug_count", + "unplug_bytes", + "unplug_fails", + "unplug_discard_fails", + {"unplug_agg": latency_agg_metrics_fields}, + "state_count", + "state_fails", + {"state_agg": latency_agg_metrics_fields}, + "unplug_all_count", + "unplug_all_fails", + {"unplug_all_agg": latency_agg_metrics_fields}, ], } diff --git a/tests/host_tools/memory.py b/tests/host_tools/memory.py index 134147724cd..c09dae0d206 100644 --- a/tests/host_tools/memory.py +++ b/tests/host_tools/memory.py @@ -99,7 +99,9 @@ def is_guest_mem_x86(self, size, guest_mem_bytes): Checks if a region is a guest memory region based on x86_64 physical memory layout """ - return size in ( + # it could be bigger if hotplugging is enabled + # if it's bigger, it's likely not from FC because we don't have big allocations + return size >= guest_mem_bytes or size in ( # memory fits before the first gap guest_mem_bytes, # guest memory spans at least two regions & memory fits before the second gap @@ -121,7 +123,9 @@ def is_guest_mem_arch64(self, size, guest_mem_bytes): Checks if a region is a guest memory region based on ARM64 physical memory layout """ - return size in ( + # it could be bigger if hotplugging is enabled + # if it's bigger, it's likely not from FC because we don't have big allocations + return size >= guest_mem_bytes or size in ( # guest memory fits before the gap guest_mem_bytes, # guest memory fills the space before the gap diff --git a/tests/integration_tests/functional/test_api.py b/tests/integration_tests/functional/test_api.py index 81454559990..ac929941dca 100644 --- a/tests/integration_tests/functional/test_api.py +++ b/tests/integration_tests/functional/test_api.py @@ -981,13 +981,14 @@ def test_api_entropy(uvm_plain): test_microvm.api.entropy.put() -def test_api_memory_hotplug(uvm_plain): +def test_api_memory_hotplug(uvm_plain_6_1): """ Test hotplug related API commands. """ - test_microvm = uvm_plain + test_microvm = uvm_plain_6_1 test_microvm.spawn() test_microvm.basic_config() + test_microvm.add_net_iface() # Adding hotplug memory region should be OK. test_microvm.api.memory_hotplug.put( @@ -1002,6 +1003,10 @@ def test_api_memory_hotplug(uvm_plain): with pytest.raises(AssertionError): test_microvm.api.memory_hotplug.get() + # Patch API should be rejected before boot + with pytest.raises(RuntimeError, match=NOT_SUPPORTED_BEFORE_START): + test_microvm.api.memory_hotplug.patch(requested_size_mib=512) + # Start the microvm test_microvm.start() @@ -1013,6 +1018,11 @@ def test_api_memory_hotplug(uvm_plain): status = test_microvm.api.memory_hotplug.get().json() assert status["total_size_mib"] == 1024 + # Patch API should work after boot + test_microvm.api.memory_hotplug.patch(requested_size_mib=512) + status = test_microvm.api.memory_hotplug.get().json() + assert status["requested_size_mib"] == 512 + def test_api_balloon(uvm_nano): """ diff --git a/tests/integration_tests/functional/test_balloon.py b/tests/integration_tests/functional/test_balloon.py index f8960bedb6d..19b1651c72a 100644 --- a/tests/integration_tests/functional/test_balloon.py +++ b/tests/integration_tests/functional/test_balloon.py @@ -9,12 +9,13 @@ import pytest import requests -from framework.utils import check_output, get_free_mem_ssh +from framework.guest_stats import MeminfoGuest +from framework.utils import get_resident_memory STATS_POLLING_INTERVAL_S = 1 -def get_stable_rss_mem_by_pid(pid, percentage_delta=1): +def get_stable_rss_mem(uvm, percentage_delta=1): """ Get the RSS memory that a guest uses, given the pid of the guest. @@ -22,22 +23,16 @@ def get_stable_rss_mem_by_pid(pid, percentage_delta=1): Or print a warning if this does not happen. """ - # All values are reported as KiB - - def get_rss_from_pmap(): - _, output, _ = check_output("pmap -X {}".format(pid)) - return int(output.split("\n")[-2].split()[1], 10) - first_rss = 0 second_rss = 0 for _ in range(5): - first_rss = get_rss_from_pmap() + first_rss = get_resident_memory(uvm.ps) time.sleep(1) - second_rss = get_rss_from_pmap() + second_rss = get_resident_memory(uvm.ps) abs_diff = abs(first_rss - second_rss) abs_delta = abs_diff / first_rss * 100 print( - f"RSS readings: old: {first_rss} new: {second_rss} abs_diff: {abs_diff} abs_delta: {abs_delta}" + f"RSS readings (bytes): old: {first_rss} new: {second_rss} abs_diff: {abs_diff} abs_delta: {abs_delta}" ) if abs_delta < percentage_delta: return second_rss @@ -87,25 +82,24 @@ def make_guest_dirty_memory(ssh_connection, amount_mib=32): def _test_rss_memory_lower(test_microvm): """Check inflating the balloon makes guest use less rss memory.""" # Get the firecracker pid, and open an ssh connection. - firecracker_pid = test_microvm.firecracker_pid ssh_connection = test_microvm.ssh # Using deflate_on_oom, get the RSS as low as possible test_microvm.api.balloon.patch(amount_mib=200) # Get initial rss consumption. - init_rss = get_stable_rss_mem_by_pid(firecracker_pid) + init_rss = get_stable_rss_mem(test_microvm) # Get the balloon back to 0. test_microvm.api.balloon.patch(amount_mib=0) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Dirty memory, then inflate balloon and get ballooned rss consumption. make_guest_dirty_memory(ssh_connection, amount_mib=32) test_microvm.api.balloon.patch(amount_mib=200) - balloon_rss = get_stable_rss_mem_by_pid(firecracker_pid) + balloon_rss = get_stable_rss_mem(test_microvm) # Check that the ballooning reclaimed the memory. assert balloon_rss - init_rss <= 15000 @@ -149,18 +143,18 @@ def test_inflate_reduces_free(uvm_plain_any): # Start the microvm test_microvm.start() - firecracker_pid = test_microvm.firecracker_pid + meminfo = MeminfoGuest(test_microvm) # Get the free memory before ballooning. - available_mem_deflated = get_free_mem_ssh(test_microvm.ssh) + available_mem_deflated = meminfo.get().mem_free.kib() # Inflate 64 MB == 16384 page balloon. test_microvm.api.balloon.patch(amount_mib=64) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Get the free memory after ballooning. - available_mem_inflated = get_free_mem_ssh(test_microvm.ssh) + available_mem_inflated = meminfo.get().mem_free.kib() # Assert that ballooning reclaimed about 64 MB of memory. assert available_mem_inflated <= available_mem_deflated - 85 * 64000 / 100 @@ -195,19 +189,18 @@ def test_deflate_on_oom(uvm_plain_any, deflate_on_oom): # Start the microvm. test_microvm.start() - firecracker_pid = test_microvm.firecracker_pid # We get an initial reading of the RSS, then calculate the amount # we need to inflate the balloon with by subtracting it from the # VM size and adding an offset of 50 MiB in order to make sure we # get a lower reading than the initial one. - initial_rss = get_stable_rss_mem_by_pid(firecracker_pid) + initial_rss = get_stable_rss_mem(test_microvm) inflate_size = 256 - (int(initial_rss / 1024) + 50) # Inflate the balloon test_microvm.api.balloon.patch(amount_mib=inflate_size) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Check that using memory leads to the balloon device automatically # deflate (or not). @@ -250,39 +243,38 @@ def test_reinflate_balloon(uvm_plain_any): # Start the microvm. test_microvm.start() - firecracker_pid = test_microvm.firecracker_pid # First inflate the balloon to free up the uncertain amount of memory # used by the kernel at boot and establish a baseline, then give back # the memory. test_microvm.api.balloon.patch(amount_mib=200) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) test_microvm.api.balloon.patch(amount_mib=0) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Get the guest to dirty memory. make_guest_dirty_memory(test_microvm.ssh, amount_mib=32) - first_reading = get_stable_rss_mem_by_pid(firecracker_pid) + first_reading = get_stable_rss_mem(test_microvm) # Now inflate the balloon. test_microvm.api.balloon.patch(amount_mib=200) - second_reading = get_stable_rss_mem_by_pid(firecracker_pid) + second_reading = get_stable_rss_mem(test_microvm) # Now deflate the balloon. test_microvm.api.balloon.patch(amount_mib=0) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Now have the guest dirty memory again. make_guest_dirty_memory(test_microvm.ssh, amount_mib=32) - third_reading = get_stable_rss_mem_by_pid(firecracker_pid) + third_reading = get_stable_rss_mem(test_microvm) # Now inflate the balloon again. test_microvm.api.balloon.patch(amount_mib=200) - fourth_reading = get_stable_rss_mem_by_pid(firecracker_pid) + fourth_reading = get_stable_rss_mem(test_microvm) # Check that the memory used is the same after regardless of the previous # inflate history of the balloon (with the third reading being allowed @@ -309,10 +301,9 @@ def test_size_reduction(uvm_plain_any): # Start the microvm. test_microvm.start() - firecracker_pid = test_microvm.firecracker_pid # Check memory usage. - first_reading = get_stable_rss_mem_by_pid(firecracker_pid) + first_reading = get_stable_rss_mem(test_microvm) # Have the guest drop its caches. test_microvm.ssh.run("sync; echo 3 > /proc/sys/vm/drop_caches") @@ -328,7 +319,7 @@ def test_size_reduction(uvm_plain_any): test_microvm.api.balloon.patch(amount_mib=inflate_size) # Check memory usage again. - second_reading = get_stable_rss_mem_by_pid(firecracker_pid) + second_reading = get_stable_rss_mem(test_microvm) # There should be a reduction of at least 10MB. assert first_reading - second_reading >= 10000 @@ -353,7 +344,6 @@ def test_stats(uvm_plain_any): # Start the microvm. test_microvm.start() - firecracker_pid = test_microvm.firecracker_pid # Give Firecracker enough time to poll the stats at least once post-boot time.sleep(STATS_POLLING_INTERVAL_S * 2) @@ -371,7 +361,7 @@ def test_stats(uvm_plain_any): make_guest_dirty_memory(test_microvm.ssh, amount_mib=10) time.sleep(1) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Make sure that the stats catch the page faults. after_workload_stats = test_microvm.api.balloon_stats.get().json() @@ -380,7 +370,7 @@ def test_stats(uvm_plain_any): # Now inflate the balloon with 10MB of pages. test_microvm.api.balloon.patch(amount_mib=10) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Get another reading of the stats after the polling interval has passed. inflated_stats = test_microvm.api.balloon_stats.get().json() @@ -393,7 +383,7 @@ def test_stats(uvm_plain_any): # available memory. test_microvm.api.balloon.patch(amount_mib=0) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Get another reading of the stats after the polling interval has passed. deflated_stats = test_microvm.api.balloon_stats.get().json() @@ -421,13 +411,12 @@ def test_stats_update(uvm_plain_any): # Start the microvm. test_microvm.start() - firecracker_pid = test_microvm.firecracker_pid # Dirty 30MB of pages. make_guest_dirty_memory(test_microvm.ssh, amount_mib=30) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Get an initial reading of the stats. initial_stats = test_microvm.api.balloon_stats.get().json() @@ -477,17 +466,14 @@ def test_balloon_snapshot(uvm_plain_any, microvm_factory): make_guest_dirty_memory(vm.ssh, amount_mib=60) time.sleep(1) - # Get the firecracker pid, and open an ssh connection. - firecracker_pid = vm.firecracker_pid - # Check memory usage. - first_reading = get_stable_rss_mem_by_pid(firecracker_pid) + first_reading = get_stable_rss_mem(vm) # Now inflate the balloon with 20MB of pages. vm.api.balloon.patch(amount_mib=20) # Check memory usage again. - second_reading = get_stable_rss_mem_by_pid(firecracker_pid) + second_reading = get_stable_rss_mem(vm) # There should be a reduction in RSS, but it's inconsistent. # We only test that the reduction happens. @@ -496,28 +482,25 @@ def test_balloon_snapshot(uvm_plain_any, microvm_factory): snapshot = vm.snapshot_full() microvm = microvm_factory.build_from_snapshot(snapshot) - # Get the firecracker from snapshot pid, and open an ssh connection. - firecracker_pid = microvm.firecracker_pid - # Wait out the polling interval, then get the updated stats. time.sleep(STATS_POLLING_INTERVAL_S * 2) stats_after_snap = microvm.api.balloon_stats.get().json() # Check memory usage. - third_reading = get_stable_rss_mem_by_pid(firecracker_pid) + third_reading = get_stable_rss_mem(microvm) # Dirty 60MB of pages. make_guest_dirty_memory(microvm.ssh, amount_mib=60) # Check memory usage. - fourth_reading = get_stable_rss_mem_by_pid(firecracker_pid) + fourth_reading = get_stable_rss_mem(microvm) assert fourth_reading > third_reading # Inflate the balloon with another 20MB of pages. microvm.api.balloon.patch(amount_mib=40) - fifth_reading = get_stable_rss_mem_by_pid(firecracker_pid) + fifth_reading = get_stable_rss_mem(microvm) # There should be a reduction in RSS, but it's inconsistent. # We only test that the reduction happens. @@ -557,15 +540,14 @@ def test_memory_scrub(uvm_plain_any): microvm.api.balloon.patch(amount_mib=60) # Get the firecracker pid, and open an ssh connection. - firecracker_pid = microvm.firecracker_pid # Wait for the inflate to complete. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(microvm) # Deflate the balloon completely. microvm.api.balloon.patch(amount_mib=0) # Wait for the deflate to complete. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(microvm) microvm.ssh.check_output("/usr/local/bin/readmem {} {}".format(60, 1)) diff --git a/tests/integration_tests/functional/test_memory_hp.py b/tests/integration_tests/functional/test_memory_hp.py deleted file mode 100644 index b2132d6c9ed..00000000000 --- a/tests/integration_tests/functional/test_memory_hp.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 - -"""Tests for verifying the virtio-mem is working correctly""" - - -def test_virtio_mem_detected(uvm_plain_6_1): - """ - Check that the guest kernel has enabled PV steal time. - """ - uvm = uvm_plain_6_1 - uvm.spawn() - uvm.memory_monitor = None - uvm.basic_config( - boot_args="console=ttyS0 reboot=k panic=1 memhp_default_state=online_movable" - ) - uvm.add_net_iface() - uvm.api.memory_hotplug.put(total_size_mib=1024) - uvm.start() - - _, stdout, _ = uvm.ssh.check_output("dmesg | grep 'virtio_mem'") - for line in stdout.splitlines(): - _, key, value = line.strip().split(":") - key = key.strip() - value = int(value.strip(), base=0) - match key: - case "start address": - assert value == (512 << 30), "start address doesn't match" - case "region size": - assert value == 1024 << 20, "region size doesn't match" - case "device block size": - assert value == 2 << 20, "block size doesn't match" - case "plugged size": - assert value == 0, "plugged size doesn't match" - case "requested size": - assert value == 0, "requested size doesn't match" - case _: - continue diff --git a/tests/integration_tests/functional/test_snapshot_restore_cross_kernel.py b/tests/integration_tests/functional/test_snapshot_restore_cross_kernel.py index bfe5316d9e5..253502a2d1f 100644 --- a/tests/integration_tests/functional/test_snapshot_restore_cross_kernel.py +++ b/tests/integration_tests/functional/test_snapshot_restore_cross_kernel.py @@ -20,7 +20,7 @@ from framework.utils_cpu_templates import get_supported_cpu_templates from framework.utils_vsock import check_vsock_device from integration_tests.functional.test_balloon import ( - get_stable_rss_mem_by_pid, + get_stable_rss_mem, make_guest_dirty_memory, ) @@ -28,21 +28,18 @@ def _test_balloon(microvm): - # Get the firecracker pid. - firecracker_pid = microvm.firecracker_pid - # Check memory usage. - first_reading = get_stable_rss_mem_by_pid(firecracker_pid) + first_reading = get_stable_rss_mem(microvm) # Dirty 300MB of pages. make_guest_dirty_memory(microvm.ssh, amount_mib=300) # Check memory usage again. - second_reading = get_stable_rss_mem_by_pid(firecracker_pid) + second_reading = get_stable_rss_mem(microvm) assert second_reading > first_reading # Inflate the balloon. Get back 200MB. microvm.api.balloon.patch(amount_mib=200) - third_reading = get_stable_rss_mem_by_pid(firecracker_pid) + third_reading = get_stable_rss_mem(microvm) # Ensure that there is a reduction in RSS. assert second_reading > third_reading diff --git a/tests/integration_tests/performance/test_hotplug_memory.py b/tests/integration_tests/performance/test_hotplug_memory.py new file mode 100644 index 00000000000..6aa8649c5d1 --- /dev/null +++ b/tests/integration_tests/performance/test_hotplug_memory.py @@ -0,0 +1,292 @@ +# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Tests for verifying the virtio-mem is working correctly + +This file also contains functional tests for virtio-mem because they need to be +run on an ag=1 host due to the use of HugePages. +""" + +import pytest +from packaging import version +from tenacity import Retrying, retry_if_exception_type, stop_after_delay, wait_fixed + +from framework.guest_stats import MeminfoGuest +from framework.microvm import HugePagesConfig +from framework.utils import get_kernel_version, get_resident_memory + +MEMHP_BOOTARGS = "console=ttyS0 reboot=k panic=1 memhp_default_state=online_movable" +DEFAULT_CONFIG = {"total_size_mib": 1024, "slot_size_mib": 128, "block_size_mib": 2} + + +def uvm_booted_memhp( + uvm, rootfs, _microvm_factory, vhost_user, memhp_config, huge_pages, _uffd_handler +): + """Boots a VM with the given memory hotplugging config""" + + uvm.spawn() + uvm.memory_monitor = None + if vhost_user: + # We need to setup ssh keys manually because we did not specify rootfs + # in microvm_factory.build method + ssh_key = rootfs.with_suffix(".id_rsa") + uvm.ssh_key = ssh_key + uvm.basic_config( + boot_args=MEMHP_BOOTARGS, add_root_device=False, huge_pages=huge_pages + ) + uvm.add_vhost_user_drive( + "rootfs", rootfs, is_root_device=True, is_read_only=True + ) + else: + uvm.basic_config(boot_args=MEMHP_BOOTARGS, huge_pages=huge_pages) + + uvm.api.memory_hotplug.put(**memhp_config) + uvm.add_net_iface() + uvm.start() + return uvm + + +def uvm_resumed_memhp( + uvm_plain, + rootfs, + microvm_factory, + vhost_user, + memhp_config, + huge_pages, + uffd_handler, +): + """Restores a VM with the given memory hotplugging config after booting and snapshotting""" + if vhost_user: + pytest.skip("vhost-user doesn't support snapshot/restore") + if huge_pages and huge_pages != HugePagesConfig.NONE and not uffd_handler: + pytest.skip("Hugepages requires a UFFD handler") + uvm = uvm_booted_memhp( + uvm_plain, rootfs, microvm_factory, vhost_user, memhp_config, huge_pages, None + ) + return microvm_factory.clone_uvm(uvm, uffd_handler_name=uffd_handler) + + +@pytest.fixture( + params=[ + (uvm_booted_memhp, False, HugePagesConfig.NONE, None), + (uvm_booted_memhp, False, HugePagesConfig.HUGETLBFS_2MB, None), + (uvm_booted_memhp, True, HugePagesConfig.NONE, None), + (uvm_resumed_memhp, False, HugePagesConfig.NONE, None), + (uvm_resumed_memhp, False, HugePagesConfig.NONE, "on_demand"), + (uvm_resumed_memhp, False, HugePagesConfig.HUGETLBFS_2MB, "on_demand"), + ], + ids=[ + "booted", + "booted-huge-pages", + "booted-vhost-user", + "resumed", + "resumed-uffd", + "resumed-uffd-huge-pages", + ], +) +def uvm_any_memhp(request, uvm_plain_6_1, rootfs, microvm_factory): + """Fixture that yields a booted or resumed VM with memory hotplugging""" + ctor, vhost_user, huge_pages, uffd_handler = request.param + yield ctor( + uvm_plain_6_1, + rootfs, + microvm_factory, + vhost_user, + DEFAULT_CONFIG, + huge_pages, + uffd_handler, + ) + + +def supports_hugetlbfs_discard(): + """Returns True if the kernel supports hugetlbfs discard""" + return version.parse(get_kernel_version()) >= version.parse("5.18.0") + + +def validate_metrics(uvm): + """Validates that there are no fails in the metrics""" + metrics_to_check = ["plug_fails", "unplug_fails", "unplug_all_fails", "state_fails"] + if supports_hugetlbfs_discard(): + metrics_to_check.append("unplug_discard_fails") + uvm.flush_metrics() + for metrics in uvm.get_all_metrics(): + for k in metrics_to_check: + assert ( + metrics["memory_hotplug"][k] == 0 + ), f"{k}={metrics[k]} is greater than zero" + + +def check_device_detected(uvm): + """ + Check that the guest kernel has enabled virtio-mem. + """ + hp_config = uvm.api.memory_hotplug.get().json() + _, stdout, _ = uvm.ssh.check_output("dmesg | grep 'virtio_mem'") + for line in stdout.splitlines(): + _, key, value = line.strip().split(":") + key = key.strip() + value = int(value.strip(), base=0) + match key: + case "start address": + assert value == (512 << 30), "start address doesn't match" + case "region size": + assert ( + value == hp_config["total_size_mib"] << 20 + ), "region size doesn't match" + case "device block size": + assert ( + value == hp_config["block_size_mib"] << 20 + ), "block size doesn't match" + case "plugged size": + assert value == 0, "plugged size doesn't match" + case "requested size": + assert value == 0, "requested size doesn't match" + case _: + continue + + +def check_memory_usable(uvm): + """Allocates memory to verify it's usable (5% margin to avoid OOM-kill)""" + mem_available = MeminfoGuest(uvm).get().mem_available.bytes() + # number of 64b ints to allocate as 95% of available memory + count = mem_available * 95 // 100 // 8 + + uvm.ssh.check_output( + f"python3 -c 'Q = 0x0123456789abcdef; a = [Q] * {count}; assert all(q == Q for q in a)'" + ) + + +def check_hotplug(uvm, requested_size_mib): + """Verifies memory can be hot(un)plugged""" + meminfo = MeminfoGuest(uvm) + mem_total_fixed = ( + meminfo.get().mem_total.mib() + - uvm.api.memory_hotplug.get().json()["plugged_size_mib"] + ) + uvm.hotplug_memory(requested_size_mib) + + # verify guest driver received the request + _, stdout, _ = uvm.ssh.check_output( + "dmesg | grep 'virtio_mem' | grep 'requested size' | tail -1" + ) + assert ( + int(stdout.strip().split(":")[-1].strip(), base=0) == requested_size_mib << 20 + ) + + for attempt in Retrying( + retry=retry_if_exception_type(AssertionError), + stop=stop_after_delay(5), + wait=wait_fixed(1), + reraise=True, + ): + with attempt: + # verify guest driver executed the request + mem_total_after = meminfo.get().mem_total.mib() + assert mem_total_after == mem_total_fixed + requested_size_mib + + +def check_hotunplug(uvm, requested_size_mib): + """Verifies memory can be hotunplugged and gets released""" + + rss_before = get_resident_memory(uvm.ps) + + check_hotplug(uvm, requested_size_mib) + + rss_after = get_resident_memory(uvm.ps) + + print(f"RSS before: {rss_before}, after: {rss_after}") + + huge_pages = HugePagesConfig(uvm.api.machine_config.get().json()["huge_pages"]) + if huge_pages == HugePagesConfig.HUGETLBFS_2MB and supports_hugetlbfs_discard(): + assert rss_after < rss_before, "RSS didn't decrease" + + +def test_virtio_mem_hotplug_hotunplug(uvm_any_memhp): + """ + Check that memory can be hotplugged into the VM. + """ + uvm = uvm_any_memhp + check_device_detected(uvm) + + check_hotplug(uvm, 1024) + check_memory_usable(uvm) + + check_hotunplug(uvm, 0) + + # Check it works again + check_hotplug(uvm, 1024) + check_memory_usable(uvm) + + validate_metrics(uvm) + + +@pytest.mark.parametrize( + "memhp_config", + [ + {"total_size_mib": 256, "slot_size_mib": 128, "block_size_mib": 64}, + {"total_size_mib": 256, "slot_size_mib": 128, "block_size_mib": 128}, + {"total_size_mib": 256, "slot_size_mib": 256, "block_size_mib": 64}, + {"total_size_mib": 256, "slot_size_mib": 256, "block_size_mib": 256}, + ], + ids=["all_different", "slot_sized_block", "single_slot", "single_block"], +) +def test_virtio_mem_configs(uvm_plain_6_1, memhp_config): + """ + Check that the virtio mem device is working as expected for different configs + """ + uvm = uvm_booted_memhp(uvm_plain_6_1, None, None, False, memhp_config, None, None) + if not uvm.pci_enabled: + pytest.skip( + "Skip tests on MMIO transport to save time as we don't expect any difference." + ) + + check_device_detected(uvm) + + for size in range( + 0, memhp_config["total_size_mib"] + 1, memhp_config["block_size_mib"] + ): + check_hotplug(uvm, size) + + check_memory_usable(uvm) + + for size in range( + memhp_config["total_size_mib"] - memhp_config["block_size_mib"], + -1, + -memhp_config["block_size_mib"], + ): + check_hotunplug(uvm, size) + + validate_metrics(uvm) + + +def test_snapshot_restore_persistence(uvm_plain_6_1, microvm_factory): + """ + Check that hptplugged memory is persisted across snapshot/restore. + """ + if not uvm_plain_6_1.pci_enabled: + pytest.skip( + "Skip tests on MMIO transport to save time as we don't expect any difference." + ) + uvm = uvm_booted_memhp( + uvm_plain_6_1, None, microvm_factory, False, DEFAULT_CONFIG, None, None + ) + + uvm.hotplug_memory(1024) + + # Increase /dev/shm size as it defaults to half of the boot memory + uvm.ssh.check_output("mount -o remount,size=1024M -t tmpfs tmpfs /dev/shm") + + uvm.ssh.check_output("dd if=/dev/urandom of=/dev/shm/mem_hp_test bs=1M count=1024") + + _, checksum_before, _ = uvm.ssh.check_output("sha256sum /dev/shm/mem_hp_test") + + restored_vm = microvm_factory.clone_uvm(uvm) + + _, checksum_after, _ = restored_vm.ssh.check_output( + "sha256sum /dev/shm/mem_hp_test" + ) + + assert checksum_before == checksum_after, "Checksums didn't match" + + validate_metrics(restored_vm)