From 5377a55ac93abbdbe433c39653262881126e5e17 Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Wed, 24 Sep 2025 15:43:23 +0100 Subject: [PATCH 01/14] feat(virtio-mem): add static allocation of hotpluggable memory Allocate the memory that will be used for hotplugging. Initially, this memory will be registered with KVM, but that will change later when we add dynamic slot support. Signed-off-by: Riccardo Mancini --- src/vmm/src/builder.rs | 15 ++++++++++++++- src/vmm/src/resources.rs | 12 ++++++++++++ src/vmm/src/vstate/memory.rs | 10 ++++++++++ src/vmm/src/vstate/vm.rs | 13 +++++++++++++ 4 files changed, 49 insertions(+), 1 deletion(-) diff --git a/src/vmm/src/builder.rs b/src/vmm/src/builder.rs index cb901a78c63..5556c62e44f 100644 --- a/src/vmm/src/builder.rs +++ b/src/vmm/src/builder.rs @@ -32,7 +32,7 @@ use crate::device_manager::{ use crate::devices::acpi::vmgenid::VmGenIdError; use crate::devices::virtio::balloon::Balloon; use crate::devices::virtio::block::device::Block; -use crate::devices::virtio::mem::VirtioMem; +use crate::devices::virtio::mem::{VIRTIO_MEM_GUEST_ADDRESS, VirtioMem}; use crate::devices::virtio::net::Net; use crate::devices::virtio::rng::Entropy; use crate::devices::virtio::vsock::{Vsock, VsockUnixBackend}; @@ -44,6 +44,7 @@ use crate::persist::{MicrovmState, MicrovmStateError}; use crate::resources::VmResources; use crate::seccomp::BpfThreadMap; use crate::snapshot::Persist; +use crate::utils::mib_to_bytes; use crate::vmm_config::instance_info::InstanceInfo; use crate::vmm_config::machine_config::MachineConfigError; use crate::vmm_config::memory_hotplug::MemoryHotplugConfig; @@ -172,6 +173,18 @@ pub fn build_microvm_for_boot( let (mut vcpus, vcpus_exit_evt) = vm.create_vcpus(vm_resources.machine_config.vcpu_count)?; vm.register_dram_memory_regions(guest_memory)?; + // Allocate memory as soon as possible to make hotpluggable memory available to all consumers, + // before they clone the GuestMemoryMmap object + if let Some(memory_hotplug) = &vm_resources.memory_hotplug { + let hotplug_memory_region = vm_resources + .allocate_memory_region( + VIRTIO_MEM_GUEST_ADDRESS, + mib_to_bytes(memory_hotplug.total_size_mib), + ) + .map_err(StartMicrovmError::GuestMemory)?; + vm.register_hotpluggable_memory_region(hotplug_memory_region)?; + } + let mut device_manager = DeviceManager::new( event_manager, &vcpus_exit_evt, diff --git a/src/vmm/src/resources.rs b/src/vmm/src/resources.rs index 03d9f9b0c77..25fbc14e5db 100644 --- a/src/vmm/src/resources.rs +++ b/src/vmm/src/resources.rs @@ -536,6 +536,18 @@ impl VmResources { crate::arch::arch_memory_regions(mib_to_bytes(self.machine_config.mem_size_mib)); self.allocate_memory_regions(®ions) } + + /// Allocates a single guest memory region. + pub fn allocate_memory_region( + &self, + start: GuestAddress, + size: usize, + ) -> Result { + Ok(self + .allocate_memory_regions(&[(start, size)])? + .pop() + .unwrap()) + } } impl From<&VmResources> for VmmConfig { diff --git a/src/vmm/src/vstate/memory.rs b/src/vmm/src/vstate/memory.rs index 1bf7cda6342..7fa349f6f4a 100644 --- a/src/vmm/src/vstate/memory.rs +++ b/src/vmm/src/vstate/memory.rs @@ -57,6 +57,8 @@ pub enum MemoryError { pub enum GuestRegionType { /// Guest DRAM Dram, + /// Hotpluggable memory + Hotpluggable, } /// An extension to GuestMemoryRegion that stores the type of region, and the KVM slot @@ -80,6 +82,14 @@ impl GuestRegionMmapExt { } } + pub(crate) fn hotpluggable_from_mmap_region(region: GuestRegionMmap, slot: u32) -> Self { + GuestRegionMmapExt { + inner: region, + region_type: GuestRegionType::Hotpluggable, + slot, + } + } + pub(crate) fn from_state( region: GuestRegionMmap, state: &GuestMemoryRegionState, diff --git a/src/vmm/src/vstate/vm.rs b/src/vmm/src/vstate/vm.rs index cc6afb722a2..67771473355 100644 --- a/src/vmm/src/vstate/vm.rs +++ b/src/vmm/src/vstate/vm.rs @@ -222,6 +222,19 @@ impl Vm { Ok(()) } + /// Register a new hotpluggable region to this [`Vm`]. + pub fn register_hotpluggable_memory_region( + &mut self, + region: GuestRegionMmap, + ) -> Result<(), VmError> { + let arcd_region = Arc::new(GuestRegionMmapExt::hotpluggable_from_mmap_region( + region, + self.allocate_slot_ids(1)?, + )); + + self._register_memory_region(arcd_region) + } + /// Register a list of new memory regions to this [`Vm`]. /// /// Note: regions and state.regions need to be in the same order. From c324adc77d3dbb506a9bb1462f17ac64b344be06 Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Wed, 24 Sep 2025 15:51:01 +0100 Subject: [PATCH 02/14] feat(virtio-mem): wire PATCH support Wire up PATCH requests with the virtio-mem device. All the validation is performed in the device, but the actual operation is not yet implemented. Signed-off-by: Riccardo Mancini --- .../src/api_server/parsed_request.rs | 5 ++- .../src/api_server/request/hotplug/memory.rs | 37 +++++++++++++++- src/vmm/src/devices/virtio/mem/device.rs | 43 +++++++++++++++++++ src/vmm/src/lib.rs | 9 ++++ src/vmm/src/logger/metrics.rs | 6 +++ src/vmm/src/rpc_interface.rs | 23 +++++++++- src/vmm/src/vmm_config/memory_hotplug.rs | 8 ++++ 7 files changed, 127 insertions(+), 4 deletions(-) diff --git a/src/firecracker/src/api_server/parsed_request.rs b/src/firecracker/src/api_server/parsed_request.rs index 287742ede41..3d21695ce3e 100644 --- a/src/firecracker/src/api_server/parsed_request.rs +++ b/src/firecracker/src/api_server/parsed_request.rs @@ -28,7 +28,7 @@ use super::request::snapshot::{parse_patch_vm_state, parse_put_snapshot}; use super::request::version::parse_get_version; use super::request::vsock::parse_put_vsock; use crate::api_server::request::hotplug::memory::{ - parse_get_memory_hotplug, parse_put_memory_hotplug, + parse_get_memory_hotplug, parse_patch_memory_hotplug, parse_put_memory_hotplug, }; use crate::api_server::request::serial::parse_put_serial; @@ -119,6 +119,9 @@ impl TryFrom<&Request> for ParsedRequest { parse_patch_net(body, path_tokens.next()) } (Method::Patch, "vm", Some(body)) => parse_patch_vm_state(body), + (Method::Patch, "hotplug", Some(body)) if path_tokens.next() == Some("memory") => { + parse_patch_memory_hotplug(body) + } (Method::Patch, _, None) => method_to_error(Method::Patch), (method, unknown_uri, _) => Err(RequestError::InvalidPathMethod( unknown_uri.to_string(), diff --git a/src/firecracker/src/api_server/request/hotplug/memory.rs b/src/firecracker/src/api_server/request/hotplug/memory.rs index 4bdeec73a6d..5ec514ca964 100644 --- a/src/firecracker/src/api_server/request/hotplug/memory.rs +++ b/src/firecracker/src/api_server/request/hotplug/memory.rs @@ -4,7 +4,7 @@ use micro_http::Body; use vmm::logger::{IncMetric, METRICS}; use vmm::rpc_interface::VmmAction; -use vmm::vmm_config::memory_hotplug::MemoryHotplugConfig; +use vmm::vmm_config::memory_hotplug::{MemoryHotplugConfig, MemoryHotplugSizeUpdate}; use crate::api_server::parsed_request::{ParsedRequest, RequestError}; @@ -23,11 +23,23 @@ pub(crate) fn parse_get_memory_hotplug() -> Result Ok(ParsedRequest::new_sync(VmmAction::GetMemoryHotplugStatus)) } +pub(crate) fn parse_patch_memory_hotplug(body: &Body) -> Result { + METRICS.patch_api_requests.hotplug_memory_count.inc(); + let config = + serde_json::from_slice::(body.raw()).inspect_err(|_| { + METRICS.patch_api_requests.hotplug_memory_fails.inc(); + })?; + Ok(ParsedRequest::new_sync(VmmAction::UpdateMemoryHotplugSize( + config, + ))) +} + #[cfg(test)] mod tests { use vmm::devices::virtio::mem::{ VIRTIO_MEM_DEFAULT_BLOCK_SIZE_MIB, VIRTIO_MEM_DEFAULT_SLOT_SIZE_MIB, }; + use vmm::vmm_config::memory_hotplug::MemoryHotplugSizeUpdate; use super::*; use crate::api_server::parsed_request::tests::vmm_action_from_request; @@ -80,4 +92,27 @@ mod tests { VmmAction::GetMemoryHotplugStatus ); } + + #[test] + fn test_parse_patch_memory_hotplug_request() { + parse_patch_memory_hotplug(&Body::new("invalid_payload")).unwrap_err(); + + // PATCH with invalid fields. + let body = r#"{ + "requested_size_mib": "bar" + }"#; + parse_patch_memory_hotplug(&Body::new(body)).unwrap_err(); + + // PATCH with valid input fields. + let body = r#"{ + "requested_size_mib": 2048 + }"#; + let expected_config = MemoryHotplugSizeUpdate { + requested_size_mib: 2048, + }; + assert_eq!( + vmm_action_from_request(parse_patch_memory_hotplug(&Body::new(body)).unwrap()), + VmmAction::UpdateMemoryHotplugSize(expected_config) + ); + } } diff --git a/src/vmm/src/devices/virtio/mem/device.rs b/src/vmm/src/devices/virtio/mem/device.rs index c8bcb6cbf53..8fd938ec9fe 100644 --- a/src/vmm/src/devices/virtio/mem/device.rs +++ b/src/vmm/src/devices/virtio/mem/device.rs @@ -43,6 +43,10 @@ pub enum VirtioMemError { EventFd(#[from] io::Error), /// Received error while sending an interrupt: {0} InterruptError(#[from] InterruptError), + /// Size {0} is invalid: it must be a multiple of block size and less than the total size + InvalidSize(u64), + /// Device is not active + DeviceNotActive, } #[derive(Debug)] @@ -200,6 +204,45 @@ impl VirtioMem { pub(crate) fn activate_event(&self) -> &EventFd { &self.activate_event } + + /// Updates the requested size of the virtio-mem device. + pub fn update_requested_size( + &mut self, + requested_size_mib: usize, + ) -> Result<(), VirtioMemError> { + let requested_size = usize_to_u64(mib_to_bytes(requested_size_mib)); + if !self.is_activated() { + return Err(VirtioMemError::DeviceNotActive); + } + + if requested_size % self.config.block_size != 0 { + return Err(VirtioMemError::InvalidSize(requested_size)); + } + if requested_size > self.config.region_size { + return Err(VirtioMemError::InvalidSize(requested_size)); + } + + // usable_region_size can only be increased + if self.config.usable_region_size < requested_size { + self.config.usable_region_size = + requested_size.next_multiple_of(usize_to_u64(self.slot_size)); + debug!( + "virtio-mem: Updated usable size to {} bytes", + self.config.usable_region_size + ); + } + + self.config.requested_size = requested_size; + debug!( + "virtio-mem: Updated requested size to {} bytes", + requested_size + ); + // TODO(virtio-mem): trigger interrupt once we add handling for the requests + // self.interrupt_trigger() + // .trigger(VirtioInterruptType::Config) + // .map_err(VirtioMemError::InterruptError) + Ok(()) + } } impl VirtioDevice for VirtioMem { diff --git a/src/vmm/src/lib.rs b/src/vmm/src/lib.rs index 0b6fee2e0a0..79e26c706a1 100644 --- a/src/vmm/src/lib.rs +++ b/src/vmm/src/lib.rs @@ -609,6 +609,15 @@ impl Vmm { .map_err(VmmError::FindDeviceError) } + /// Returns the current state of the memory hotplug device. + pub fn update_memory_hotplug_size(&self, requested_size_mib: usize) -> Result<(), VmmError> { + self.device_manager + .try_with_virtio_device_with_id(VIRTIO_MEM_DEV_ID, |dev: &mut VirtioMem| { + dev.update_requested_size(requested_size_mib) + }) + .map_err(VmmError::FindDeviceError) + } + /// Signals Vmm to stop and exit. pub fn stop(&mut self, exit_code: FcExitCode) { // To avoid cycles, all teardown paths take the following route: diff --git a/src/vmm/src/logger/metrics.rs b/src/vmm/src/logger/metrics.rs index c983a5a9f16..060a751562a 100644 --- a/src/vmm/src/logger/metrics.rs +++ b/src/vmm/src/logger/metrics.rs @@ -479,6 +479,10 @@ pub struct PatchRequestsMetrics { pub mmds_count: SharedIncMetric, /// Number of failures in PATCHing an mmds. pub mmds_fails: SharedIncMetric, + /// Number of PATCHes to /hotplug/memory + pub hotplug_memory_count: SharedIncMetric, + /// Number of failed PATCHes to /hotplug/memory + pub hotplug_memory_fails: SharedIncMetric, } impl PatchRequestsMetrics { /// Const default construction. @@ -492,6 +496,8 @@ impl PatchRequestsMetrics { machine_cfg_fails: SharedIncMetric::new(), mmds_count: SharedIncMetric::new(), mmds_fails: SharedIncMetric::new(), + hotplug_memory_count: SharedIncMetric::new(), + hotplug_memory_fails: SharedIncMetric::new(), } } } diff --git a/src/vmm/src/rpc_interface.rs b/src/vmm/src/rpc_interface.rs index fe3e9c296e7..d25ddb735a8 100644 --- a/src/vmm/src/rpc_interface.rs +++ b/src/vmm/src/rpc_interface.rs @@ -29,7 +29,9 @@ use crate::vmm_config::drive::{BlockDeviceConfig, BlockDeviceUpdateConfig, Drive use crate::vmm_config::entropy::{EntropyDeviceConfig, EntropyDeviceError}; use crate::vmm_config::instance_info::InstanceInfo; use crate::vmm_config::machine_config::{MachineConfig, MachineConfigError, MachineConfigUpdate}; -use crate::vmm_config::memory_hotplug::{MemoryHotplugConfig, MemoryHotplugConfigError}; +use crate::vmm_config::memory_hotplug::{ + MemoryHotplugConfig, MemoryHotplugConfigError, MemoryHotplugSizeUpdate, +}; use crate::vmm_config::metrics::{MetricsConfig, MetricsConfigError}; use crate::vmm_config::mmds::{MmdsConfig, MmdsConfigError}; use crate::vmm_config::net::{ @@ -113,6 +115,9 @@ pub enum VmmAction { /// Set the memory hotplug device using `MemoryHotplugConfig` as input. This action can only be /// called before the microVM has booted. SetMemoryHotplugDevice(MemoryHotplugConfig), + /// Updates the memory hotplug device using `MemoryHotplugConfigUpdate` as input. This action + /// can only be called after the microVM has booted. + UpdateMemoryHotplugSize(MemoryHotplugSizeUpdate), /// Launch the microVM. This action can only be called before the microVM has booted. StartMicroVm, /// Send CTRL+ALT+DEL to the microVM, using the i8042 keyboard function. If an AT-keyboard @@ -152,6 +157,8 @@ pub enum VmmActionError { EntropyDevice(#[from] EntropyDeviceError), /// Memory hotplug config error: {0} MemoryHotplugConfig(#[from] MemoryHotplugConfigError), + /// Memory hotplug update error: {0} + MemoryHotplugUpdate(VmmError), /// Internal VMM error: {0} InternalVmm(#[from] VmmError), /// Load snapshot error: {0} @@ -469,6 +476,7 @@ impl<'a> PrebootApiController<'a> { | UpdateBalloon(_) | UpdateBalloonStatistics(_) | UpdateBlockDevice(_) + | UpdateMemoryHotplugSize(_) | UpdateNetworkInterface(_) => Err(VmmActionError::OperationNotSupportedPreBoot), #[cfg(target_arch = "x86_64")] SendCtrlAltDel => Err(VmmActionError::OperationNotSupportedPreBoot), @@ -709,7 +717,13 @@ impl RuntimeApiController { .map_err(VmmActionError::BalloonUpdate), UpdateBlockDevice(new_cfg) => self.update_block_device(new_cfg), UpdateNetworkInterface(netif_update) => self.update_net_rate_limiters(netif_update), - + UpdateMemoryHotplugSize(cfg) => self + .vmm + .lock() + .expect("Poisoned lock") + .update_memory_hotplug_size(cfg.requested_size_mib) + .map(|_| VmmData::Empty) + .map_err(VmmActionError::MemoryHotplugUpdate), // Operations not allowed post-boot. ConfigureBootSource(_) | ConfigureLogger(_) @@ -1181,6 +1195,11 @@ mod tests { ))); #[cfg(target_arch = "x86_64")] check_unsupported(preboot_request(VmmAction::SendCtrlAltDel)); + check_unsupported(preboot_request(VmmAction::UpdateMemoryHotplugSize( + MemoryHotplugSizeUpdate { + requested_size_mib: 0, + }, + ))); } fn runtime_request(request: VmmAction) -> Result { diff --git a/src/vmm/src/vmm_config/memory_hotplug.rs b/src/vmm/src/vmm_config/memory_hotplug.rs index d09141c1b66..85cf45ee5e8 100644 --- a/src/vmm/src/vmm_config/memory_hotplug.rs +++ b/src/vmm/src/vmm_config/memory_hotplug.rs @@ -86,6 +86,14 @@ impl MemoryHotplugConfig { } } +/// Configuration for memory hotplug device. +#[derive(Clone, Debug, Default, PartialEq, Eq, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] +pub struct MemoryHotplugSizeUpdate { + /// Requested size in MiB to resize the hotpluggable memory to. + pub requested_size_mib: usize, +} + #[cfg(test)] mod tests { use serde_json; From b2f0cfc3d991a122c6ff25bc2681694dd430b2fb Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Wed, 24 Sep 2025 16:00:07 +0100 Subject: [PATCH 03/14] doc(virtio-mem): document PATCH API in swagger and docs Add entry for the patch API in Swagger and in the docs. Signed-off-by: Riccardo Mancini --- docs/device-api.md | 1 + src/firecracker/swagger/firecracker.yaml | 29 ++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/docs/device-api.md b/docs/device-api.md index 01e470f1d64..0b8035651f8 100644 --- a/docs/device-api.md +++ b/docs/device-api.md @@ -104,6 +104,7 @@ specification: | `MemoryHotplugConfig` | total_size_mib | O | O | O | O | O | O | O | **R** | | | slot_size_mib | O | O | O | O | O | O | O | **R** | | | block_size_mi | O | O | O | O | O | O | O | **R** | +| `MemoryHotplugSizeUpdate` | requested_size_mib | O | O | O | O | O | O | O | **R** | \* `Drive`'s `drive_id`, `is_root_device` and `partuuid` can be configured by either virtio-block or vhost-user-block devices. diff --git a/src/firecracker/swagger/firecracker.yaml b/src/firecracker/swagger/firecracker.yaml index c5011a79fd7..bddc4942c2a 100644 --- a/src/firecracker/swagger/firecracker.yaml +++ b/src/firecracker/swagger/firecracker.yaml @@ -549,6 +549,26 @@ paths: description: Internal server error schema: $ref: "#/definitions/Error" + patch: + summary: Updates the size of the hotpluggable memory region + operationId: patchMemoryHotplug + description: + Updates the size of the hotpluggable memory region. The guest will plug and unplug memory to + hit the requested memory. + parameters: + - name: body + in: body + description: Hotpluggable memory size update + required: true + schema: + $ref: "#/definitions/MemoryHotplugSizeUpdate" + responses: + 204: + description: Hotpluggable memory configured + default: + description: Internal server error + schema: + $ref: "#/definitions/Error" get: summary: Retrieves the status of the hotpluggable memory operationId: getMemoryHotplug @@ -1422,6 +1442,15 @@ definitions: description: (Logical) Block size for the hotpluggable memory in MiB. This will determine the logical granularity of hot-plug memory for the guest. Refer to the device documentation on how to tune this value. + MemoryHotplugSizeUpdate: + type: object + description: + An update to the size of the hotpluggable memory region. + properties: + requested_size_mib: + type: integer + description: New target region size. + MemoryHotplugStatus: type: object description: From 84c7db52cdacd7ca7584925ae346ff223939e31e Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Wed, 24 Sep 2025 18:03:00 +0100 Subject: [PATCH 04/14] test(virtio-mem): add API tests for PATCH Test that the new PATCH API behaves as expected. Also updates expected metrics and fixes memory monitor to account for hotplugging. Signed-off-by: Riccardo Mancini --- tests/host_tools/fcmetrics.py | 2 ++ tests/host_tools/memory.py | 8 ++++++-- tests/integration_tests/functional/test_api.py | 14 ++++++++++++-- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/tests/host_tools/fcmetrics.py b/tests/host_tools/fcmetrics.py index 0dcff5eed00..6b110a70f96 100644 --- a/tests/host_tools/fcmetrics.py +++ b/tests/host_tools/fcmetrics.py @@ -202,6 +202,8 @@ def validate_fc_metrics(metrics): "machine_cfg_fails", "mmds_count", "mmds_fails", + "hotplug_memory_count", + "hotplug_memory_fails", ], "put_api_requests": [ "actions_count", diff --git a/tests/host_tools/memory.py b/tests/host_tools/memory.py index 134147724cd..c09dae0d206 100644 --- a/tests/host_tools/memory.py +++ b/tests/host_tools/memory.py @@ -99,7 +99,9 @@ def is_guest_mem_x86(self, size, guest_mem_bytes): Checks if a region is a guest memory region based on x86_64 physical memory layout """ - return size in ( + # it could be bigger if hotplugging is enabled + # if it's bigger, it's likely not from FC because we don't have big allocations + return size >= guest_mem_bytes or size in ( # memory fits before the first gap guest_mem_bytes, # guest memory spans at least two regions & memory fits before the second gap @@ -121,7 +123,9 @@ def is_guest_mem_arch64(self, size, guest_mem_bytes): Checks if a region is a guest memory region based on ARM64 physical memory layout """ - return size in ( + # it could be bigger if hotplugging is enabled + # if it's bigger, it's likely not from FC because we don't have big allocations + return size >= guest_mem_bytes or size in ( # guest memory fits before the gap guest_mem_bytes, # guest memory fills the space before the gap diff --git a/tests/integration_tests/functional/test_api.py b/tests/integration_tests/functional/test_api.py index 81454559990..ac929941dca 100644 --- a/tests/integration_tests/functional/test_api.py +++ b/tests/integration_tests/functional/test_api.py @@ -981,13 +981,14 @@ def test_api_entropy(uvm_plain): test_microvm.api.entropy.put() -def test_api_memory_hotplug(uvm_plain): +def test_api_memory_hotplug(uvm_plain_6_1): """ Test hotplug related API commands. """ - test_microvm = uvm_plain + test_microvm = uvm_plain_6_1 test_microvm.spawn() test_microvm.basic_config() + test_microvm.add_net_iface() # Adding hotplug memory region should be OK. test_microvm.api.memory_hotplug.put( @@ -1002,6 +1003,10 @@ def test_api_memory_hotplug(uvm_plain): with pytest.raises(AssertionError): test_microvm.api.memory_hotplug.get() + # Patch API should be rejected before boot + with pytest.raises(RuntimeError, match=NOT_SUPPORTED_BEFORE_START): + test_microvm.api.memory_hotplug.patch(requested_size_mib=512) + # Start the microvm test_microvm.start() @@ -1013,6 +1018,11 @@ def test_api_memory_hotplug(uvm_plain): status = test_microvm.api.memory_hotplug.get().json() assert status["total_size_mib"] == 1024 + # Patch API should work after boot + test_microvm.api.memory_hotplug.patch(requested_size_mib=512) + status = test_microvm.api.memory_hotplug.get().json() + assert status["requested_size_mib"] == 512 + def test_api_balloon(uvm_nano): """ From 1425c7c4de4d44f617f214aa3b908ddff7ffb236 Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Wed, 24 Sep 2025 18:22:27 +0100 Subject: [PATCH 05/14] feat(virtio-mem): add virtio request parsing and dummy response Parse virtio requests over the queue and always ack them. Following commits will add the state management inside the device. Signed-off-by: Riccardo Mancini --- src/vmm/src/devices/virtio/mem/device.rs | 168 ++++++++++++++++++++-- src/vmm/src/devices/virtio/mem/metrics.rs | 24 ++++ src/vmm/src/devices/virtio/mem/mod.rs | 1 + src/vmm/src/devices/virtio/mem/request.rs | 142 ++++++++++++++++++ 4 files changed, 326 insertions(+), 9 deletions(-) create mode 100644 src/vmm/src/devices/virtio/mem/request.rs diff --git a/src/vmm/src/devices/virtio/mem/device.rs b/src/vmm/src/devices/virtio/mem/device.rs index 8fd938ec9fe..df8cb0df60e 100644 --- a/src/vmm/src/devices/virtio/mem/device.rs +++ b/src/vmm/src/devices/virtio/mem/device.rs @@ -9,7 +9,7 @@ use std::sync::atomic::AtomicU32; use log::info; use serde::{Deserialize, Serialize}; use vm_memory::{ - Address, GuestAddress, GuestMemory, GuestMemoryError, GuestMemoryRegion, GuestUsize, + Address, Bytes, GuestAddress, GuestMemory, GuestMemoryError, GuestMemoryRegion, GuestUsize, }; use vmm_sys_util::eventfd::EventFd; @@ -20,12 +20,15 @@ use crate::devices::virtio::device::{ActiveState, DeviceState, VirtioDevice}; use crate::devices::virtio::generated::virtio_config::VIRTIO_F_VERSION_1; use crate::devices::virtio::generated::virtio_ids::VIRTIO_ID_MEM; use crate::devices::virtio::generated::virtio_mem::{ - VIRTIO_MEM_F_UNPLUGGED_INACCESSIBLE, virtio_mem_config, + self, VIRTIO_MEM_F_UNPLUGGED_INACCESSIBLE, virtio_mem_config, }; use crate::devices::virtio::iov_deque::IovDequeError; use crate::devices::virtio::mem::metrics::METRICS; +use crate::devices::virtio::mem::request::{BlockRangeState, Request, RequestedRange, Response}; use crate::devices::virtio::mem::{VIRTIO_MEM_DEV_ID, VIRTIO_MEM_GUEST_ADDRESS}; -use crate::devices::virtio::queue::{FIRECRACKER_MAX_QUEUE_SIZE, InvalidAvailIdx, Queue}; +use crate::devices::virtio::queue::{ + DescriptorChain, FIRECRACKER_MAX_QUEUE_SIZE, InvalidAvailIdx, Queue, QueueError, +}; use crate::devices::virtio::transport::{VirtioInterrupt, VirtioInterruptType}; use crate::logger::{IncMetric, debug, error}; use crate::utils::{bytes_to_mib, mib_to_bytes, u64_to_usize, usize_to_u64}; @@ -47,6 +50,24 @@ pub enum VirtioMemError { InvalidSize(u64), /// Device is not active DeviceNotActive, + /// Descriptor is write-only + UnexpectedWriteOnlyDescriptor, + /// Error reading virtio descriptor + DescriptorWriteFailed, + /// Error writing virtio descriptor + DescriptorReadFailed, + /// Unknown request type: {0:?} + UnknownRequestType(u32), + /// Descriptor chain is too short + DescriptorChainTooShort, + /// Descriptor is too small + DescriptorLengthTooSmall, + /// Descriptor is read-only + UnexpectedReadOnlyDescriptor, + /// Error popping from virtio queue: {0} + InvalidAvailIdx(#[from] InvalidAvailIdx), + /// Error adding used queue: {0} + QueueError(#[from] QueueError), } #[derive(Debug)] @@ -170,8 +191,139 @@ impl VirtioMem { .map_err(VirtioMemError::InterruptError) } + fn guest_memory(&self) -> &GuestMemoryMmap { + &self.device_state.active_state().unwrap().mem + } + + fn parse_request( + &self, + avail_desc: &DescriptorChain, + ) -> Result<(Request, GuestAddress, u16), VirtioMemError> { + // The head contains the request type which MUST be readable. + if avail_desc.is_write_only() { + return Err(VirtioMemError::UnexpectedWriteOnlyDescriptor); + } + + if (avail_desc.len as usize) < size_of::() { + return Err(VirtioMemError::DescriptorLengthTooSmall); + } + + let request: virtio_mem::virtio_mem_req = self + .guest_memory() + .read_obj(avail_desc.addr) + .map_err(|_| VirtioMemError::DescriptorReadFailed)?; + + let resp_desc = avail_desc + .next_descriptor() + .ok_or(VirtioMemError::DescriptorChainTooShort)?; + + // The response MUST always be writable. + if !resp_desc.is_write_only() { + return Err(VirtioMemError::UnexpectedReadOnlyDescriptor); + } + + if (resp_desc.len as usize) < std::mem::size_of::() { + return Err(VirtioMemError::DescriptorLengthTooSmall); + } + + Ok((request.into(), resp_desc.addr, avail_desc.index)) + } + + fn write_response( + &mut self, + resp: Response, + resp_addr: GuestAddress, + used_idx: u16, + ) -> Result<(), VirtioMemError> { + debug!("virtio-mem: Response: {:?}", resp); + self.guest_memory() + .write_obj(virtio_mem::virtio_mem_resp::from(resp), resp_addr) + .map_err(|_| VirtioMemError::DescriptorWriteFailed) + .map(|_| size_of::())?; + self.queues[MEM_QUEUE] + .add_used( + used_idx, + u32::try_from(std::mem::size_of::()).unwrap(), + ) + .map_err(VirtioMemError::QueueError) + } + + fn handle_plug_request( + &mut self, + range: &RequestedRange, + resp_addr: GuestAddress, + used_idx: u16, + ) -> Result<(), VirtioMemError> { + METRICS.plug_count.inc(); + let _metric = METRICS.plug_agg.record_latency_metrics(); + + // TODO: implement PLUG request + let response = Response::ack(); + self.write_response(response, resp_addr, used_idx) + } + + fn handle_unplug_request( + &mut self, + range: &RequestedRange, + resp_addr: GuestAddress, + used_idx: u16, + ) -> Result<(), VirtioMemError> { + METRICS.unplug_count.inc(); + let _metric = METRICS.unplug_agg.record_latency_metrics(); + + // TODO: implement UNPLUG request + let response = Response::ack(); + self.write_response(response, resp_addr, used_idx) + } + + fn handle_unplug_all_request( + &mut self, + resp_addr: GuestAddress, + used_idx: u16, + ) -> Result<(), VirtioMemError> { + METRICS.unplug_all_count.inc(); + let _metric = METRICS.unplug_all_agg.record_latency_metrics(); + + // TODO: implement UNPLUG ALL request + let response = Response::ack(); + self.write_response(response, resp_addr, used_idx) + } + + fn handle_state_request( + &mut self, + range: &RequestedRange, + resp_addr: GuestAddress, + used_idx: u16, + ) -> Result<(), VirtioMemError> { + METRICS.state_count.inc(); + let _metric = METRICS.state_agg.record_latency_metrics(); + + // TODO: implement STATE request + let response = Response::ack_with_state(BlockRangeState::Mixed); + self.write_response(response, resp_addr, used_idx) + } + fn process_mem_queue(&mut self) -> Result<(), VirtioMemError> { - info!("TODO: Received mem queue event, but it's not implemented."); + while let Some(desc) = self.queues[MEM_QUEUE].pop()? { + let index = desc.index; + + let (req, resp_addr, used_idx) = self.parse_request(&desc)?; + debug!("virtio-mem: Request: {:?}", req); + // Handle request and write response + match req { + Request::State(ref range) => self.handle_state_request(range, resp_addr, used_idx), + Request::Plug(ref range) => self.handle_plug_request(range, resp_addr, used_idx), + Request::Unplug(ref range) => { + self.handle_unplug_request(range, resp_addr, used_idx) + } + Request::UnplugAll => self.handle_unplug_all_request(resp_addr, used_idx), + Request::Unsupported(t) => Err(VirtioMemError::UnknownRequestType(t)), + }?; + } + + self.queues[MEM_QUEUE].advance_used_ring_idx(); + self.signal_used_queue()?; + Ok(()) } @@ -237,11 +389,9 @@ impl VirtioMem { "virtio-mem: Updated requested size to {} bytes", requested_size ); - // TODO(virtio-mem): trigger interrupt once we add handling for the requests - // self.interrupt_trigger() - // .trigger(VirtioInterruptType::Config) - // .map_err(VirtioMemError::InterruptError) - Ok(()) + self.interrupt_trigger() + .trigger(VirtioInterruptType::Config) + .map_err(VirtioMemError::InterruptError) } } diff --git a/src/vmm/src/devices/virtio/mem/metrics.rs b/src/vmm/src/devices/virtio/mem/metrics.rs index 443e9a8b8f1..e9e97707782 100644 --- a/src/vmm/src/devices/virtio/mem/metrics.rs +++ b/src/vmm/src/devices/virtio/mem/metrics.rs @@ -45,6 +45,22 @@ pub(super) struct VirtioMemDeviceMetrics { pub queue_event_fails: SharedIncMetric, /// Number of queue events handled pub queue_event_count: SharedIncMetric, + /// Latency of Plug operations + pub plug_agg: LatencyAggregateMetrics, + /// Number of Plug operations + pub plug_count: SharedIncMetric, + /// Latency of Unplug operations + pub unplug_agg: LatencyAggregateMetrics, + /// Number of Unplug operations + pub unplug_count: SharedIncMetric, + /// Latency of UnplugAll operations + pub unplug_all_agg: LatencyAggregateMetrics, + /// Number of UnplugAll operations + pub unplug_all_count: SharedIncMetric, + /// Latency of State operations + pub state_agg: LatencyAggregateMetrics, + /// Number of State operations + pub state_count: SharedIncMetric, } impl VirtioMemDeviceMetrics { @@ -54,6 +70,14 @@ impl VirtioMemDeviceMetrics { activate_fails: SharedIncMetric::new(), queue_event_fails: SharedIncMetric::new(), queue_event_count: SharedIncMetric::new(), + plug_agg: LatencyAggregateMetrics::new(), + plug_count: SharedIncMetric::new(), + unplug_agg: LatencyAggregateMetrics::new(), + unplug_count: SharedIncMetric::new(), + unplug_all_agg: LatencyAggregateMetrics::new(), + unplug_all_count: SharedIncMetric::new(), + state_agg: LatencyAggregateMetrics::new(), + state_count: SharedIncMetric::new(), } } } diff --git a/src/vmm/src/devices/virtio/mem/mod.rs b/src/vmm/src/devices/virtio/mem/mod.rs index 1c9e98f98a6..5c76afc4c24 100644 --- a/src/vmm/src/devices/virtio/mem/mod.rs +++ b/src/vmm/src/devices/virtio/mem/mod.rs @@ -5,6 +5,7 @@ mod device; mod event_handler; pub mod metrics; pub mod persist; +mod request; use vm_memory::GuestAddress; diff --git a/src/vmm/src/devices/virtio/mem/request.rs b/src/vmm/src/devices/virtio/mem/request.rs new file mode 100644 index 00000000000..e3c620f10af --- /dev/null +++ b/src/vmm/src/devices/virtio/mem/request.rs @@ -0,0 +1,142 @@ +// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use vm_memory::{ByteValued, GuestAddress}; + +use crate::devices::virtio::generated::virtio_mem; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct RequestedRange { + pub(crate) addr: GuestAddress, + pub(crate) nb_blocks: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) enum Request { + Plug(RequestedRange), + Unplug(RequestedRange), + UnplugAll, + State(RequestedRange), + Unsupported(u32), +} + +// SAFETY: this is safe, trust me bro +unsafe impl ByteValued for virtio_mem::virtio_mem_req {} + +impl From for Request { + fn from(req: virtio_mem::virtio_mem_req) -> Self { + match req.type_.into() { + // SAFETY: union type is checked in the match + virtio_mem::VIRTIO_MEM_REQ_PLUG => unsafe { + Request::Plug(RequestedRange { + addr: GuestAddress(req.u.plug.addr), + nb_blocks: req.u.plug.nb_blocks.into(), + }) + }, + // SAFETY: union type is checked in the match + virtio_mem::VIRTIO_MEM_REQ_UNPLUG => unsafe { + Request::Unplug(RequestedRange { + addr: GuestAddress(req.u.unplug.addr), + nb_blocks: req.u.unplug.nb_blocks.into(), + }) + }, + virtio_mem::VIRTIO_MEM_REQ_UNPLUG_ALL => Request::UnplugAll, + // SAFETY: union type is checked in the match + virtio_mem::VIRTIO_MEM_REQ_STATE => unsafe { + Request::State(RequestedRange { + addr: GuestAddress(req.u.state.addr), + nb_blocks: req.u.state.nb_blocks.into(), + }) + }, + t => Request::Unsupported(t), + } + } +} + +#[derive(Debug, Clone, Copy)] +pub enum ResponseType { + Ack, + Nack, + Busy, + Error, +} + +impl From for u16 { + fn from(code: ResponseType) -> Self { + match code { + ResponseType::Ack => virtio_mem::VIRTIO_MEM_RESP_ACK, + ResponseType::Nack => virtio_mem::VIRTIO_MEM_RESP_NACK, + ResponseType::Busy => virtio_mem::VIRTIO_MEM_RESP_BUSY, + ResponseType::Error => virtio_mem::VIRTIO_MEM_RESP_ERROR, + } + .try_into() + .unwrap() + } +} + +#[derive(Debug, Clone, Copy)] +pub enum BlockRangeState { + Plugged, + Unplugged, + Mixed, +} + +impl From for virtio_mem::virtio_mem_resp_state { + fn from(code: BlockRangeState) -> Self { + virtio_mem::virtio_mem_resp_state { + state: match code { + BlockRangeState::Plugged => virtio_mem::VIRTIO_MEM_STATE_PLUGGED, + BlockRangeState::Unplugged => virtio_mem::VIRTIO_MEM_STATE_UNPLUGGED, + BlockRangeState::Mixed => virtio_mem::VIRTIO_MEM_STATE_MIXED, + } + .try_into() + .unwrap(), + } + } +} + +#[derive(Debug, Clone)] +pub struct Response { + pub resp_type: ResponseType, + // Only for State requests + pub state: Option, +} + +impl Response { + pub(crate) fn error() -> Self { + Response { + resp_type: ResponseType::Error, + state: None, + } + } + + pub(crate) fn ack() -> Self { + Response { + resp_type: ResponseType::Ack, + state: None, + } + } + + pub(crate) fn ack_with_state(state: BlockRangeState) -> Self { + Response { + resp_type: ResponseType::Ack, + state: Some(state), + } + } +} + +// SAFETY: Plain data structures +unsafe impl ByteValued for virtio_mem::virtio_mem_resp {} + +impl From for virtio_mem::virtio_mem_resp { + fn from(resp: Response) -> Self { + let mut out = virtio_mem::virtio_mem_resp { + type_: resp.resp_type.into(), + ..Default::default() + }; + if let Some(state) = resp.state { + out.u.state = state.into(); + } + out + } +} From a1ae2cae6627f9a7fe6849662f812f8c3ebfb1d3 Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Wed, 24 Sep 2025 18:31:09 +0100 Subject: [PATCH 06/14] feat(virtio-mem): implement virtio requests This commit adds block state management and implements the virtio requests for the virtio-mem device. Block state is tracked using a BitVec, each bit representing a single block. Plug/Unplug requests are validated before being executed to verify the range is valid (aligned and within range), and that all blocks in range are unplugged/plugged, as per the virtio spec. UplugAll is the only request where usable_region_size can be lowered. This commit is missing the dynamic KVM slot management which will be added later. Signed-off-by: Riccardo Mancini --- Cargo.lock | 41 +++++ src/vmm/Cargo.toml | 1 + src/vmm/src/devices/virtio/mem/device.rs | 201 ++++++++++++++++++++-- src/vmm/src/devices/virtio/mem/metrics.rs | 21 +++ src/vmm/src/devices/virtio/mem/persist.rs | 19 +- 5 files changed, 261 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0ea518f3a49..5a66c51b66e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -248,6 +248,19 @@ version = "2.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2261d10cca569e4643e526d8dc2e62e433cc8aba21ab764233731f8d369bf394" +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "serde", + "tap", + "wyz", +] + [[package]] name = "bumpalo" version = "3.19.0" @@ -648,6 +661,12 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + [[package]] name = "gdbstub" version = "0.7.7" @@ -1140,6 +1159,12 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + [[package]] name = "rand" version = "0.9.2" @@ -1414,6 +1439,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + [[package]] name = "thiserror" version = "1.0.69" @@ -1682,6 +1713,7 @@ dependencies = [ "base64", "bincode", "bitflags 2.9.4", + "bitvec", "byteorder", "crc64", "criterion", @@ -2008,6 +2040,15 @@ version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" +[[package]] +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +dependencies = [ + "tap", +] + [[package]] name = "zerocopy" version = "0.8.27" diff --git a/src/vmm/Cargo.toml b/src/vmm/Cargo.toml index d2601f4a305..aa6f9a2fc47 100644 --- a/src/vmm/Cargo.toml +++ b/src/vmm/Cargo.toml @@ -23,6 +23,7 @@ aws-lc-rs = { version = "1.14.0", features = ["bindgen"] } base64 = "0.22.1" bincode = { version = "2.0.1", features = ["serde"] } bitflags = "2.9.4" +bitvec = { version = "1.0.1", features = ["atomic", "serde"] } byteorder = "1.5.0" crc64 = "2.0.0" derive_more = { version = "2.0.1", default-features = false, features = [ diff --git a/src/vmm/src/devices/virtio/mem/device.rs b/src/vmm/src/devices/virtio/mem/device.rs index df8cb0df60e..7cb7dd97bf7 100644 --- a/src/vmm/src/devices/virtio/mem/device.rs +++ b/src/vmm/src/devices/virtio/mem/device.rs @@ -2,10 +2,11 @@ // SPDX-License-Identifier: Apache-2.0 use std::io; -use std::ops::Deref; +use std::ops::{Deref, Range}; use std::sync::Arc; use std::sync::atomic::AtomicU32; +use bitvec::vec::BitVec; use log::info; use serde::{Deserialize, Serialize}; use vm_memory::{ @@ -14,7 +15,6 @@ use vm_memory::{ use vmm_sys_util::eventfd::EventFd; use super::{MEM_NUM_QUEUES, MEM_QUEUE}; -use crate::devices::DeviceError; use crate::devices::virtio::ActivateError; use crate::devices::virtio::device::{ActiveState, DeviceState, VirtioDevice}; use crate::devices::virtio::generated::virtio_config::VIRTIO_F_VERSION_1; @@ -33,7 +33,7 @@ use crate::devices::virtio::transport::{VirtioInterrupt, VirtioInterruptType}; use crate::logger::{IncMetric, debug, error}; use crate::utils::{bytes_to_mib, mib_to_bytes, u64_to_usize, usize_to_u64}; use crate::vstate::interrupts::InterruptError; -use crate::vstate::memory::{ByteValued, GuestMemoryMmap, GuestRegionMmap}; +use crate::vstate::memory::{ByteValued, GuestMemoryExtension, GuestMemoryMmap, GuestRegionMmap}; use crate::vstate::vm::VmError; use crate::{Vm, impl_device_type}; @@ -68,6 +68,14 @@ pub enum VirtioMemError { InvalidAvailIdx(#[from] InvalidAvailIdx), /// Error adding used queue: {0} QueueError(#[from] QueueError), + /// Invalid requested range: {0:?}. + InvalidRange(RequestedRange), + /// The requested range cannot be plugged because it's {0:?}. + PlugRequestBlockStateInvalid(BlockRangeState), + /// Plug request rejected as plugged_size would be greater than requested_size + PlugRequestIsTooBig, + /// The requested range cannot be unplugged because it's {0:?}. + UnplugRequestBlockStateInvalid(BlockRangeState), } #[derive(Debug)] @@ -85,6 +93,8 @@ pub struct VirtioMem { // Device specific fields pub(crate) config: virtio_mem_config, pub(crate) slot_size: usize, + // Bitmap to track which blocks are plugged + pub(crate) plugged_blocks: BitVec, vm: Arc, } @@ -118,8 +128,15 @@ impl VirtioMem { block_size: mib_to_bytes(block_size_mib) as u64, ..Default::default() }; + let plugged_blocks = BitVec::repeat(false, total_size_mib / block_size_mib); - Self::from_state(vm, queues, config, mib_to_bytes(slot_size_mib)) + Self::from_state( + vm, + queues, + config, + mib_to_bytes(slot_size_mib), + plugged_blocks, + ) } pub fn from_state( @@ -127,6 +144,7 @@ impl VirtioMem { queues: Vec, config: virtio_mem_config, slot_size: usize, + plugged_blocks: BitVec, ) -> Result { let activate_event = EventFd::new(libc::EFD_NONBLOCK)?; let queue_events = (0..MEM_NUM_QUEUES) @@ -143,6 +161,7 @@ impl VirtioMem { config, vm, slot_size, + plugged_blocks, }) } @@ -195,6 +214,20 @@ impl VirtioMem { &self.device_state.active_state().unwrap().mem } + fn nb_blocks_to_len(&self, nb_blocks: usize) -> usize { + nb_blocks * u64_to_usize(self.config.block_size) + } + + fn is_range_plugged(&self, range: &RequestedRange) -> BlockRangeState { + let plugged_count = self.plugged_blocks[self.unchecked_block_range(range)].count_ones(); + + match plugged_count { + nb_blocks if nb_blocks == range.nb_blocks => BlockRangeState::Plugged, + 0 => BlockRangeState::Unplugged, + _ => BlockRangeState::Mixed, + } + } + fn parse_request( &self, avail_desc: &DescriptorChain, @@ -248,6 +281,57 @@ impl VirtioMem { .map_err(VirtioMemError::QueueError) } + fn validate_range(&self, range: &RequestedRange) -> Result<(), VirtioMemError> { + // Ensure the range is aligned + if !range + .addr + .raw_value() + .is_multiple_of(self.config.block_size) + { + return Err(VirtioMemError::InvalidRange(*range)); + } + + if range.nb_blocks == 0 { + return Err(VirtioMemError::InvalidRange(*range)); + } + + // Ensure the start addr is within the usable region + let start_off = range + .addr + .checked_offset_from(GuestAddress(self.config.addr)) + .filter(|&off| off < self.config.usable_region_size) + .ok_or(VirtioMemError::InvalidRange(*range))?; + + // Ensure the end offset (exclusive) is within the usable region + let end_off = start_off + .checked_add(usize_to_u64(self.nb_blocks_to_len(range.nb_blocks))) + .filter(|&end_off| end_off <= self.config.usable_region_size) + .ok_or(VirtioMemError::InvalidRange(*range))?; + + Ok(()) + } + + fn unchecked_block_range(&self, range: &RequestedRange) -> Range { + let start_block = u64_to_usize((range.addr.0 - self.config.addr) / self.config.block_size); + + start_block..(start_block + range.nb_blocks) + } + + fn do_plug_request(&mut self, range: &RequestedRange) -> Result<(), VirtioMemError> { + self.validate_range(range)?; + + if self.config.plugged_size + usize_to_u64(self.nb_blocks_to_len(range.nb_blocks)) + > self.config.requested_size + { + return Err(VirtioMemError::PlugRequestIsTooBig); + } + + match self.is_range_plugged(range) { + BlockRangeState::Unplugged => self.plug_range(range, true), + state => Err(VirtioMemError::PlugRequestBlockStateInvalid(state)), + } + } + fn handle_plug_request( &mut self, range: &RequestedRange, @@ -257,11 +341,31 @@ impl VirtioMem { METRICS.plug_count.inc(); let _metric = METRICS.plug_agg.record_latency_metrics(); - // TODO: implement PLUG request - let response = Response::ack(); + let response = self.do_plug_request(range).map_or_else( + |err| { + METRICS.plug_fails.inc(); + error!("virtio-mem: Failed to plug range: {}", err); + Response::error() + }, + |_| { + METRICS + .plug_bytes + .add(usize_to_u64(self.nb_blocks_to_len(range.nb_blocks))); + Response::ack() + }, + ); self.write_response(response, resp_addr, used_idx) } + fn do_unplug_request(&mut self, range: &RequestedRange) -> Result<(), VirtioMemError> { + self.validate_range(range)?; + + match self.is_range_plugged(range) { + BlockRangeState::Plugged => self.plug_range(range, false), + state => Err(VirtioMemError::UnplugRequestBlockStateInvalid(state)), + } + } + fn handle_unplug_request( &mut self, range: &RequestedRange, @@ -270,9 +374,19 @@ impl VirtioMem { ) -> Result<(), VirtioMemError> { METRICS.unplug_count.inc(); let _metric = METRICS.unplug_agg.record_latency_metrics(); - - // TODO: implement UNPLUG request - let response = Response::ack(); + let response = self.do_unplug_request(range).map_or_else( + |err| { + METRICS.unplug_fails.inc(); + error!("virtio-mem: Failed to unplug range: {}", err); + Response::error() + }, + |_| { + METRICS + .unplug_bytes + .add(usize_to_u64(self.nb_blocks_to_len(range.nb_blocks))); + Response::ack() + }, + ); self.write_response(response, resp_addr, used_idx) } @@ -283,9 +397,21 @@ impl VirtioMem { ) -> Result<(), VirtioMemError> { METRICS.unplug_all_count.inc(); let _metric = METRICS.unplug_all_agg.record_latency_metrics(); - - // TODO: implement UNPLUG ALL request - let response = Response::ack(); + let range = RequestedRange { + addr: GuestAddress(self.config.addr), + nb_blocks: self.plugged_blocks.len(), + }; + let response = self.plug_range(&range, false).map_or_else( + |err| { + METRICS.unplug_all_fails.inc(); + error!("virtio-mem: Failed to unplug all: {}", err); + Response::error() + }, + |_| { + self.config.usable_region_size = 0; + Response::ack() + }, + ); self.write_response(response, resp_addr, used_idx) } @@ -297,9 +423,14 @@ impl VirtioMem { ) -> Result<(), VirtioMemError> { METRICS.state_count.inc(); let _metric = METRICS.state_agg.record_latency_metrics(); - - // TODO: implement STATE request - let response = Response::ack_with_state(BlockRangeState::Mixed); + let response = self.validate_range(range).map_or_else( + |err| { + METRICS.state_fails.inc(); + error!("virtio-mem: Failed to retrieve state of range: {}", err); + Response::error() + }, + |_| Response::ack_with_state(self.is_range_plugged(range)), + ); self.write_response(response, resp_addr, used_idx) } @@ -357,6 +488,33 @@ impl VirtioMem { &self.activate_event } + fn plug_range(&mut self, range: &RequestedRange, plug: bool) -> Result<(), VirtioMemError> { + // Update internal state + let block_range = self.unchecked_block_range(range); + let plugged_blocks_slice = &mut self.plugged_blocks[block_range]; + let plugged_before = plugged_blocks_slice.count_ones(); + plugged_blocks_slice.fill(plug); + let plugged_after = plugged_blocks_slice.count_ones(); + self.config.plugged_size -= usize_to_u64(self.nb_blocks_to_len(plugged_before)); + self.config.plugged_size += usize_to_u64(self.nb_blocks_to_len(plugged_after)); + + // If unplugging, discard the range + if !plug { + self.guest_memory() + .discard_range(range.addr, self.nb_blocks_to_len(range.nb_blocks)) + .inspect_err(|err| { + // Failure to discard is not fatal and is not reported to the driver. It only + // gets logged. + METRICS.unplug_discard_fails.inc(); + error!("virtio-mem: Failed to discard memory range: {}", err); + }); + } + + // TODO: update KVM slots to plug/unplug them + + Ok(()) + } + /// Updates the requested size of the virtio-mem device. pub fn update_requested_size( &mut self, @@ -544,7 +702,18 @@ mod tests { usable_region_size, ..Default::default() }; - let mem = VirtioMem::from_state(vm, queues, config, mib_to_bytes(slot_size_mib)).unwrap(); + let plugged_blocks = BitVec::repeat( + false, + mib_to_bytes(region_size_mib) / mib_to_bytes(block_size_mib), + ); + let mem = VirtioMem::from_state( + vm, + queues, + config, + mib_to_bytes(slot_size_mib), + plugged_blocks, + ) + .unwrap(); assert_eq!(mem.total_size_mib(), region_size_mib); assert_eq!(mem.block_size_mib(), block_size_mib); assert_eq!(mem.slot_size_mib(), slot_size_mib); diff --git a/src/vmm/src/devices/virtio/mem/metrics.rs b/src/vmm/src/devices/virtio/mem/metrics.rs index e9e97707782..f2a6f58b92d 100644 --- a/src/vmm/src/devices/virtio/mem/metrics.rs +++ b/src/vmm/src/devices/virtio/mem/metrics.rs @@ -49,18 +49,32 @@ pub(super) struct VirtioMemDeviceMetrics { pub plug_agg: LatencyAggregateMetrics, /// Number of Plug operations pub plug_count: SharedIncMetric, + /// Number of plugged bytes + pub plug_bytes: SharedIncMetric, + /// Number of Plug operations failed + pub plug_fails: SharedIncMetric, /// Latency of Unplug operations pub unplug_agg: LatencyAggregateMetrics, /// Number of Unplug operations pub unplug_count: SharedIncMetric, + /// Number of unplugged bytes + pub unplug_bytes: SharedIncMetric, + /// Number of Unplug operations failed + pub unplug_fails: SharedIncMetric, + /// Number of discards failed for an Unplug or UnplugAll operation + pub unplug_discard_fails: SharedIncMetric, /// Latency of UnplugAll operations pub unplug_all_agg: LatencyAggregateMetrics, /// Number of UnplugAll operations pub unplug_all_count: SharedIncMetric, + /// Number of UnplugAll operations failed + pub unplug_all_fails: SharedIncMetric, /// Latency of State operations pub state_agg: LatencyAggregateMetrics, /// Number of State operations pub state_count: SharedIncMetric, + /// Number of State operations failed + pub state_fails: SharedIncMetric, } impl VirtioMemDeviceMetrics { @@ -72,12 +86,19 @@ impl VirtioMemDeviceMetrics { queue_event_count: SharedIncMetric::new(), plug_agg: LatencyAggregateMetrics::new(), plug_count: SharedIncMetric::new(), + plug_bytes: SharedIncMetric::new(), + plug_fails: SharedIncMetric::new(), unplug_agg: LatencyAggregateMetrics::new(), unplug_count: SharedIncMetric::new(), + unplug_bytes: SharedIncMetric::new(), + unplug_fails: SharedIncMetric::new(), + unplug_discard_fails: SharedIncMetric::new(), unplug_all_agg: LatencyAggregateMetrics::new(), unplug_all_count: SharedIncMetric::new(), + unplug_all_fails: SharedIncMetric::new(), state_agg: LatencyAggregateMetrics::new(), state_count: SharedIncMetric::new(), + state_fails: SharedIncMetric::new(), } } } diff --git a/src/vmm/src/devices/virtio/mem/persist.rs b/src/vmm/src/devices/virtio/mem/persist.rs index e48246de500..09f41680e32 100644 --- a/src/vmm/src/devices/virtio/mem/persist.rs +++ b/src/vmm/src/devices/virtio/mem/persist.rs @@ -5,6 +5,7 @@ use std::sync::Arc; +use bitvec::vec::BitVec; use serde::{Deserialize, Serialize}; use vm_memory::Address; @@ -17,6 +18,7 @@ use crate::devices::virtio::mem::{ use crate::devices::virtio::persist::{PersistError as VirtioStateError, VirtioDeviceState}; use crate::devices::virtio::queue::FIRECRACKER_MAX_QUEUE_SIZE; use crate::snapshot::Persist; +use crate::utils::usize_to_u64; use crate::vstate::memory::{GuestMemoryMmap, GuestRegionMmap}; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -25,9 +27,9 @@ pub struct VirtioMemState { region_size: u64, block_size: u64, usable_region_size: u64, - plugged_size: u64, requested_size: u64, slot_size: usize, + plugged_blocks: BitVec, } #[derive(Debug)] @@ -60,7 +62,7 @@ impl Persist<'_> for VirtioMem { region_size: self.config.region_size, block_size: self.config.block_size, usable_region_size: self.config.usable_region_size, - plugged_size: self.config.plugged_size, + plugged_blocks: self.plugged_blocks.clone(), requested_size: self.config.requested_size, slot_size: self.slot_size, } @@ -82,13 +84,18 @@ impl Persist<'_> for VirtioMem { region_size: state.region_size, block_size: state.block_size, usable_region_size: state.usable_region_size, - plugged_size: state.plugged_size, + plugged_size: usize_to_u64(state.plugged_blocks.count_ones()) * state.block_size, requested_size: state.requested_size, ..Default::default() }; - let mut virtio_mem = - VirtioMem::from_state(constructor_args.vm, queues, config, state.slot_size)?; + let mut virtio_mem = VirtioMem::from_state( + constructor_args.vm, + queues, + config, + state.slot_size, + state.plugged_blocks.clone(), + )?; virtio_mem.set_avail_features(state.virtio_state.avail_features); virtio_mem.set_acked_features(state.virtio_state.acked_features); @@ -111,7 +118,7 @@ mod tests { assert_eq!(state.region_size, dev.config.region_size); assert_eq!(state.block_size, dev.config.block_size); assert_eq!(state.usable_region_size, dev.config.usable_region_size); - assert_eq!(state.plugged_size, dev.config.plugged_size); + assert_eq!(state.plugged_blocks, dev.plugged_blocks); assert_eq!(state.requested_size, dev.config.requested_size); assert_eq!(state.slot_size, dev.slot_size); } From adcdddbaecbef17478260204146176bb9032a1ae Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Thu, 25 Sep 2025 16:35:17 +0100 Subject: [PATCH 07/14] test(virtio-mem): add unit tests for virtio queue request handling Adds unit tests using VirtioTestHelper to verify correct functioning of the new device. Signed-off-by: Riccardo Mancini --- src/vmm/src/devices/virtio/mem/device.rs | 508 +++++++++++++++++++++- src/vmm/src/devices/virtio/mem/metrics.rs | 7 +- src/vmm/src/devices/virtio/mem/request.rs | 96 +++- src/vmm/src/devices/virtio/test_utils.rs | 5 + 4 files changed, 605 insertions(+), 11 deletions(-) diff --git a/src/vmm/src/devices/virtio/mem/device.rs b/src/vmm/src/devices/virtio/mem/device.rs index 7cb7dd97bf7..8c53df4751f 100644 --- a/src/vmm/src/devices/virtio/mem/device.rs +++ b/src/vmm/src/devices/virtio/mem/device.rs @@ -169,6 +169,10 @@ impl VirtioMem { VIRTIO_MEM_DEV_ID } + pub fn guest_address(&self) -> GuestAddress { + GuestAddress(self.config.addr) + } + /// Gets the total hotpluggable size. pub fn total_size_mib(&self) -> usize { bytes_to_mib(u64_to_usize(self.config.region_size)) @@ -644,10 +648,34 @@ impl VirtioDevice for VirtioMem { #[cfg(test)] pub(crate) mod test_utils { use super::*; + use crate::devices::virtio::test_utils::test::VirtioTestDevice; + use crate::test_utils::single_region_mem; + use crate::vmm_config::machine_config::HugePageConfig; + use crate::vstate::memory; use crate::vstate::vm::tests::setup_vm_with_memory; + impl VirtioTestDevice for VirtioMem { + fn set_queues(&mut self, queues: Vec) { + self.queues = queues; + } + + fn num_queues() -> usize { + MEM_NUM_QUEUES + } + } + pub(crate) fn default_virtio_mem() -> VirtioMem { - let (_, vm) = setup_vm_with_memory(0x1000); + let (_, mut vm) = setup_vm_with_memory(0x1000); + vm.register_hotpluggable_memory_region( + memory::anonymous( + std::iter::once((VIRTIO_MEM_GUEST_ADDRESS, mib_to_bytes(1024))), + false, + HugePageConfig::None, + ) + .unwrap() + .pop() + .unwrap(), + ); let vm = Arc::new(vm); VirtioMem::new(vm, 1024, 2, 128).unwrap() } @@ -657,11 +685,15 @@ pub(crate) mod test_utils { mod tests { use std::ptr::null_mut; + use serde_json::de; + use vm_memory::guest_memory; use vm_memory::mmap::MmapRegionBuilder; use super::*; use crate::devices::virtio::device::VirtioDevice; use crate::devices::virtio::mem::device::test_utils::default_virtio_mem; + use crate::devices::virtio::queue::VIRTQ_DESC_F_WRITE; + use crate::devices::virtio::test_utils::test::VirtioTestHelper; use crate::vstate::vm::tests::setup_vm_with_memory; #[test] @@ -796,4 +828,478 @@ mod tests { } ); } + + #[allow(clippy::cast_possible_truncation)] + const REQ_SIZE: u32 = std::mem::size_of::() as u32; + #[allow(clippy::cast_possible_truncation)] + const RESP_SIZE: u32 = std::mem::size_of::() as u32; + + fn test_helper<'a>( + mut dev: VirtioMem, + mem: &'a GuestMemoryMmap, + ) -> VirtioTestHelper<'a, VirtioMem> { + dev.set_acked_features(dev.avail_features); + + let mut th = VirtioTestHelper::::new(mem, dev); + th.activate_device(mem); + th + } + + fn emulate_request( + th: &mut VirtioTestHelper, + mem: &GuestMemoryMmap, + req: Request, + ) -> Response { + th.add_desc_chain( + MEM_QUEUE, + 0, + &[(0, REQ_SIZE, 0), (1, RESP_SIZE, VIRTQ_DESC_F_WRITE)], + ); + mem.write_obj( + virtio_mem::virtio_mem_req::from(req), + th.desc_address(MEM_QUEUE, 0), + ) + .unwrap(); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + mem.read_obj::(th.desc_address(MEM_QUEUE, 1)) + .unwrap() + .into() + } + + #[test] + fn test_event_fail_descriptor_chain_too_short() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + + th.add_desc_chain(MEM_QUEUE, 0, &[(0, REQ_SIZE, 0)]); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails + 1); + } + + #[test] + fn test_event_fail_descriptor_length_too_small() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + + th.add_desc_chain(MEM_QUEUE, 0, &[(0, 1, 0)]); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails + 1); + } + + #[test] + fn test_event_fail_unexpected_writeonly_descriptor() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + + th.add_desc_chain(MEM_QUEUE, 0, &[(0, REQ_SIZE, VIRTQ_DESC_F_WRITE)]); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails + 1); + } + + #[test] + fn test_event_fail_unexpected_readonly_descriptor() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + + th.add_desc_chain(MEM_QUEUE, 0, &[(0, REQ_SIZE, 0), (1, RESP_SIZE, 0)]); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails + 1); + } + + #[test] + fn test_event_fail_response_descriptor_length_too_small() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + + th.add_desc_chain( + MEM_QUEUE, + 0, + &[(0, REQ_SIZE, 0), (1, 1, VIRTQ_DESC_F_WRITE)], + ); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails + 1); + } + + #[test] + fn test_update_requested_size_device_not_active() { + let mut mem_dev = default_virtio_mem(); + let result = mem_dev.update_requested_size(512); + assert!(matches!(result, Err(VirtioMemError::DeviceNotActive))); + } + + #[test] + fn test_update_requested_size_invalid_size() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + // Size not multiple of block size + let result = th.device().update_requested_size(3); + assert!(matches!(result, Err(VirtioMemError::InvalidSize(_)))); + + // Size too large + let result = th.device().update_requested_size(2048); + assert!(matches!(result, Err(VirtioMemError::InvalidSize(_)))); + } + + #[test] + fn test_update_requested_size_success() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + th.device().update_requested_size(512).unwrap(); + assert_eq!(th.device().requested_size_mib(), 512); + } + + #[test] + fn test_plug_request_success() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + let plug_count = METRICS.plug_count.count(); + let plug_bytes = METRICS.plug_bytes.count(); + let plug_fails = METRICS.plug_fails.count(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_ack()); + assert_eq!(th.device().plugged_size_mib(), 2); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails); + assert_eq!(METRICS.plug_count.count(), plug_count + 1); + assert_eq!(METRICS.plug_bytes.count(), plug_bytes + (2 << 20)); + assert_eq!(METRICS.plug_fails.count(), plug_fails); + } + + #[test] + fn test_plug_request_too_big() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(2); + let addr = th.device().guest_address(); + + let plug_count = METRICS.plug_count.count(); + let plug_bytes = METRICS.plug_bytes.count(); + let plug_fails = METRICS.plug_fails.count(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 2 }), + ); + assert!(resp.is_error()); + + assert_eq!(METRICS.plug_count.count(), plug_count + 1); + assert_eq!(METRICS.plug_bytes.count(), plug_bytes); + assert_eq!(METRICS.plug_fails.count(), plug_fails + 1); + } + + #[test] + fn test_plug_request_already_plugged() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + // First plug succeeds + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_ack()); + + // Second plug fails + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_error()); + } + + #[test] + fn test_unplug_request_success() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + let unplug_count = METRICS.unplug_count.count(); + let unplug_bytes = METRICS.unplug_bytes.count(); + let unplug_fails = METRICS.unplug_fails.count(); + + // First plug + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_ack()); + assert_eq!(th.device().plugged_size_mib(), 2); + + // Then unplug + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Unplug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_ack()); + assert_eq!(th.device().plugged_size_mib(), 0); + + assert_eq!(METRICS.unplug_count.count(), unplug_count + 1); + assert_eq!(METRICS.unplug_bytes.count(), unplug_bytes + (2 << 20)); + assert_eq!(METRICS.unplug_fails.count(), unplug_fails); + } + + #[test] + fn test_unplug_request_not_plugged() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + let unplug_count = METRICS.unplug_count.count(); + let unplug_bytes = METRICS.unplug_bytes.count(); + let unplug_fails = METRICS.unplug_fails.count(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Unplug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_error()); + + assert_eq!(METRICS.unplug_count.count(), unplug_count + 1); + assert_eq!(METRICS.unplug_bytes.count(), unplug_bytes); + assert_eq!(METRICS.unplug_fails.count(), unplug_fails + 1); + } + + #[test] + fn test_unplug_all_request() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + let unplug_all_count = METRICS.unplug_all_count.count(); + let unplug_all_fails = METRICS.unplug_all_fails.count(); + + // Plug some blocks + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 2 }), + ); + assert!(resp.is_ack()); + assert_eq!(th.device().plugged_size_mib(), 4); + + // Unplug all + let resp = emulate_request(&mut th, &guest_mem, Request::UnplugAll); + assert!(resp.is_ack()); + assert_eq!(th.device().plugged_size_mib(), 0); + + assert_eq!(METRICS.unplug_all_count.count(), unplug_all_count + 1); + assert_eq!(METRICS.unplug_all_fails.count(), unplug_all_fails); + } + + #[test] + fn test_state_request_unplugged() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + let state_count = METRICS.state_count.count(); + let state_fails = METRICS.state_fails.count(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::State(RequestedRange { addr, nb_blocks: 1 }), + ); + assert_eq!(resp, Response::ack_with_state(BlockRangeState::Unplugged)); + + assert_eq!(METRICS.state_count.count(), state_count + 1); + assert_eq!(METRICS.state_fails.count(), state_fails); + } + + #[test] + fn test_state_request_plugged() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + // Plug first + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_ack()); + + // Check state + let resp = emulate_request( + &mut th, + &guest_mem, + Request::State(RequestedRange { addr, nb_blocks: 1 }), + ); + assert_eq!(resp, Response::ack_with_state(BlockRangeState::Plugged)); + } + + #[test] + fn test_state_request_mixed() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + // Plug first block only + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_ack()); + + // Check state of 2 blocks (one plugged, one unplugged) + let resp = emulate_request( + &mut th, + &guest_mem, + Request::State(RequestedRange { addr, nb_blocks: 2 }), + ); + assert_eq!(resp, Response::ack_with_state(BlockRangeState::Mixed)); + } + + #[test] + fn test_invalid_range_unaligned() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address().unchecked_add(1); + + let state_count = METRICS.state_count.count(); + let state_fails = METRICS.state_fails.count(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::State(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_error()); + + assert_eq!(METRICS.state_count.count(), state_count + 1); + assert_eq!(METRICS.state_fails.count(), state_fails + 1); + } + + #[test] + fn test_invalid_range_zero_blocks() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::State(RequestedRange { addr, nb_blocks: 0 }), + ); + assert!(resp.is_error()); + } + + #[test] + fn test_invalid_range_out_of_bounds() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(4); + let addr = th.device().guest_address(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::State(RequestedRange { + addr, + nb_blocks: 1024, + }), + ); + assert!(resp.is_error()); + } + + #[test] + fn test_unsupported_request() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + + th.add_desc_chain( + MEM_QUEUE, + 0, + &[(0, REQ_SIZE, 0), (1, RESP_SIZE, VIRTQ_DESC_F_WRITE)], + ); + guest_mem + .write_obj( + virtio_mem::virtio_mem_req::from(Request::Unsupported(999)), + th.desc_address(MEM_QUEUE, 0), + ) + .unwrap(); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails + 1); + } } diff --git a/src/vmm/src/devices/virtio/mem/metrics.rs b/src/vmm/src/devices/virtio/mem/metrics.rs index f2a6f58b92d..d69255d44ec 100644 --- a/src/vmm/src/devices/virtio/mem/metrics.rs +++ b/src/vmm/src/devices/virtio/mem/metrics.rs @@ -111,13 +111,8 @@ pub mod tests { #[test] fn test_memory_hotplug_metrics() { let mem_metrics: VirtioMemDeviceMetrics = VirtioMemDeviceMetrics::new(); - let mem_metrics_local: String = serde_json::to_string(&mem_metrics).unwrap(); - // the 1st serialize flushes the metrics and resets values to 0 so that - // we can compare the values with local metrics. - serde_json::to_string(&METRICS).unwrap(); - let mem_metrics_global: String = serde_json::to_string(&METRICS).unwrap(); - assert_eq!(mem_metrics_local, mem_metrics_global); mem_metrics.queue_event_count.inc(); assert_eq!(mem_metrics.queue_event_count.count(), 1); + let _ = serde_json::to_string(&mem_metrics).unwrap(); } } diff --git a/src/vmm/src/devices/virtio/mem/request.rs b/src/vmm/src/devices/virtio/mem/request.rs index e3c620f10af..a55bdb2bbf6 100644 --- a/src/vmm/src/devices/virtio/mem/request.rs +++ b/src/vmm/src/devices/virtio/mem/request.rs @@ -1,7 +1,7 @@ // Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use vm_memory::{ByteValued, GuestAddress}; +use vm_memory::{Address, ByteValued, GuestAddress}; use crate::devices::virtio::generated::virtio_mem; @@ -53,7 +53,7 @@ impl From for Request { } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Eq, PartialEq)] pub enum ResponseType { Ack, Nack, @@ -74,7 +74,7 @@ impl From for u16 { } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Eq, PartialEq)] pub enum BlockRangeState { Plugged, Unplugged, @@ -95,7 +95,7 @@ impl From for virtio_mem::virtio_mem_resp_state { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq, PartialEq)] pub struct Response { pub resp_type: ResponseType, // Only for State requests @@ -123,6 +123,14 @@ impl Response { state: Some(state), } } + + pub(crate) fn is_ack(&self) -> bool { + self.resp_type == ResponseType::Ack + } + + pub(crate) fn is_error(&self) -> bool { + self.resp_type == ResponseType::Error + } } // SAFETY: Plain data structures @@ -140,3 +148,83 @@ impl From for virtio_mem::virtio_mem_resp { out } } + +#[cfg(test)] +mod test_util { + use super::*; + + // Implement the reverse conversions to use in test code. + + impl From for virtio_mem::virtio_mem_req { + fn from(req: Request) -> virtio_mem::virtio_mem_req { + match req { + Request::Plug(r) => virtio_mem::virtio_mem_req { + type_: virtio_mem::VIRTIO_MEM_REQ_PLUG.try_into().unwrap(), + u: virtio_mem::virtio_mem_req__bindgen_ty_1 { + plug: virtio_mem::virtio_mem_req_plug { + addr: r.addr.raw_value(), + nb_blocks: r.nb_blocks.try_into().unwrap(), + ..Default::default() + }, + }, + ..Default::default() + }, + Request::Unplug(r) => virtio_mem::virtio_mem_req { + type_: virtio_mem::VIRTIO_MEM_REQ_UNPLUG.try_into().unwrap(), + u: virtio_mem::virtio_mem_req__bindgen_ty_1 { + unplug: virtio_mem::virtio_mem_req_unplug { + addr: r.addr.raw_value(), + nb_blocks: r.nb_blocks.try_into().unwrap(), + ..Default::default() + }, + }, + ..Default::default() + }, + Request::UnplugAll => virtio_mem::virtio_mem_req { + type_: virtio_mem::VIRTIO_MEM_REQ_UNPLUG_ALL.try_into().unwrap(), + ..Default::default() + }, + Request::State(r) => virtio_mem::virtio_mem_req { + type_: virtio_mem::VIRTIO_MEM_REQ_STATE.try_into().unwrap(), + u: virtio_mem::virtio_mem_req__bindgen_ty_1 { + state: virtio_mem::virtio_mem_req_state { + addr: r.addr.raw_value(), + nb_blocks: r.nb_blocks.try_into().unwrap(), + ..Default::default() + }, + }, + ..Default::default() + }, + Request::Unsupported(t) => virtio_mem::virtio_mem_req { + type_: t.try_into().unwrap(), + ..Default::default() + }, + } + } + } + + impl From for Response { + fn from(resp: virtio_mem::virtio_mem_resp) -> Self { + Response { + resp_type: match resp.type_.into() { + virtio_mem::VIRTIO_MEM_RESP_ACK => ResponseType::Ack, + virtio_mem::VIRTIO_MEM_RESP_NACK => ResponseType::Nack, + virtio_mem::VIRTIO_MEM_RESP_BUSY => ResponseType::Busy, + virtio_mem::VIRTIO_MEM_RESP_ERROR => ResponseType::Error, + t => panic!("Invalid response type: {:?}", t), + }, + // There is no way to know whether this is present or not as it depends on the + // request types. Callers should ignore this value if the request wasn't STATE + /// SAFETY: test code only. Uninitialized values are 0 and recognized as PLUGGED. + state: Some(unsafe { + match resp.u.state.state.into() { + virtio_mem::VIRTIO_MEM_STATE_PLUGGED => BlockRangeState::Plugged, + virtio_mem::VIRTIO_MEM_STATE_UNPLUGGED => BlockRangeState::Unplugged, + virtio_mem::VIRTIO_MEM_STATE_MIXED => BlockRangeState::Mixed, + t => panic!("Invalid state: {:?}", t), + } + }), + } + } + } +} diff --git a/src/vmm/src/devices/virtio/test_utils.rs b/src/vmm/src/devices/virtio/test_utils.rs index 6f1489dd380..0c7978504e7 100644 --- a/src/vmm/src/devices/virtio/test_utils.rs +++ b/src/vmm/src/devices/virtio/test_utils.rs @@ -442,6 +442,11 @@ pub(crate) mod test { self.virtqueues.last().unwrap().end().raw_value() } + /// Get the address of a descriptor + pub fn desc_address(&self, queue: usize, index: usize) -> GuestAddress { + GuestAddress(self.virtqueues[queue].dtable[index].addr.get()) + } + /// Add a new Descriptor in one of the device's queues /// /// This function adds in one of the queues of the device a DescriptorChain at some offset From 0c8c0458769c0e753cf72bc437ff6a2c057e2640 Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Fri, 26 Sep 2025 17:35:29 +0100 Subject: [PATCH 08/14] test(metrics): add virtio-mem device metrics to validation Add the virtio-mem device metrics to the integ test validation. Signed-off-by: Riccardo Mancini --- tests/host_tools/fcmetrics.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/host_tools/fcmetrics.py b/tests/host_tools/fcmetrics.py index 6b110a70f96..3b65901b5aa 100644 --- a/tests/host_tools/fcmetrics.py +++ b/tests/host_tools/fcmetrics.py @@ -302,6 +302,21 @@ def validate_fc_metrics(metrics): "activate_fails", "queue_event_fails", "queue_event_count", + "plug_count", + "plug_bytes", + "plug_fails", + {"plug_agg": latency_agg_metrics_fields}, + "unplug_count", + "unplug_bytes", + "unplug_fails", + "unplug_discard_fails", + {"unplug_agg": latency_agg_metrics_fields}, + "state_count", + "state_fails", + {"state_agg": latency_agg_metrics_fields}, + "unplug_all_count", + "unplug_all_fails", + {"unplug_all_agg": latency_agg_metrics_fields}, ], } From d4b0f9a166c9a484b43a3738cdb86bff48bd3f57 Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Fri, 26 Sep 2025 15:25:25 +0100 Subject: [PATCH 09/14] fix(examples/uffd): unregister range on UFFD Remove event If the handler receives a UFFD remove event, it currently stores the PFN and will reply with a zero page whenever it receives a pagefault event for that page. This works well with 4k pages, but zeropage is not supported on hugepages. In order to support hugepages, let's just unregister from UFFD whenever we get a remove event. By doing so, the handler won't receive a notification for the removed page, and the VM will get a new zero page from the kernel. Signed-off-by: Riccardo Mancini --- .../examples/uffd/on_demand_handler.rs | 2 +- src/firecracker/examples/uffd/uffd_utils.rs | 30 +++++-------------- 2 files changed, 8 insertions(+), 24 deletions(-) diff --git a/src/firecracker/examples/uffd/on_demand_handler.rs b/src/firecracker/examples/uffd/on_demand_handler.rs index 3be958b3578..3101aa19253 100644 --- a/src/firecracker/examples/uffd/on_demand_handler.rs +++ b/src/firecracker/examples/uffd/on_demand_handler.rs @@ -87,7 +87,7 @@ fn main() { } } userfaultfd::Event::Remove { start, end } => { - uffd_handler.mark_range_removed(start as u64, end as u64) + uffd_handler.unregister_range(start, end) } _ => panic!("Unexpected event on userfaultfd"), } diff --git a/src/firecracker/examples/uffd/uffd_utils.rs b/src/firecracker/examples/uffd/uffd_utils.rs index 97c6150b65b..30aa16f040a 100644 --- a/src/firecracker/examples/uffd/uffd_utils.rs +++ b/src/firecracker/examples/uffd/uffd_utils.rs @@ -9,7 +9,7 @@ dead_code )] -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::ffi::c_void; use std::fs::File; use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd}; @@ -54,7 +54,6 @@ pub struct UffdHandler { pub page_size: usize, backing_buffer: *const u8, uffd: Uffd, - removed_pages: HashSet, } impl UffdHandler { @@ -125,7 +124,6 @@ impl UffdHandler { page_size, backing_buffer, uffd, - removed_pages: HashSet::new(), } } @@ -133,24 +131,18 @@ impl UffdHandler { self.uffd.read_event() } - pub fn mark_range_removed(&mut self, start: u64, end: u64) { - let pfn_start = start / self.page_size as u64; - let pfn_end = end / self.page_size as u64; - - for pfn in pfn_start..pfn_end { - self.removed_pages.insert(pfn); - } + pub fn unregister_range(&mut self, start: *mut c_void, end: *mut c_void) { + // SAFETY: start and end are valid and provided by UFFD + let len = unsafe { end.offset_from_unsigned(start) }; + self.uffd + .unregister(start, len) + .expect("range should be valid"); } pub fn serve_pf(&mut self, addr: *mut u8, len: usize) -> bool { // Find the start of the page that the current faulting address belongs to. let dst = (addr as usize & !(self.page_size - 1)) as *mut libc::c_void; let fault_page_addr = dst as u64; - let fault_pfn = fault_page_addr / self.page_size as u64; - - if self.removed_pages.contains(&fault_pfn) { - return self.zero_out(fault_page_addr); - } for region in self.mem_regions.iter() { if region.contains(fault_page_addr) { @@ -193,14 +185,6 @@ impl UffdHandler { true } - - fn zero_out(&mut self, addr: u64) -> bool { - match unsafe { self.uffd.zeropage(addr as *mut _, self.page_size, true) } { - Ok(_) => true, - Err(Error::ZeropageFailed(error)) if error as i32 == libc::EAGAIN => false, - r => panic!("Unexpected zeropage result: {:?}", r), - } - } } #[derive(Debug)] From 9efa1b9cf7b043fe4fb7fe3c748e004cb051e9f7 Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Fri, 26 Sep 2025 15:35:40 +0100 Subject: [PATCH 10/14] feat(test/balloon): include HugePages in RSS measurements This moves the logic to measure RSS to framework.utils and adds a logic to also include huge pages in the measurement. Furthermore, this also adds caching for the firecracker_pid, as well as a new property to get the corresponding psutil.Process. Signed-off-by: Riccardo Mancini --- tests/framework/microvm.py | 10 ++- tests/framework/utils.py | 14 ++++ .../functional/test_balloon.py | 82 +++++++------------ .../test_snapshot_restore_cross_kernel.py | 11 +-- 4 files changed, 57 insertions(+), 60 deletions(-) diff --git a/tests/framework/microvm.py b/tests/framework/microvm.py index fa9dea79b82..69f00e4b94b 100644 --- a/tests/framework/microvm.py +++ b/tests/framework/microvm.py @@ -23,10 +23,11 @@ from collections import namedtuple from dataclasses import dataclass from enum import Enum, auto -from functools import lru_cache +from functools import cached_property, lru_cache from pathlib import Path from typing import Optional +import psutil from tenacity import Retrying, retry, stop_after_attempt, wait_fixed import host_tools.cargo_build as build_tools @@ -472,7 +473,7 @@ def state(self): """Get the InstanceInfo property and return the state field.""" return self.api.describe.get().json()["state"] - @property + @cached_property def firecracker_pid(self): """Return Firecracker's PID @@ -491,6 +492,11 @@ def firecracker_pid(self): with attempt: return int(self.jailer.pid_file.read_text(encoding="ascii")) + @cached_property + def ps(self): + """Returns a handle to the psutil.Process for this VM""" + return psutil.Process(self.firecracker_pid) + @property def dimensions(self): """Gets a default set of cloudwatch dimensions describing the configuration of this microvm""" diff --git a/tests/framework/utils.py b/tests/framework/utils.py index 64bc9526e5c..d592daec84f 100644 --- a/tests/framework/utils.py +++ b/tests/framework/utils.py @@ -14,6 +14,7 @@ import typing from collections import defaultdict, namedtuple from contextlib import contextmanager +from pathlib import Path from typing import Dict import psutil @@ -129,6 +130,19 @@ def track_cpu_utilization( return cpu_utilization +def get_resident_memory(process: psutil.Process): + """Returns current memory utilization in KiB, including used HugeTLBFS""" + + proc_status = Path("/proc", str(process.pid), "status").read_text("utf-8") + for line in proc_status.splitlines(): + if line.startswith("HugetlbPages:"): # entry is in KiB + hugetlbfs_usage = int(line.split()[1]) + break + else: + assert False, f"HugetlbPages not found in {str(proc_status)}" + return hugetlbfs_usage + process.memory_info().rss // 1024 + + @contextmanager def chroot(path): """ diff --git a/tests/integration_tests/functional/test_balloon.py b/tests/integration_tests/functional/test_balloon.py index f8960bedb6d..59c87358c42 100644 --- a/tests/integration_tests/functional/test_balloon.py +++ b/tests/integration_tests/functional/test_balloon.py @@ -9,12 +9,12 @@ import pytest import requests -from framework.utils import check_output, get_free_mem_ssh +from framework.utils import get_free_mem_ssh, get_resident_memory STATS_POLLING_INTERVAL_S = 1 -def get_stable_rss_mem_by_pid(pid, percentage_delta=1): +def get_stable_rss_mem(uvm, percentage_delta=1): """ Get the RSS memory that a guest uses, given the pid of the guest. @@ -22,22 +22,16 @@ def get_stable_rss_mem_by_pid(pid, percentage_delta=1): Or print a warning if this does not happen. """ - # All values are reported as KiB - - def get_rss_from_pmap(): - _, output, _ = check_output("pmap -X {}".format(pid)) - return int(output.split("\n")[-2].split()[1], 10) - first_rss = 0 second_rss = 0 for _ in range(5): - first_rss = get_rss_from_pmap() + first_rss = get_resident_memory(uvm.ps) time.sleep(1) - second_rss = get_rss_from_pmap() + second_rss = get_resident_memory(uvm.ps) abs_diff = abs(first_rss - second_rss) abs_delta = abs_diff / first_rss * 100 print( - f"RSS readings: old: {first_rss} new: {second_rss} abs_diff: {abs_diff} abs_delta: {abs_delta}" + f"RSS readings (bytes): old: {first_rss} new: {second_rss} abs_diff: {abs_diff} abs_delta: {abs_delta}" ) if abs_delta < percentage_delta: return second_rss @@ -87,25 +81,24 @@ def make_guest_dirty_memory(ssh_connection, amount_mib=32): def _test_rss_memory_lower(test_microvm): """Check inflating the balloon makes guest use less rss memory.""" # Get the firecracker pid, and open an ssh connection. - firecracker_pid = test_microvm.firecracker_pid ssh_connection = test_microvm.ssh # Using deflate_on_oom, get the RSS as low as possible test_microvm.api.balloon.patch(amount_mib=200) # Get initial rss consumption. - init_rss = get_stable_rss_mem_by_pid(firecracker_pid) + init_rss = get_stable_rss_mem(test_microvm) # Get the balloon back to 0. test_microvm.api.balloon.patch(amount_mib=0) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Dirty memory, then inflate balloon and get ballooned rss consumption. make_guest_dirty_memory(ssh_connection, amount_mib=32) test_microvm.api.balloon.patch(amount_mib=200) - balloon_rss = get_stable_rss_mem_by_pid(firecracker_pid) + balloon_rss = get_stable_rss_mem(test_microvm) # Check that the ballooning reclaimed the memory. assert balloon_rss - init_rss <= 15000 @@ -149,7 +142,6 @@ def test_inflate_reduces_free(uvm_plain_any): # Start the microvm test_microvm.start() - firecracker_pid = test_microvm.firecracker_pid # Get the free memory before ballooning. available_mem_deflated = get_free_mem_ssh(test_microvm.ssh) @@ -157,7 +149,7 @@ def test_inflate_reduces_free(uvm_plain_any): # Inflate 64 MB == 16384 page balloon. test_microvm.api.balloon.patch(amount_mib=64) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Get the free memory after ballooning. available_mem_inflated = get_free_mem_ssh(test_microvm.ssh) @@ -195,19 +187,18 @@ def test_deflate_on_oom(uvm_plain_any, deflate_on_oom): # Start the microvm. test_microvm.start() - firecracker_pid = test_microvm.firecracker_pid # We get an initial reading of the RSS, then calculate the amount # we need to inflate the balloon with by subtracting it from the # VM size and adding an offset of 50 MiB in order to make sure we # get a lower reading than the initial one. - initial_rss = get_stable_rss_mem_by_pid(firecracker_pid) + initial_rss = get_stable_rss_mem(test_microvm) inflate_size = 256 - (int(initial_rss / 1024) + 50) # Inflate the balloon test_microvm.api.balloon.patch(amount_mib=inflate_size) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Check that using memory leads to the balloon device automatically # deflate (or not). @@ -250,39 +241,38 @@ def test_reinflate_balloon(uvm_plain_any): # Start the microvm. test_microvm.start() - firecracker_pid = test_microvm.firecracker_pid # First inflate the balloon to free up the uncertain amount of memory # used by the kernel at boot and establish a baseline, then give back # the memory. test_microvm.api.balloon.patch(amount_mib=200) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) test_microvm.api.balloon.patch(amount_mib=0) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Get the guest to dirty memory. make_guest_dirty_memory(test_microvm.ssh, amount_mib=32) - first_reading = get_stable_rss_mem_by_pid(firecracker_pid) + first_reading = get_stable_rss_mem(test_microvm) # Now inflate the balloon. test_microvm.api.balloon.patch(amount_mib=200) - second_reading = get_stable_rss_mem_by_pid(firecracker_pid) + second_reading = get_stable_rss_mem(test_microvm) # Now deflate the balloon. test_microvm.api.balloon.patch(amount_mib=0) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Now have the guest dirty memory again. make_guest_dirty_memory(test_microvm.ssh, amount_mib=32) - third_reading = get_stable_rss_mem_by_pid(firecracker_pid) + third_reading = get_stable_rss_mem(test_microvm) # Now inflate the balloon again. test_microvm.api.balloon.patch(amount_mib=200) - fourth_reading = get_stable_rss_mem_by_pid(firecracker_pid) + fourth_reading = get_stable_rss_mem(test_microvm) # Check that the memory used is the same after regardless of the previous # inflate history of the balloon (with the third reading being allowed @@ -309,10 +299,9 @@ def test_size_reduction(uvm_plain_any): # Start the microvm. test_microvm.start() - firecracker_pid = test_microvm.firecracker_pid # Check memory usage. - first_reading = get_stable_rss_mem_by_pid(firecracker_pid) + first_reading = get_stable_rss_mem(test_microvm) # Have the guest drop its caches. test_microvm.ssh.run("sync; echo 3 > /proc/sys/vm/drop_caches") @@ -328,7 +317,7 @@ def test_size_reduction(uvm_plain_any): test_microvm.api.balloon.patch(amount_mib=inflate_size) # Check memory usage again. - second_reading = get_stable_rss_mem_by_pid(firecracker_pid) + second_reading = get_stable_rss_mem(test_microvm) # There should be a reduction of at least 10MB. assert first_reading - second_reading >= 10000 @@ -353,7 +342,6 @@ def test_stats(uvm_plain_any): # Start the microvm. test_microvm.start() - firecracker_pid = test_microvm.firecracker_pid # Give Firecracker enough time to poll the stats at least once post-boot time.sleep(STATS_POLLING_INTERVAL_S * 2) @@ -371,7 +359,7 @@ def test_stats(uvm_plain_any): make_guest_dirty_memory(test_microvm.ssh, amount_mib=10) time.sleep(1) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Make sure that the stats catch the page faults. after_workload_stats = test_microvm.api.balloon_stats.get().json() @@ -380,7 +368,7 @@ def test_stats(uvm_plain_any): # Now inflate the balloon with 10MB of pages. test_microvm.api.balloon.patch(amount_mib=10) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Get another reading of the stats after the polling interval has passed. inflated_stats = test_microvm.api.balloon_stats.get().json() @@ -393,7 +381,7 @@ def test_stats(uvm_plain_any): # available memory. test_microvm.api.balloon.patch(amount_mib=0) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Get another reading of the stats after the polling interval has passed. deflated_stats = test_microvm.api.balloon_stats.get().json() @@ -421,13 +409,12 @@ def test_stats_update(uvm_plain_any): # Start the microvm. test_microvm.start() - firecracker_pid = test_microvm.firecracker_pid # Dirty 30MB of pages. make_guest_dirty_memory(test_microvm.ssh, amount_mib=30) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Get an initial reading of the stats. initial_stats = test_microvm.api.balloon_stats.get().json() @@ -477,17 +464,14 @@ def test_balloon_snapshot(uvm_plain_any, microvm_factory): make_guest_dirty_memory(vm.ssh, amount_mib=60) time.sleep(1) - # Get the firecracker pid, and open an ssh connection. - firecracker_pid = vm.firecracker_pid - # Check memory usage. - first_reading = get_stable_rss_mem_by_pid(firecracker_pid) + first_reading = get_stable_rss_mem(vm) # Now inflate the balloon with 20MB of pages. vm.api.balloon.patch(amount_mib=20) # Check memory usage again. - second_reading = get_stable_rss_mem_by_pid(firecracker_pid) + second_reading = get_stable_rss_mem(vm) # There should be a reduction in RSS, but it's inconsistent. # We only test that the reduction happens. @@ -496,28 +480,25 @@ def test_balloon_snapshot(uvm_plain_any, microvm_factory): snapshot = vm.snapshot_full() microvm = microvm_factory.build_from_snapshot(snapshot) - # Get the firecracker from snapshot pid, and open an ssh connection. - firecracker_pid = microvm.firecracker_pid - # Wait out the polling interval, then get the updated stats. time.sleep(STATS_POLLING_INTERVAL_S * 2) stats_after_snap = microvm.api.balloon_stats.get().json() # Check memory usage. - third_reading = get_stable_rss_mem_by_pid(firecracker_pid) + third_reading = get_stable_rss_mem(microvm) # Dirty 60MB of pages. make_guest_dirty_memory(microvm.ssh, amount_mib=60) # Check memory usage. - fourth_reading = get_stable_rss_mem_by_pid(firecracker_pid) + fourth_reading = get_stable_rss_mem(microvm) assert fourth_reading > third_reading # Inflate the balloon with another 20MB of pages. microvm.api.balloon.patch(amount_mib=40) - fifth_reading = get_stable_rss_mem_by_pid(firecracker_pid) + fifth_reading = get_stable_rss_mem(microvm) # There should be a reduction in RSS, but it's inconsistent. # We only test that the reduction happens. @@ -557,15 +538,14 @@ def test_memory_scrub(uvm_plain_any): microvm.api.balloon.patch(amount_mib=60) # Get the firecracker pid, and open an ssh connection. - firecracker_pid = microvm.firecracker_pid # Wait for the inflate to complete. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(microvm) # Deflate the balloon completely. microvm.api.balloon.patch(amount_mib=0) # Wait for the deflate to complete. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(microvm) microvm.ssh.check_output("/usr/local/bin/readmem {} {}".format(60, 1)) diff --git a/tests/integration_tests/functional/test_snapshot_restore_cross_kernel.py b/tests/integration_tests/functional/test_snapshot_restore_cross_kernel.py index bfe5316d9e5..253502a2d1f 100644 --- a/tests/integration_tests/functional/test_snapshot_restore_cross_kernel.py +++ b/tests/integration_tests/functional/test_snapshot_restore_cross_kernel.py @@ -20,7 +20,7 @@ from framework.utils_cpu_templates import get_supported_cpu_templates from framework.utils_vsock import check_vsock_device from integration_tests.functional.test_balloon import ( - get_stable_rss_mem_by_pid, + get_stable_rss_mem, make_guest_dirty_memory, ) @@ -28,21 +28,18 @@ def _test_balloon(microvm): - # Get the firecracker pid. - firecracker_pid = microvm.firecracker_pid - # Check memory usage. - first_reading = get_stable_rss_mem_by_pid(firecracker_pid) + first_reading = get_stable_rss_mem(microvm) # Dirty 300MB of pages. make_guest_dirty_memory(microvm.ssh, amount_mib=300) # Check memory usage again. - second_reading = get_stable_rss_mem_by_pid(firecracker_pid) + second_reading = get_stable_rss_mem(microvm) assert second_reading > first_reading # Inflate the balloon. Get back 200MB. microvm.api.balloon.patch(amount_mib=200) - third_reading = get_stable_rss_mem_by_pid(firecracker_pid) + third_reading = get_stable_rss_mem(microvm) # Ensure that there is a reduction in RSS. assert second_reading > third_reading From 9b27bcfe2dffcf19c97ebd3a99d8f1629f1b5007 Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Fri, 26 Sep 2025 15:39:55 +0100 Subject: [PATCH 11/14] refactor(test/balloon): move logic to get guest avail mem to framework Move the logic to get the MemAvailable from /proc/meminfo inside the guest to a new guest_stats module in the test framework. This provides a new class MeminfoGuest that can be used to retrieve this information (and more!). Signed-off-by: Riccardo Mancini --- tests/framework/guest_stats.py | 79 +++++++++++++++++++ tests/framework/utils.py | 19 ----- .../functional/test_balloon.py | 8 +- 3 files changed, 84 insertions(+), 22 deletions(-) create mode 100644 tests/framework/guest_stats.py diff --git a/tests/framework/guest_stats.py b/tests/framework/guest_stats.py new file mode 100644 index 00000000000..570b9a4ea63 --- /dev/null +++ b/tests/framework/guest_stats.py @@ -0,0 +1,79 @@ +# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Classes for querying guest stats inside microVMs. +""" + + +class ByteUnit: + """Represents a byte unit that can be converted to other units.""" + + value_bytes: int + + def __init__(self, value_bytes: int): + self.value_bytes = value_bytes + + @classmethod + def from_kib(cls, value_kib: int): + """Creates a ByteUnit from a value in KiB.""" + if value_kib < 0: + raise ValueError("value_kib must be non-negative") + return ByteUnit(value_kib * 1024) + + def bytes(self) -> float: + """Returns the value in B.""" + return self.value_bytes + + def kib(self) -> float: + """Returns the value in KiB.""" + return self.value_bytes / 1024 + + def mib(self) -> float: + """Returns the value in MiB.""" + return self.value_bytes / (1 << 20) + + def gib(self) -> float: + """Returns the value in GiB.""" + return self.value_bytes / (1 << 30) + + +class Meminfo: + """Represents the contents of /proc/meminfo inside the guest""" + + mem_total: ByteUnit + mem_free: ByteUnit + mem_available: ByteUnit + buffers: ByteUnit + cached: ByteUnit + + def __init__(self): + self.mem_total = ByteUnit(0) + self.mem_free = ByteUnit(0) + self.mem_available = ByteUnit(0) + self.buffers = ByteUnit(0) + self.cached = ByteUnit(0) + + +class MeminfoGuest: + """Queries /proc/meminfo inside the guest""" + + def __init__(self, vm): + self.vm = vm + + def get(self) -> Meminfo: + """Returns the contents of /proc/meminfo inside the guest""" + meminfo = Meminfo() + for line in self.vm.ssh.check_output("cat /proc/meminfo").stdout.splitlines(): + parts = line.split() + if parts[0] == "MemTotal:": + meminfo.mem_total = ByteUnit.from_kib(int(parts[1])) + elif parts[0] == "MemFree:": + meminfo.mem_free = ByteUnit.from_kib(int(parts[1])) + elif parts[0] == "MemAvailable:": + meminfo.mem_available = ByteUnit.from_kib(int(parts[1])) + elif parts[0] == "Buffers:": + meminfo.buffers = ByteUnit.from_kib(int(parts[1])) + elif parts[0] == "Cached:": + meminfo.cached = ByteUnit.from_kib(int(parts[1])) + + return meminfo diff --git a/tests/framework/utils.py b/tests/framework/utils.py index d592daec84f..448b351fd86 100644 --- a/tests/framework/utils.py +++ b/tests/framework/utils.py @@ -254,25 +254,6 @@ def search_output_from_cmd(cmd: str, find_regex: typing.Pattern) -> typing.Match ) -def get_free_mem_ssh(ssh_connection): - """ - Get how much free memory in kB a guest sees, over ssh. - - :param ssh_connection: connection to the guest - :return: available mem column output of 'free' - """ - _, stdout, stderr = ssh_connection.run("cat /proc/meminfo | grep MemAvailable") - assert stderr == "" - - # Split "MemAvailable: 123456 kB" and validate it - meminfo_data = stdout.split() - if len(meminfo_data) == 3: - # Return the middle element in the array - return int(meminfo_data[1]) - - raise Exception("Available memory not found in `/proc/meminfo") - - def _format_output_message(proc, stdout, stderr): output_message = f"\n[{proc.pid}] Command:\n{proc.args}" # Append stdout/stderr to the output message diff --git a/tests/integration_tests/functional/test_balloon.py b/tests/integration_tests/functional/test_balloon.py index 59c87358c42..19b1651c72a 100644 --- a/tests/integration_tests/functional/test_balloon.py +++ b/tests/integration_tests/functional/test_balloon.py @@ -9,7 +9,8 @@ import pytest import requests -from framework.utils import get_free_mem_ssh, get_resident_memory +from framework.guest_stats import MeminfoGuest +from framework.utils import get_resident_memory STATS_POLLING_INTERVAL_S = 1 @@ -142,9 +143,10 @@ def test_inflate_reduces_free(uvm_plain_any): # Start the microvm test_microvm.start() + meminfo = MeminfoGuest(test_microvm) # Get the free memory before ballooning. - available_mem_deflated = get_free_mem_ssh(test_microvm.ssh) + available_mem_deflated = meminfo.get().mem_free.kib() # Inflate 64 MB == 16384 page balloon. test_microvm.api.balloon.patch(amount_mib=64) @@ -152,7 +154,7 @@ def test_inflate_reduces_free(uvm_plain_any): _ = get_stable_rss_mem(test_microvm) # Get the free memory after ballooning. - available_mem_inflated = get_free_mem_ssh(test_microvm.ssh) + available_mem_inflated = meminfo.get().mem_free.kib() # Assert that ballooning reclaimed about 64 MB of memory. assert available_mem_inflated <= available_mem_deflated - 85 * 64000 / 100 From 9fca25d9670db718d090895146c855aa9ea0e1ce Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Fri, 26 Sep 2025 15:44:05 +0100 Subject: [PATCH 12/14] test(virtio-mem): add functional integration tests for device Add integration tests for the new device: - check that the device is detected - check that hotplugging and unplugging works - check that memory can be used after hotplugging - check that memory is freed on hotunplug - check different config combinations - check different uvm types - check that contents are preserved across snapshot-restore Signed-off-by: Riccardo Mancini --- tests/framework/microvm.py | 28 ++ .../functional/test_memory_hp.py | 271 +++++++++++++++++- 2 files changed, 288 insertions(+), 11 deletions(-) diff --git a/tests/framework/microvm.py b/tests/framework/microvm.py index 69f00e4b94b..853de6fc4ef 100644 --- a/tests/framework/microvm.py +++ b/tests/framework/microvm.py @@ -1186,6 +1186,22 @@ def wait_for_ssh_up(self): # run commands. The actual connection retry loop happens in SSHConnection._init_connection _ = self.ssh_iface(0) + def hotplug_memory( + self, requested_size_mib: int, timeout: int = 60, poll: float = 0.1 + ): + """Send a hot(un)plug request and wait up to timeout seconds for completion polling every poll seconds""" + self.api.memory_hotplug.patch(requested_size_mib=requested_size_mib) + # Wait for the hotplug to complete + deadline = time.time() + timeout + while time.time() < deadline: + if ( + self.api.memory_hotplug.get().json()["plugged_size_mib"] + == requested_size_mib + ): + return + time.sleep(poll) + raise TimeoutError(f"Hotplug did not complete within {timeout} seconds") + class MicroVMFactory: """MicroVM factory""" @@ -1300,6 +1316,18 @@ def build_n_from_snapshot( last_snapshot.delete() current_snapshot.delete() + def clone_uvm(self, uvm, uffd_handler_name=None): + """ + Clone the given VM and start it. + """ + snapshot = uvm.snapshot_full() + restored_vm = self.build() + restored_vm.spawn() + restored_vm.restore_from_snapshot( + snapshot, resume=True, uffd_handler_name=uffd_handler_name + ) + return restored_vm + def kill(self): """Clean up all built VMs""" for vm in self.vms: diff --git a/tests/integration_tests/functional/test_memory_hp.py b/tests/integration_tests/functional/test_memory_hp.py index b2132d6c9ed..317acb8d915 100644 --- a/tests/integration_tests/functional/test_memory_hp.py +++ b/tests/integration_tests/functional/test_memory_hp.py @@ -3,21 +3,120 @@ """Tests for verifying the virtio-mem is working correctly""" +import pytest +from packaging import version +from tenacity import Retrying, retry_if_exception_type, stop_after_delay, wait_fixed + +from framework.guest_stats import MeminfoGuest +from framework.microvm import HugePagesConfig +from framework.utils import get_kernel_version, get_resident_memory + +MEMHP_BOOTARGS = "console=ttyS0 reboot=k panic=1 memhp_default_state=online_movable" +DEFAULT_CONFIG = {"total_size_mib": 1024, "slot_size_mib": 128, "block_size_mib": 2} + + +def uvm_booted_memhp( + uvm, rootfs, _microvm_factory, vhost_user, memhp_config, huge_pages, _uffd_handler +): + """Boots a VM with the given memory hotplugging config""" -def test_virtio_mem_detected(uvm_plain_6_1): - """ - Check that the guest kernel has enabled PV steal time. - """ - uvm = uvm_plain_6_1 uvm.spawn() uvm.memory_monitor = None - uvm.basic_config( - boot_args="console=ttyS0 reboot=k panic=1 memhp_default_state=online_movable" - ) + if vhost_user: + # We need to setup ssh keys manually because we did not specify rootfs + # in microvm_factory.build method + ssh_key = rootfs.with_suffix(".id_rsa") + uvm.ssh_key = ssh_key + uvm.basic_config( + boot_args=MEMHP_BOOTARGS, add_root_device=False, huge_pages=huge_pages + ) + uvm.add_vhost_user_drive( + "rootfs", rootfs, is_root_device=True, is_read_only=True + ) + else: + uvm.basic_config(boot_args=MEMHP_BOOTARGS, huge_pages=huge_pages) + + uvm.api.memory_hotplug.put(**memhp_config) uvm.add_net_iface() - uvm.api.memory_hotplug.put(total_size_mib=1024) uvm.start() + return uvm + + +def uvm_resumed_memhp( + uvm_plain, + rootfs, + microvm_factory, + vhost_user, + memhp_config, + huge_pages, + uffd_handler, +): + """Restores a VM with the given memory hotplugging config after booting and snapshotting""" + if vhost_user: + pytest.skip("vhost-user doesn't support snapshot/restore") + if huge_pages and huge_pages != HugePagesConfig.NONE and not uffd_handler: + pytest.skip("Hugepages requires a UFFD handler") + uvm = uvm_booted_memhp( + uvm_plain, rootfs, microvm_factory, vhost_user, memhp_config, huge_pages, None + ) + return microvm_factory.clone_uvm(uvm, uffd_handler_name=uffd_handler) + + +@pytest.fixture( + params=[ + (uvm_booted_memhp, False, HugePagesConfig.NONE, None), + (uvm_booted_memhp, False, HugePagesConfig.HUGETLBFS_2MB, None), + (uvm_booted_memhp, True, HugePagesConfig.NONE, None), + (uvm_resumed_memhp, False, HugePagesConfig.NONE, None), + (uvm_resumed_memhp, False, HugePagesConfig.NONE, "on_demand"), + (uvm_resumed_memhp, False, HugePagesConfig.HUGETLBFS_2MB, "on_demand"), + ], + ids=[ + "booted", + "booted-huge-pages", + "booted-vhost-user", + "resumed", + "resumed-uffd", + "resumed-uffd-huge-pages", + ], +) +def uvm_any_memhp(request, uvm_plain_6_1, rootfs, microvm_factory): + """Fixture that yields a booted or resumed VM with memory hotplugging""" + ctor, vhost_user, huge_pages, uffd_handler = request.param + yield ctor( + uvm_plain_6_1, + rootfs, + microvm_factory, + vhost_user, + DEFAULT_CONFIG, + huge_pages, + uffd_handler, + ) + + +def supports_hugetlbfs_discard(): + """Returns True if the kernel supports hugetlbfs discard""" + return version.parse(get_kernel_version()) >= version.parse("5.18.0") + + +def validate_metrics(uvm): + """Validates that there are no fails in the metrics""" + metrics_to_check = ["plug_fails", "unplug_fails", "unplug_all_fails", "state_fails"] + if supports_hugetlbfs_discard(): + metrics_to_check.append("unplug_discard_fails") + uvm.flush_metrics() + for metrics in uvm.get_all_metrics(): + for k in metrics_to_check: + assert ( + metrics["memory_hotplug"][k] == 0 + ), f"{k}={metrics[k]} is greater than zero" + +def check_device_detected(uvm): + """ + Check that the guest kernel has enabled virtio-mem. + """ + hp_config = uvm.api.memory_hotplug.get().json() _, stdout, _ = uvm.ssh.check_output("dmesg | grep 'virtio_mem'") for line in stdout.splitlines(): _, key, value = line.strip().split(":") @@ -27,12 +126,162 @@ def test_virtio_mem_detected(uvm_plain_6_1): case "start address": assert value == (512 << 30), "start address doesn't match" case "region size": - assert value == 1024 << 20, "region size doesn't match" + assert ( + value == hp_config["total_size_mib"] << 20 + ), "region size doesn't match" case "device block size": - assert value == 2 << 20, "block size doesn't match" + assert ( + value == hp_config["block_size_mib"] << 20 + ), "block size doesn't match" case "plugged size": assert value == 0, "plugged size doesn't match" case "requested size": assert value == 0, "requested size doesn't match" case _: continue + + +def check_memory_usable(uvm): + """Allocates memory to verify it's usable (5% margin to avoid OOM-kill)""" + mem_available = MeminfoGuest(uvm).get().mem_available.bytes() + # number of 64b ints to allocate as 95% of available memory + count = mem_available * 95 // 100 // 8 + + uvm.ssh.check_output( + f"python3 -c 'Q = 0x0123456789abcdef; a = [Q] * {count}; assert all(q == Q for q in a)'" + ) + + +def check_hotplug(uvm, requested_size_mib): + """Verifies memory can be hot(un)plugged""" + meminfo = MeminfoGuest(uvm) + mem_total_fixed = ( + meminfo.get().mem_total.mib() + - uvm.api.memory_hotplug.get().json()["plugged_size_mib"] + ) + uvm.hotplug_memory(requested_size_mib) + + # verify guest driver received the request + _, stdout, _ = uvm.ssh.check_output( + "dmesg | grep 'virtio_mem' | grep 'requested size' | tail -1" + ) + assert ( + int(stdout.strip().split(":")[-1].strip(), base=0) == requested_size_mib << 20 + ) + + for attempt in Retrying( + retry=retry_if_exception_type(AssertionError), + stop=stop_after_delay(5), + wait=wait_fixed(1), + reraise=True, + ): + with attempt: + # verify guest driver executed the request + mem_total_after = meminfo.get().mem_total.mib() + assert mem_total_after == mem_total_fixed + requested_size_mib + + +def check_hotunplug(uvm, requested_size_mib): + """Verifies memory can be hotunplugged and gets released""" + + rss_before = get_resident_memory(uvm.ps) + + check_hotplug(uvm, requested_size_mib) + + rss_after = get_resident_memory(uvm.ps) + + print(f"RSS before: {rss_before}, after: {rss_after}") + + huge_pages = HugePagesConfig(uvm.api.machine_config.get().json()["huge_pages"]) + if huge_pages == HugePagesConfig.HUGETLBFS_2MB and supports_hugetlbfs_discard(): + assert rss_after < rss_before, "RSS didn't decrease" + + +def test_virtio_mem_hotplug_hotunplug(uvm_any_memhp): + """ + Check that memory can be hotplugged into the VM. + """ + uvm = uvm_any_memhp + check_device_detected(uvm) + + check_hotplug(uvm, 1024) + check_memory_usable(uvm) + + check_hotunplug(uvm, 0) + + # Check it works again + check_hotplug(uvm, 1024) + check_memory_usable(uvm) + + validate_metrics(uvm) + + +@pytest.mark.parametrize( + "memhp_config", + [ + {"total_size_mib": 256, "slot_size_mib": 128, "block_size_mib": 64}, + {"total_size_mib": 256, "slot_size_mib": 128, "block_size_mib": 128}, + {"total_size_mib": 256, "slot_size_mib": 256, "block_size_mib": 64}, + {"total_size_mib": 256, "slot_size_mib": 256, "block_size_mib": 256}, + ], + ids=["all_different", "slot_sized_block", "single_slot", "single_block"], +) +def test_virtio_mem_configs(uvm_plain_6_1, memhp_config): + """ + Check that the virtio mem device is working as expected for different configs + """ + uvm = uvm_booted_memhp(uvm_plain_6_1, None, None, False, memhp_config, None, None) + if not uvm.pci_enabled: + pytest.skip( + "Skip tests on MMIO transport to save time as we don't expect any difference." + ) + + check_device_detected(uvm) + + for size in range( + 0, memhp_config["total_size_mib"] + 1, memhp_config["block_size_mib"] + ): + check_hotplug(uvm, size) + + check_memory_usable(uvm) + + for size in range( + memhp_config["total_size_mib"] - memhp_config["block_size_mib"], + -1, + -memhp_config["block_size_mib"], + ): + check_hotunplug(uvm, size) + + validate_metrics(uvm) + + +def test_snapshot_restore_persistence(uvm_plain_6_1, microvm_factory): + """ + Check that hptplugged memory is persisted across snapshot/restore. + """ + if not uvm_plain_6_1.pci_enabled: + pytest.skip( + "Skip tests on MMIO transport to save time as we don't expect any difference." + ) + uvm = uvm_booted_memhp( + uvm_plain_6_1, None, microvm_factory, False, DEFAULT_CONFIG, None, None + ) + + uvm.hotplug_memory(1024) + + # Increase /dev/shm size as it defaults to half of the boot memory + uvm.ssh.check_output("mount -o remount,size=1024M -t tmpfs tmpfs /dev/shm") + + uvm.ssh.check_output("dd if=/dev/urandom of=/dev/shm/mem_hp_test bs=1M count=1024") + + _, checksum_before, _ = uvm.ssh.check_output("sha256sum /dev/shm/mem_hp_test") + + restored_vm = microvm_factory.clone_uvm(uvm) + + _, checksum_after, _ = restored_vm.ssh.check_output( + "sha256sum /dev/shm/mem_hp_test" + ) + + assert checksum_before == checksum_after, "Checksums didn't match" + + validate_metrics(restored_vm) From 2069fc3132fe61861039188a817696de1f749f5f Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Thu, 2 Oct 2025 17:17:29 +0100 Subject: [PATCH 13/14] chore(test/virtio-mem): move tests under performance Since these tests need to be run on an ag=1 host, move them under the "performance" folder. Signed-off-by: Riccardo Mancini --- .../test_hotplug_memory.py} | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) rename tests/integration_tests/{functional/test_memory_hp.py => performance/test_hotplug_memory.py} (97%) diff --git a/tests/integration_tests/functional/test_memory_hp.py b/tests/integration_tests/performance/test_hotplug_memory.py similarity index 97% rename from tests/integration_tests/functional/test_memory_hp.py rename to tests/integration_tests/performance/test_hotplug_memory.py index 317acb8d915..6aa8649c5d1 100644 --- a/tests/integration_tests/functional/test_memory_hp.py +++ b/tests/integration_tests/performance/test_hotplug_memory.py @@ -1,7 +1,12 @@ # Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -"""Tests for verifying the virtio-mem is working correctly""" +""" +Tests for verifying the virtio-mem is working correctly + +This file also contains functional tests for virtio-mem because they need to be +run on an ag=1 host due to the use of HugePages. +""" import pytest from packaging import version From 370eb4e34b5a71016096fa4a9b25cec9074d5af6 Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Mon, 6 Oct 2025 16:29:57 +0100 Subject: [PATCH 14/14] test(virtio-mem): add rust integration tests These tests add unit test coverage to the builder.rs and vm.rs files which where previously untested in the memory hotplug case. Signed-off-by: Riccardo Mancini --- src/vmm/src/test_utils/mod.rs | 24 ++++----- src/vmm/tests/integration_tests.rs | 79 ++++++++++++++++++------------ 2 files changed, 60 insertions(+), 43 deletions(-) diff --git a/src/vmm/src/test_utils/mod.rs b/src/vmm/src/test_utils/mod.rs index 6fe66cdbadb..887acc54d38 100644 --- a/src/vmm/src/test_utils/mod.rs +++ b/src/vmm/src/test_utils/mod.rs @@ -16,6 +16,7 @@ use crate::vm_memory_vendored::GuestRegionCollection; use crate::vmm_config::boot_source::BootSourceConfig; use crate::vmm_config::instance_info::InstanceInfo; use crate::vmm_config::machine_config::HugePageConfig; +use crate::vmm_config::memory_hotplug::MemoryHotplugConfig; use crate::vstate::memory::{self, GuestMemoryMmap, GuestRegionMmap, GuestRegionMmapExt}; use crate::{EventManager, Vmm}; @@ -73,6 +74,7 @@ pub fn create_vmm( is_diff: bool, boot_microvm: bool, pci_enabled: bool, + memory_hotplug_enabled: bool, ) -> (Arc>, EventManager) { let mut event_manager = EventManager::new().unwrap(); let empty_seccomp_filters = get_empty_filters(); @@ -96,6 +98,14 @@ pub fn create_vmm( resources.pci_enabled = pci_enabled; + if memory_hotplug_enabled { + resources.memory_hotplug = Some(MemoryHotplugConfig { + total_size_mib: 1024, + block_size_mib: 2, + slot_size_mib: 128, + }); + } + let vmm = build_microvm_for_boot( &InstanceInfo::default(), &resources, @@ -112,23 +122,15 @@ pub fn create_vmm( } pub fn default_vmm(kernel_image: Option<&str>) -> (Arc>, EventManager) { - create_vmm(kernel_image, false, true, false) + create_vmm(kernel_image, false, true, false, false) } pub fn default_vmm_no_boot(kernel_image: Option<&str>) -> (Arc>, EventManager) { - create_vmm(kernel_image, false, false, false) -} - -pub fn default_vmm_pci_no_boot(kernel_image: Option<&str>) -> (Arc>, EventManager) { - create_vmm(kernel_image, false, false, true) + create_vmm(kernel_image, false, false, false, false) } pub fn dirty_tracking_vmm(kernel_image: Option<&str>) -> (Arc>, EventManager) { - create_vmm(kernel_image, true, true, false) -} - -pub fn default_vmm_pci(kernel_image: Option<&str>) -> (Arc>, EventManager) { - create_vmm(kernel_image, false, true, false) + create_vmm(kernel_image, true, true, false, false) } #[allow(clippy::undocumented_unsafe_blocks)] diff --git a/src/vmm/tests/integration_tests.rs b/src/vmm/tests/integration_tests.rs index 4abbedc4530..6a5e6a08a14 100644 --- a/src/vmm/tests/integration_tests.rs +++ b/src/vmm/tests/integration_tests.rs @@ -18,9 +18,7 @@ use vmm::rpc_interface::{ use vmm::seccomp::get_empty_filters; use vmm::snapshot::Snapshot; use vmm::test_utils::mock_resources::{MockVmResources, NOISY_KERNEL_IMAGE}; -use vmm::test_utils::{ - create_vmm, default_vmm, default_vmm_no_boot, default_vmm_pci, default_vmm_pci_no_boot, -}; +use vmm::test_utils::{create_vmm, default_vmm, default_vmm_no_boot}; use vmm::vmm_config::balloon::BalloonDeviceConfig; use vmm::vmm_config::boot_source::BootSourceConfig; use vmm::vmm_config::drive::BlockDeviceConfig; @@ -66,13 +64,12 @@ fn test_build_and_boot_microvm() { assert_eq!(format!("{:?}", vmm_ret.err()), "Some(MissingKernelConfig)"); } - // Success case. - let (vmm, evmgr) = default_vmm(None); - check_booted_microvm(vmm, evmgr); - - // microVM with PCI - let (vmm, evmgr) = default_vmm_pci(None); - check_booted_microvm(vmm, evmgr); + for pci_enabled in [false, true] { + for memory_hotplug in [false, true] { + let (vmm, evmgr) = create_vmm(None, false, true, pci_enabled, memory_hotplug); + check_booted_microvm(vmm, evmgr); + } + } } #[allow(unused_mut, unused_variables)] @@ -96,10 +93,12 @@ fn check_build_microvm(vmm: Arc>, mut evmgr: EventManager) { #[test] fn test_build_microvm() { - let (vmm, evtmgr) = default_vmm_no_boot(None); - check_build_microvm(vmm, evtmgr); - let (vmm, evtmgr) = default_vmm_pci_no_boot(None); - check_build_microvm(vmm, evtmgr); + for pci_enabled in [false, true] { + for memory_hotplug in [false, true] { + let (vmm, evmgr) = create_vmm(None, false, false, pci_enabled, memory_hotplug); + check_build_microvm(vmm, evmgr); + } + } } fn pause_resume_microvm(vmm: Arc>) { @@ -118,13 +117,14 @@ fn pause_resume_microvm(vmm: Arc>) { #[test] fn test_pause_resume_microvm() { - // Tests that pausing and resuming a microVM work as expected. - let (vmm, _) = default_vmm(None); + for pci_enabled in [false, true] { + for memory_hotplug in [false, true] { + // Tests that pausing and resuming a microVM work as expected. + let (vmm, _) = create_vmm(None, false, true, pci_enabled, memory_hotplug); - pause_resume_microvm(vmm); - - let (vmm, _) = default_vmm_pci(None); - pause_resume_microvm(vmm); + pause_resume_microvm(vmm); + } + } } #[test] @@ -195,11 +195,21 @@ fn test_disallow_dump_cpu_config_without_pausing() { vmm.lock().unwrap().stop(FcExitCode::Ok); } -fn verify_create_snapshot(is_diff: bool, pci_enabled: bool) -> (TempFile, TempFile) { +fn verify_create_snapshot( + is_diff: bool, + pci_enabled: bool, + memory_hotplug: bool, +) -> (TempFile, TempFile) { let snapshot_file = TempFile::new().unwrap(); let memory_file = TempFile::new().unwrap(); - let (vmm, _) = create_vmm(Some(NOISY_KERNEL_IMAGE), is_diff, true, pci_enabled); + let (vmm, _) = create_vmm( + Some(NOISY_KERNEL_IMAGE), + is_diff, + true, + pci_enabled, + memory_hotplug, + ); let resources = VmResources { machine_config: MachineConfig { mem_size_mib: 1, @@ -303,14 +313,19 @@ fn verify_load_snapshot(snapshot_file: TempFile, memory_file: TempFile) { #[test] fn test_create_and_load_snapshot() { - for (diff_snap, pci_enabled) in [(false, false), (false, true), (true, false), (true, true)] { - // Create snapshot. - let (snapshot_file, memory_file) = verify_create_snapshot(diff_snap, pci_enabled); - // Create a new microVm from snapshot. This only tests code-level logic; it verifies - // that a microVM can be built with no errors from given snapshot. - // It does _not_ verify that the guest is actually restored properly. We're using - // python integration tests for that. - verify_load_snapshot(snapshot_file, memory_file); + for diff_snap in [false, true] { + for pci_enabled in [false, true] { + for memory_hotplug in [false, true] { + // Create snapshot. + let (snapshot_file, memory_file) = + verify_create_snapshot(diff_snap, pci_enabled, memory_hotplug); + // Create a new microVm from snapshot. This only tests code-level logic; it verifies + // that a microVM can be built with no errors from given snapshot. + // It does _not_ verify that the guest is actually restored properly. We're using + // python integration tests for that. + verify_load_snapshot(snapshot_file, memory_file); + } + } } } @@ -338,7 +353,7 @@ fn check_snapshot(mut microvm_state: MicrovmState) { fn get_microvm_state_from_snapshot(pci_enabled: bool) -> MicrovmState { // Create a diff snapshot - let (snapshot_file, _) = verify_create_snapshot(true, pci_enabled); + let (snapshot_file, _) = verify_create_snapshot(true, pci_enabled, false); // Deserialize the microVM state. snapshot_file.as_file().seek(SeekFrom::Start(0)).unwrap(); @@ -346,7 +361,7 @@ fn get_microvm_state_from_snapshot(pci_enabled: bool) -> MicrovmState { } fn verify_load_snap_disallowed_after_boot_resources(res: VmmAction, res_name: &str) { - let (snapshot_file, memory_file) = verify_create_snapshot(false, false); + let (snapshot_file, memory_file) = verify_create_snapshot(false, false, false); let mut event_manager = EventManager::new().unwrap(); let empty_seccomp_filters = get_empty_filters();