diff --git a/.buildkite/pipeline_perf.py b/.buildkite/pipeline_perf.py index a64504cc99d..fff3d944f23 100755 --- a/.buildkite/pipeline_perf.py +++ b/.buildkite/pipeline_perf.py @@ -80,6 +80,11 @@ "tests": "integration_tests/performance/test_mmds.py", "devtool_opts": "-c 1-10 -m 0", }, + "memory-hotplug": { + "label": "memory-hotplug", + "tests": "integration_tests/performance/test_hotplug_memory.py", + "devtool_opts": "-c 1-10 -m 0", + }, } REVISION_A = os.environ.get("REVISION_A") diff --git a/Cargo.lock b/Cargo.lock index 6534e429943..516b94b3a5e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -206,6 +206,19 @@ version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "serde", + "tap", + "wyz", +] + [[package]] name = "bumpalo" version = "3.19.0" @@ -568,6 +581,12 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + [[package]] name = "gdbstub" version = "0.7.8" @@ -1013,6 +1032,12 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + [[package]] name = "rand" version = "0.9.2" @@ -1281,6 +1306,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + [[package]] name = "thiserror" version = "1.0.69" @@ -1525,6 +1556,7 @@ dependencies = [ "base64", "bincode", "bitflags 2.10.0", + "bitvec", "byteorder", "crc64", "criterion", @@ -1834,6 +1866,15 @@ version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" +[[package]] +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +dependencies = [ + "tap", +] + [[package]] name = "zerocopy" version = "0.8.27" diff --git a/docs/device-api.md b/docs/device-api.md index f622c2137af..f05b160048d 100644 --- a/docs/device-api.md +++ b/docs/device-api.md @@ -14,23 +14,24 @@ BadRequest - HTTP response. ## API Endpoints -| Endpoint | keyboard | serial console | virtio-block | vhost-user-block | virtio-net | virtio-vsock | virtio-rng | virtio-pmem | -| ------------------------- | :------: | :------------: | :----------: | :--------------: | :--------: | :----------: | :--------: | :---------: | -| `boot-source` | O | O | O | O | O | O | O | O | -| `cpu-config` | O | O | O | O | O | O | O | O | -| `drives/{id}` | O | O | **R** | **R** | O | O | O | O | -| `logger` | O | O | O | O | O | O | O | O | -| `machine-config` | O | O | O | O | O | O | O | O | -| `metrics` | O | O | O | O | O | O | O | O | -| `mmds` | O | O | O | O | **R** | O | O | O | -| `mmds/config` | O | O | O | O | **R** | O | O | O | -| `network-interfaces/{id}` | O | O | O | O | **R** | O | O | O | -| `snapshot/create` | O | O | O | O | O | O | O | O | -| `snapshot/load` | O | O | O | O | O | O | O | O | -| `vm` | O | O | O | O | O | O | O | O | -| `vsock` | O | O | O | O | O | O | O | O | -| `entropy` | O | O | O | O | O | O | **R** | O | -| `pmem/{id}` | O | O | O | O | O | O | O | **R** | +| Endpoint | keyboard | serial console | virtio-block | vhost-user-block | virtio-net | virtio-vsock | virtio-rng | virtio-pmem | virtio-mem | +| ------------------------- | :------: | :------------: | :----------: | :--------------: | :--------: | :----------: | :--------: | :---------: | :--------: | +| `boot-source` | O | O | O | O | O | O | O | O | O | +| `cpu-config` | O | O | O | O | O | O | O | O | O | +| `drives/{id}` | O | O | **R** | **R** | O | O | O | O | O | +| `hotplug/memory` | O | O | O | O | O | O | O | O | **R** | +| `logger` | O | O | O | O | O | O | O | O | O | +| `machine-config` | O | O | O | O | O | O | O | O | O | +| `metrics` | O | O | O | O | O | O | O | O | O | +| `mmds` | O | O | O | O | **R** | O | O | O | O | +| `mmds/config` | O | O | O | O | **R** | O | O | O | O | +| `network-interfaces/{id}` | O | O | O | O | **R** | O | O | O | O | +| `snapshot/create` | O | O | O | O | O | O | O | O | O | +| `snapshot/load` | O | O | O | O | O | O | O | O | O | +| `vm` | O | O | O | O | O | O | O | O | O | +| `vsock` | O | O | O | O | O | O | O | O | O | +| `entropy` | O | O | O | O | O | O | **R** | O | O | +| `pmem/{id}` | O | O | O | O | O | O | O | **R** | O | ## Input Schema @@ -38,73 +39,77 @@ All input schema fields can be found in the [Swagger](https://swagger.io) specification: [firecracker.yaml](./../src/firecracker/swagger/firecracker.yaml). -| Schema | Property | keyboard | serial console | virtio-block | vhost-user-block | virtio-net | virtio-vsock | virtio-rng | virtio-pmem | -| ------------------------- | ------------------ | :------: | :------------: | :----------: | :--------------: | :--------: | :----------: | :--------: | :---------: | -| `BootSource` | boot_args | O | O | O | O | O | O | O | O | -| | initrd_path | O | O | O | O | O | O | O | O | -| | kernel_image_path | O | O | O | O | O | O | O | O | -| `CpuConfig` | cpuid_modifiers | O | O | O | O | O | O | O | O | -| | msr_modifiers | O | O | O | O | O | O | O | O | -| | reg_modifiers | O | O | O | O | O | O | O | O | -| `CpuTemplate` | enum | O | O | O | O | O | O | O | O | -| `CreateSnapshotParams` | mem_file_path | O | O | O | O | O | O | O | O | -| | snapshot_path | O | O | O | O | O | O | O | O | -| | snapshot_type | O | O | O | O | O | O | O | O | -| | version | O | O | O | O | O | O | O | O | -| `Drive` | drive_id \* | O | O | **R** | **R** | O | O | O | O | -| | is_read_only | O | O | **R** | O | O | O | O | O | -| | is_root_device \* | O | O | **R** | **R** | O | O | O | O | -| | partuuid \* | O | O | **R** | **R** | O | O | O | O | -| | path_on_host | O | O | **R** | O | O | O | O | O | -| | rate_limiter | O | O | **R** | O | O | O | O | O | -| | socket | O | O | O | **R** | O | O | O | O | -| `InstanceActionInfo` | action_type | O | O | O | O | O | O | O | O | -| `LoadSnapshotParams` | track_dirty_pages | O | O | O | O | O | O | O | O | -| | mem_file_path | O | O | O | O | O | O | O | O | -| | mem_backend | O | O | O | O | O | O | O | O | -| | snapshot_path | O | O | O | O | O | O | O | O | -| | resume_vm | O | O | O | O | O | O | O | O | -| `Logger` | level | O | O | O | O | O | O | O | O | -| | log_path | O | O | O | O | O | O | O | O | -| | show_level | O | O | O | O | O | O | O | O | -| | show_log_origin | O | O | O | O | O | O | O | O | -| `MachineConfiguration` | cpu_template | O | O | O | O | O | O | O | O | -| | smt | O | O | O | O | O | O | O | O | -| | mem_size_mib | O | O | O | O | O | O | O | O | -| | track_dirty_pages | O | O | O | O | O | O | O | O | -| | vcpu_count | O | O | O | O | O | O | O | O | -| `Metrics` | metrics_path | O | O | O | O | O | O | O | O | -| `MmdsConfig` | network_interfaces | O | O | O | O | **R** | O | O | O | -| | version | O | O | O | O | **R** | O | O | O | -| | ipv4_address | O | O | O | O | **R** | O | O | O | -| | imds_compat | O | O | O | O | O | O | O | O | -| `NetworkInterface` | guest_mac | O | O | O | O | **R** | O | O | O | -| | host_dev_name | O | O | O | O | **R** | O | O | O | -| | iface_id | O | O | O | O | **R** | O | O | O | -| | rx_rate_limiter | O | O | O | O | **R** | O | O | O | -| | tx_rate_limiter | O | O | O | O | **R** | O | O | O | -| `PartialDrive` | drive_id | O | O | **R** | O | O | O | O | O | -| | path_on_host | O | O | **R** | O | O | O | O | O | -| `PartialNetworkInterface` | iface_id | O | O | O | O | **R** | O | O | O | -| | rx_rate_limiter | O | O | O | O | **R** | O | O | O | -| | tx_rate_limiter | O | O | O | O | **R** | O | O | O | -| `RateLimiter` | bandwidth | O | O | O | O | **R** | O | O | O | -| | ops | O | O | **R** | O | O | O | O | O | -| `TokenBucket` \*\* | one_time_burst | O | O | **R** | O | O | O | O | O | -| | refill_time | O | O | **R** | O | O | O | O | O | -| | size | O | O | **R** | O | O | O | O | O | -| `TokenBucket` \*\* | one_time_burst | O | O | O | O | **R** | O | O | O | -| | refill_time | O | O | O | O | **R** | O | O | O | -| | size | O | O | O | O | **R** | O | O | O | -| `Vm` | state | O | O | O | O | O | O | O | O | -| `Vsock` | guest_cid | O | O | O | O | O | **R** | O | O | -| | uds_path | O | O | O | O | O | **R** | O | O | -| | vsock_id | O | O | O | O | O | **R** | O | O | -| `EntropyDevice` | rate_limiter | O | O | O | O | O | O | **R** | O | -| `Pmem` | id | O | O | O | O | O | O | O | **R** | -| | path_on_host | O | O | O | O | O | O | O | **R** | -| | root_device | O | O | O | O | O | O | O | **R** | -| | read_only | O | O | O | O | O | O | O | **R** | +| Schema | Property | keyboard | serial console | virtio-block | vhost-user-block | virtio-net | virtio-vsock | virtio-rng | virtio-pmem | virtio-mem | +| ------------------------- | ------------------ | :------: | :------------: | :----------: | :--------------: | :--------: | :----------: | :--------: | :---------: | :--------: | +| `BootSource` | boot_args | O | O | O | O | O | O | O | O | O | +| | initrd_path | O | O | O | O | O | O | O | O | O | +| | kernel_image_path | O | O | O | O | O | O | O | O | O | +| `CpuConfig` | cpuid_modifiers | O | O | O | O | O | O | O | O | O | +| | msr_modifiers | O | O | O | O | O | O | O | O | O | +| | reg_modifiers | O | O | O | O | O | O | O | O | O | +| `CpuTemplate` | enum | O | O | O | O | O | O | O | O | O | +| `CreateSnapshotParams` | mem_file_path | O | O | O | O | O | O | O | O | O | +| | snapshot_path | O | O | O | O | O | O | O | O | O | +| | snapshot_type | O | O | O | O | O | O | O | O | O | +| | version | O | O | O | O | O | O | O | O | O | +| `Drive` | drive_id \* | O | O | **R** | **R** | O | O | O | O | O | +| | is_read_only | O | O | **R** | O | O | O | O | O | O | +| | is_root_device \* | O | O | **R** | **R** | O | O | O | O | O | +| | partuuid \* | O | O | **R** | **R** | O | O | O | O | O | +| | path_on_host | O | O | **R** | O | O | O | O | O | O | +| | rate_limiter | O | O | **R** | O | O | O | O | O | O | +| | socket | O | O | O | **R** | O | O | O | O | O | +| `InstanceActionInfo` | action_type | O | O | O | O | O | O | O | O | O | +| `LoadSnapshotParams` | track_dirty_pages | O | O | O | O | O | O | O | O | O | +| | mem_file_path | O | O | O | O | O | O | O | O | O | +| | mem_backend | O | O | O | O | O | O | O | O | O | +| | snapshot_path | O | O | O | O | O | O | O | O | O | +| | resume_vm | O | O | O | O | O | O | O | O | O | +| `Logger` | level | O | O | O | O | O | O | O | O | O | +| | log_path | O | O | O | O | O | O | O | O | O | +| | show_level | O | O | O | O | O | O | O | O | O | +| | show_log_origin | O | O | O | O | O | O | O | O | O | +| `MachineConfiguration` | cpu_template | O | O | O | O | O | O | O | O | O | +| | smt | O | O | O | O | O | O | O | O | O | +| | mem_size_mib | O | O | O | O | O | O | O | O | O | +| | track_dirty_pages | O | O | O | O | O | O | O | O | O | +| | vcpu_count | O | O | O | O | O | O | O | O | O | +| `Metrics` | metrics_path | O | O | O | O | O | O | O | O | O | +| `MmdsConfig` | network_interfaces | O | O | O | O | **R** | O | O | O | O | +| | version | O | O | O | O | **R** | O | O | O | O | +| | ipv4_address | O | O | O | O | **R** | O | O | O | O | +| | imds_compat | O | O | O | O | O | O | O | O | O | +| `NetworkInterface` | guest_mac | O | O | O | O | **R** | O | O | O | O | +| | host_dev_name | O | O | O | O | **R** | O | O | O | O | +| | iface_id | O | O | O | O | **R** | O | O | O | O | +| | rx_rate_limiter | O | O | O | O | **R** | O | O | O | O | +| | tx_rate_limiter | O | O | O | O | **R** | O | O | O | O | +| `PartialDrive` | drive_id | O | O | **R** | O | O | O | O | O | O | +| | path_on_host | O | O | **R** | O | O | O | O | O | O | +| `PartialNetworkInterface` | iface_id | O | O | O | O | **R** | O | O | O | O | +| | rx_rate_limiter | O | O | O | O | **R** | O | O | O | O | +| | tx_rate_limiter | O | O | O | O | **R** | O | O | O | O | +| `RateLimiter` | bandwidth | O | O | O | O | **R** | O | O | O | O | +| | ops | O | O | **R** | O | O | O | O | O | O | +| `TokenBucket` \*\* | one_time_burst | O | O | **R** | O | O | O | O | O | O | +| | refill_time | O | O | **R** | O | O | O | O | O | O | +| | size | O | O | **R** | O | O | O | O | O | O | +| `TokenBucket` \*\* | one_time_burst | O | O | O | O | **R** | O | O | O | O | +| | refill_time | O | O | O | O | **R** | O | O | O | O | +| | size | O | O | O | O | **R** | O | O | O | O | +| `Vm` | state | O | O | O | O | O | O | O | O | O | +| `Vsock` | guest_cid | O | O | O | O | O | **R** | O | O | O | +| | uds_path | O | O | O | O | O | **R** | O | O | O | +| | vsock_id | O | O | O | O | O | **R** | O | O | O | +| `EntropyDevice` | rate_limiter | O | O | O | O | O | O | **R** | O | O | +| `Pmem` | id | O | O | O | O | O | O | O | **R** | O | +| | path_on_host | O | O | O | O | O | O | O | **R** | O | +| | root_device | O | O | O | O | O | O | O | **R** | O | +| | read_only | O | O | O | O | O | O | O | **R** | O | +| `MemoryHotplugConfig` | total_size_mib | O | O | O | O | O | O | O | O | **R** | +| | slot_size_mib | O | O | O | O | O | O | O | O | **R** | +| | block_size_mi | O | O | O | O | O | O | O | O | **R** | +| `MemoryHotplugSizeUpdate` | requested_size_mib | O | O | O | O | O | O | O | O | **R** | \* `Drive`'s `drive_id`, `is_root_device` and `partuuid` can be configured by either virtio-block or vhost-user-block devices. @@ -118,18 +123,24 @@ All output schema fields can be found in the [Swagger](https://swagger.io) specification: [firecracker.yaml](./../src/firecracker/swagger/firecracker.yaml). -| Schema | Property | keyboard | serial console | virtio-block | vhost-user-block | virtio-net | virtio-vsock | -| ---------------------- | ----------------- | :------: | :------------: | :----------: | :--------------: | :--------: | :----------: | -| `Error` | fault_message | O | O | O | O | O | O | -| `InstanceInfo` | app_name | O | O | O | O | O | O | -| | id | O | O | O | O | O | O | -| | state | O | O | O | O | O | O | -| | vmm_version | O | O | O | O | O | O | -| `MachineConfiguration` | cpu_template | O | O | O | O | O | O | -| | smt | O | O | O | O | O | O | -| | mem_size_mib | O | O | O | O | O | O | -| | track_dirty_pages | O | O | O | O | O | O | -| | vcpu_count | O | O | O | O | O | O | +| Schema | Property | keyboard | serial console | virtio-block | vhost-user-block | virtio-net | virtio-vsock | virtio-mem | +| ---------------------- | ------------------ | :------: | :------------: | :----------: | :--------------: | :--------: | :----------: | :--------: | +| `Error` | fault_message | O | O | O | O | O | O | O | +| `InstanceInfo` | app_name | O | O | O | O | O | O | O | +| | id | O | O | O | O | O | O | O | +| | state | O | O | O | O | O | O | O | +| | vmm_version | O | O | O | O | O | O | O | +| `MachineConfiguration` | cpu_template | O | O | O | O | O | O | O | +| | smt | O | O | O | O | O | O | O | +| | mem_size_mib | O | O | O | O | O | O | O | +| | track_dirty_pages | O | O | O | O | O | O | O | +| | vcpu_count | O | O | O | O | O | O | O | +| | vmm_version | O | O | O | O | O | O | O | +| `MemoryHotplugStatus ` | total_size_mib | O | O | O | O | O | O | **R** | +| | slot_size_mib | O | O | O | O | O | O | **R** | +| | block_size_mib | O | O | O | O | O | O | **R** | +| | plugged_size_mib | O | O | O | O | O | O | **R** | +| | requested_size_mib | O | O | O | O | O | O | **R** | ## Instance Actions diff --git a/resources/seccomp/aarch64-unknown-linux-musl.json b/resources/seccomp/aarch64-unknown-linux-musl.json index d81a1012599..26dd661e46b 100644 --- a/resources/seccomp/aarch64-unknown-linux-musl.json +++ b/resources/seccomp/aarch64-unknown-linux-musl.json @@ -445,6 +445,18 @@ } ] }, + { + "syscall": "ioctl", + "args": [ + { + "index": 1, + "type": "dword", + "op": "eq", + "val": 1075883590, + "comment": "KVM_SET_USER_MEMORY_REGION, used to (un)plug memory for the virtio-mem device" + } + ] + }, { "syscall": "sched_yield", "comment": "Used by the rust standard library in std::sync::mpmc. Firecracker uses mpsc channels from this module for inter-thread communication" @@ -460,6 +472,10 @@ { "syscall": "restart_syscall", "comment": "automatically issued by the kernel when specific timing-related syscalls (e.g. nanosleep) get interrupted by SIGSTOP" + }, + { + "syscall": "mprotect", + "comment": "Used by memory hotplug to protect access to underlying host memory" } ] }, diff --git a/resources/seccomp/x86_64-unknown-linux-musl.json b/resources/seccomp/x86_64-unknown-linux-musl.json index 66c986495fb..dcd6753a4c5 100644 --- a/resources/seccomp/x86_64-unknown-linux-musl.json +++ b/resources/seccomp/x86_64-unknown-linux-musl.json @@ -457,6 +457,18 @@ } ] }, + { + "syscall": "ioctl", + "args": [ + { + "index": 1, + "type": "dword", + "op": "eq", + "val": 1075883590, + "comment": "KVM_SET_USER_MEMORY_REGION, used to (un)plug memory for the virtio-mem device" + } + ] + }, { "syscall": "sched_yield", "comment": "Used by the rust standard library in std::sync::mpmc. Firecracker uses mpsc channels from this module for inter-thread communication" @@ -472,6 +484,10 @@ { "syscall": "restart_syscall", "comment": "automatically issued by the kernel when specific timing-related syscalls (e.g. nanosleep) get interrupted by SIGSTOP" + }, + { + "syscall": "mprotect", + "comment": "Used by memory hotplug to protect access to underlying host memory" } ] }, diff --git a/src/firecracker/examples/uffd/on_demand_handler.rs b/src/firecracker/examples/uffd/on_demand_handler.rs index 3be958b3578..3101aa19253 100644 --- a/src/firecracker/examples/uffd/on_demand_handler.rs +++ b/src/firecracker/examples/uffd/on_demand_handler.rs @@ -87,7 +87,7 @@ fn main() { } } userfaultfd::Event::Remove { start, end } => { - uffd_handler.mark_range_removed(start as u64, end as u64) + uffd_handler.unregister_range(start, end) } _ => panic!("Unexpected event on userfaultfd"), } diff --git a/src/firecracker/examples/uffd/uffd_utils.rs b/src/firecracker/examples/uffd/uffd_utils.rs index 97c6150b65b..ab28f6f4d2e 100644 --- a/src/firecracker/examples/uffd/uffd_utils.rs +++ b/src/firecracker/examples/uffd/uffd_utils.rs @@ -9,7 +9,7 @@ dead_code )] -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::ffi::c_void; use std::fs::File; use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd}; @@ -54,7 +54,6 @@ pub struct UffdHandler { pub page_size: usize, backing_buffer: *const u8, uffd: Uffd, - removed_pages: HashSet, } impl UffdHandler { @@ -125,7 +124,6 @@ impl UffdHandler { page_size, backing_buffer, uffd, - removed_pages: HashSet::new(), } } @@ -133,24 +131,23 @@ impl UffdHandler { self.uffd.read_event() } - pub fn mark_range_removed(&mut self, start: u64, end: u64) { - let pfn_start = start / self.page_size as u64; - let pfn_end = end / self.page_size as u64; - - for pfn in pfn_start..pfn_end { - self.removed_pages.insert(pfn); - } + pub fn unregister_range(&mut self, start: *mut c_void, end: *mut c_void) { + assert!( + (start as usize).is_multiple_of(self.page_size) + && (end as usize).is_multiple_of(self.page_size) + && end > start + ); + // SAFETY: start and end are valid and provided by UFFD + let len = unsafe { end.offset_from_unsigned(start) }; + self.uffd + .unregister(start, len) + .expect("range should be valid"); } pub fn serve_pf(&mut self, addr: *mut u8, len: usize) -> bool { // Find the start of the page that the current faulting address belongs to. let dst = (addr as usize & !(self.page_size - 1)) as *mut libc::c_void; let fault_page_addr = dst as u64; - let fault_pfn = fault_page_addr / self.page_size as u64; - - if self.removed_pages.contains(&fault_pfn) { - return self.zero_out(fault_page_addr); - } for region in self.mem_regions.iter() { if region.contains(fault_page_addr) { @@ -193,14 +190,6 @@ impl UffdHandler { true } - - fn zero_out(&mut self, addr: u64) -> bool { - match unsafe { self.uffd.zeropage(addr as *mut _, self.page_size, true) } { - Ok(_) => true, - Err(Error::ZeropageFailed(error)) if error as i32 == libc::EAGAIN => false, - r => panic!("Unexpected zeropage result: {:?}", r), - } - } } #[derive(Debug)] diff --git a/src/firecracker/src/api_server/parsed_request.rs b/src/firecracker/src/api_server/parsed_request.rs index 9f1ab870061..f28db7db28d 100644 --- a/src/firecracker/src/api_server/parsed_request.rs +++ b/src/firecracker/src/api_server/parsed_request.rs @@ -28,6 +28,9 @@ use super::request::pmem::parse_put_pmem; use super::request::snapshot::{parse_patch_vm_state, parse_put_snapshot}; use super::request::version::parse_get_version; use super::request::vsock::parse_put_vsock; +use crate::api_server::request::hotplug::memory::{ + parse_get_memory_hotplug, parse_patch_memory_hotplug, parse_put_memory_hotplug, +}; use crate::api_server::request::serial::parse_put_serial; #[derive(Debug)] @@ -85,6 +88,9 @@ impl TryFrom<&Request> for ParsedRequest { } (Method::Get, "machine-config", None) => parse_get_machine_config(), (Method::Get, "mmds", None) => parse_get_mmds(), + (Method::Get, "hotplug", None) if path_tokens.next() == Some("memory") => { + parse_get_memory_hotplug() + } (Method::Get, _, Some(_)) => method_to_error(Method::Get), (Method::Put, "actions", Some(body)) => parse_put_actions(body), (Method::Put, "balloon", Some(body)) => parse_put_balloon(body), @@ -103,6 +109,9 @@ impl TryFrom<&Request> for ParsedRequest { (Method::Put, "snapshot", Some(body)) => parse_put_snapshot(body, path_tokens.next()), (Method::Put, "vsock", Some(body)) => parse_put_vsock(body), (Method::Put, "entropy", Some(body)) => parse_put_entropy(body), + (Method::Put, "hotplug", Some(body)) if path_tokens.next() == Some("memory") => { + parse_put_memory_hotplug(body) + } (Method::Put, _, None) => method_to_error(Method::Put), (Method::Patch, "balloon", Some(body)) => parse_patch_balloon(body, path_tokens.next()), (Method::Patch, "drives", Some(body)) => parse_patch_drive(body, path_tokens.next()), @@ -112,6 +121,9 @@ impl TryFrom<&Request> for ParsedRequest { parse_patch_net(body, path_tokens.next()) } (Method::Patch, "vm", Some(body)) => parse_patch_vm_state(body), + (Method::Patch, "hotplug", Some(body)) if path_tokens.next() == Some("memory") => { + parse_patch_memory_hotplug(body) + } (Method::Patch, _, None) => method_to_error(Method::Patch), (method, unknown_uri, _) => Err(RequestError::InvalidPathMethod( unknown_uri.to_string(), @@ -175,6 +187,7 @@ impl ParsedRequest { Self::success_response_with_data(balloon_config) } VmmData::BalloonStats(stats) => Self::success_response_with_data(stats), + VmmData::VirtioMemStatus(data) => Self::success_response_with_data(data), VmmData::InstanceInformation(info) => Self::success_response_with_data(info), VmmData::VmmVersion(version) => Self::success_response_with_data( &serde_json::json!({ "firecracker_version": version.as_str() }), @@ -559,6 +572,9 @@ pub mod tests { VmmData::BalloonStats(stats) => { http_response(&serde_json::to_string(stats).unwrap(), 200) } + VmmData::VirtioMemStatus(data) => { + http_response(&serde_json::to_string(data).unwrap(), 200) + } VmmData::Empty => http_response("", 204), VmmData::FullVmConfig(cfg) => { http_response(&serde_json::to_string(cfg).unwrap(), 200) diff --git a/src/firecracker/src/api_server/request/hotplug/memory.rs b/src/firecracker/src/api_server/request/hotplug/memory.rs new file mode 100644 index 00000000000..5ec514ca964 --- /dev/null +++ b/src/firecracker/src/api_server/request/hotplug/memory.rs @@ -0,0 +1,118 @@ +// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use micro_http::Body; +use vmm::logger::{IncMetric, METRICS}; +use vmm::rpc_interface::VmmAction; +use vmm::vmm_config::memory_hotplug::{MemoryHotplugConfig, MemoryHotplugSizeUpdate}; + +use crate::api_server::parsed_request::{ParsedRequest, RequestError}; + +pub(crate) fn parse_put_memory_hotplug(body: &Body) -> Result { + METRICS.put_api_requests.hotplug_memory_count.inc(); + let config = serde_json::from_slice::(body.raw()).inspect_err(|_| { + METRICS.put_api_requests.hotplug_memory_fails.inc(); + })?; + Ok(ParsedRequest::new_sync(VmmAction::SetMemoryHotplugDevice( + config, + ))) +} + +pub(crate) fn parse_get_memory_hotplug() -> Result { + METRICS.get_api_requests.hotplug_memory_count.inc(); + Ok(ParsedRequest::new_sync(VmmAction::GetMemoryHotplugStatus)) +} + +pub(crate) fn parse_patch_memory_hotplug(body: &Body) -> Result { + METRICS.patch_api_requests.hotplug_memory_count.inc(); + let config = + serde_json::from_slice::(body.raw()).inspect_err(|_| { + METRICS.patch_api_requests.hotplug_memory_fails.inc(); + })?; + Ok(ParsedRequest::new_sync(VmmAction::UpdateMemoryHotplugSize( + config, + ))) +} + +#[cfg(test)] +mod tests { + use vmm::devices::virtio::mem::{ + VIRTIO_MEM_DEFAULT_BLOCK_SIZE_MIB, VIRTIO_MEM_DEFAULT_SLOT_SIZE_MIB, + }; + use vmm::vmm_config::memory_hotplug::MemoryHotplugSizeUpdate; + + use super::*; + use crate::api_server::parsed_request::tests::vmm_action_from_request; + + #[test] + fn test_parse_put_memory_hotplug_request() { + parse_put_memory_hotplug(&Body::new("invalid_payload")).unwrap_err(); + + // PUT with invalid fields. + let body = r#"{ + "total_size_mib": "bar" + }"#; + parse_put_memory_hotplug(&Body::new(body)).unwrap_err(); + + // PUT with valid input fields with defaults. + let body = r#"{ + "total_size_mib": 2048 + }"#; + let expected_config = MemoryHotplugConfig { + total_size_mib: 2048, + block_size_mib: VIRTIO_MEM_DEFAULT_BLOCK_SIZE_MIB, + slot_size_mib: VIRTIO_MEM_DEFAULT_SLOT_SIZE_MIB, + }; + assert_eq!( + vmm_action_from_request(parse_put_memory_hotplug(&Body::new(body)).unwrap()), + VmmAction::SetMemoryHotplugDevice(expected_config) + ); + + // PUT with valid input fields. + let body = r#"{ + "total_size_mib": 2048, + "block_size_mib": 64, + "slot_size_mib": 64 + }"#; + let expected_config = MemoryHotplugConfig { + total_size_mib: 2048, + block_size_mib: 64, + slot_size_mib: 64, + }; + assert_eq!( + vmm_action_from_request(parse_put_memory_hotplug(&Body::new(body)).unwrap()), + VmmAction::SetMemoryHotplugDevice(expected_config) + ); + } + + #[test] + fn test_parse_parse_get_memory_hotplug_request() { + assert_eq!( + vmm_action_from_request(parse_get_memory_hotplug().unwrap()), + VmmAction::GetMemoryHotplugStatus + ); + } + + #[test] + fn test_parse_patch_memory_hotplug_request() { + parse_patch_memory_hotplug(&Body::new("invalid_payload")).unwrap_err(); + + // PATCH with invalid fields. + let body = r#"{ + "requested_size_mib": "bar" + }"#; + parse_patch_memory_hotplug(&Body::new(body)).unwrap_err(); + + // PATCH with valid input fields. + let body = r#"{ + "requested_size_mib": 2048 + }"#; + let expected_config = MemoryHotplugSizeUpdate { + requested_size_mib: 2048, + }; + assert_eq!( + vmm_action_from_request(parse_patch_memory_hotplug(&Body::new(body)).unwrap()), + VmmAction::UpdateMemoryHotplugSize(expected_config) + ); + } +} diff --git a/src/firecracker/src/api_server/request/hotplug/mod.rs b/src/firecracker/src/api_server/request/hotplug/mod.rs new file mode 100644 index 00000000000..50b97ea2b80 --- /dev/null +++ b/src/firecracker/src/api_server/request/hotplug/mod.rs @@ -0,0 +1,4 @@ +// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +pub mod memory; diff --git a/src/firecracker/src/api_server/request/mod.rs b/src/firecracker/src/api_server/request/mod.rs index 276a89d5a4e..9be4617bd8e 100644 --- a/src/firecracker/src/api_server/request/mod.rs +++ b/src/firecracker/src/api_server/request/mod.rs @@ -7,6 +7,7 @@ pub mod boot_source; pub mod cpu_configuration; pub mod drive; pub mod entropy; +pub mod hotplug; pub mod instance_info; pub mod logger; pub mod machine_configuration; diff --git a/src/firecracker/swagger/firecracker.yaml b/src/firecracker/swagger/firecracker.yaml index 5bf55108b09..ae6d5fe3a55 100644 --- a/src/firecracker/swagger/firecracker.yaml +++ b/src/firecracker/swagger/firecracker.yaml @@ -482,6 +482,7 @@ paths: description: The MMDS data store JSON. schema: type: object + additionalProperties: true 404: description: The MMDS data store content can not be found. schema: @@ -560,6 +561,63 @@ paths: schema: $ref: "#/definitions/Error" + /hotplug/memory: + put: + summary: Configures the hotpluggable memory + operationId: putMemoryHotplug + description: + Configure the hotpluggable memory, which is a virtio-mem device, with an associated memory area + that can be hot(un)plugged in the guest on demand using the PATCH API. + parameters: + - name: body + in: body + description: Hotpluggable memory configuration + required: true + schema: + $ref: "#/definitions/MemoryHotplugConfig" + responses: + 204: + description: Hotpluggable memory configured + default: + description: Internal server error + schema: + $ref: "#/definitions/Error" + patch: + summary: Updates the size of the hotpluggable memory region + operationId: patchMemoryHotplug + description: + Updates the size of the hotpluggable memory region. The guest will plug and unplug memory to + hit the requested memory. + parameters: + - name: body + in: body + description: Hotpluggable memory size update + required: true + schema: + $ref: "#/definitions/MemoryHotplugSizeUpdate" + responses: + 204: + description: Hotpluggable memory configured + default: + description: Internal server error + schema: + $ref: "#/definitions/Error" + get: + summary: Retrieves the status of the hotpluggable memory + operationId: getMemoryHotplug + description: + Reuturn the status of the hotpluggable memory. This can be used to follow the progress of the guest + after a PATCH API. + responses: + 200: + description: OK + schema: + $ref: "#/definitions/MemoryHotplugStatus" + default: + description: Internal server error + schema: + $ref: "#/definitions/Error" + /network-interfaces/{iface_id}: put: summary: Creates a network interface. Pre-boot only. @@ -898,21 +956,119 @@ definitions: The CPU configuration template defines a set of bit maps as modifiers of flags accessed by register to be disabled/enabled for the microvm. properties: + kvm_capabilities: + type: array + description: A collection of KVM capabilities to be added or removed (both x86_64 and aarch64) + items: + type: string + description: KVM capability as a numeric string. Prefix with '!' to remove capability. Example "121" (add) or "!121" (remove) cpuid_modifiers: - type: object - description: A collection of CPUIDs to be modified. (x86_64) + type: array + description: A collection of CPUID leaf modifiers (x86_64 only) + items: + $ref: "#/definitions/CpuidLeafModifier" msr_modifiers: - type: object - description: A collection of model specific registers to be modified. (x86_64) + type: array + description: A collection of model specific register modifiers (x86_64 only) + items: + $ref: "#/definitions/MsrModifier" reg_modifiers: - type: object - description: A collection of registers to be modified. (aarch64) + type: array + description: A collection of register modifiers (aarch64 only) + items: + $ref: "#/definitions/ArmRegisterModifier" vcpu_features: - type: object - description: A collection of vcpu features to be modified. (aarch64) - kvm_capabilities: - type: object - description: A collection of kvm capabilities to be modified. (aarch64) + type: array + description: A collection of vCPU features to be modified (aarch64 only) + items: + $ref: "#/definitions/VcpuFeatures" + + CpuidLeafModifier: + type: object + description: Modifier for a CPUID leaf and subleaf (x86_64) + required: + - leaf + - subleaf + - flags + - modifiers + properties: + leaf: + type: string + description: CPUID leaf index as hex, binary, or decimal string (e.g., "0x0", "0b0", "0")) + subleaf: + type: string + description: CPUID subleaf index as hex, binary, or decimal string (e.g., "0x0", "0b0", "0") + flags: + type: integer + format: int32 + description: KVM feature flags for this leaf-subleaf + modifiers: + type: array + description: Register modifiers for this CPUID leaf + items: + $ref: "#/definitions/CpuidRegisterModifier" + + CpuidRegisterModifier: + type: object + description: Modifier for a specific CPUID register within a leaf (x86_64) + required: + - register + - bitmap + properties: + register: + type: string + description: Target CPUID register name + enum: + - eax + - ebx + - ecx + - edx + bitmap: + type: string + description: 32-bit bitmap string defining which bits to modify. Format is "0b" followed by 32 characters where '0' = clear bit, '1' = set bit, 'x' = don't modify. Example "0b00000000000000000000000000000001" or "0bxxxxxxxxxxxxxxxxxxxxxxxxxxxx0001" + + MsrModifier: + type: object + description: Modifier for a model specific register (x86_64) + required: + - addr + - bitmap + properties: + addr: + type: string + description: 32-bit MSR address as hex, binary, or decimal string (e.g., "0x10a", "0b100001010", "266") + bitmap: + type: string + description: 64-bit bitmap string defining which bits to modify. Format is "0b" followed by 64 characters where '0' = clear bit, '1' = set bit, 'x' = don't modify. Underscores can be used for readability. Example "0b0000000000000000000000000000000000000000000000000000000000000001" + + ArmRegisterModifier: + type: object + description: Modifier for an ARM register (aarch64) + required: + - addr + - bitmap + properties: + addr: + type: string + description: 64-bit register address as hex, binary, or decimal string (e.g., "0x0", "0b0", "0") + bitmap: + type: string + description: 128-bit bitmap string defining which bits to modify. Format is "0b" followed by up to 128 characters where '0' = clear bit, '1' = set bit, 'x' = don't modify. Underscores can be used for readability. Example "0b0000000000000000000000000000000000000000000000000000000000000001" + + VcpuFeatures: + type: object + description: vCPU feature modifier (aarch64) + required: + - index + - bitmap + properties: + index: + type: integer + format: int32 + description: Index in the kvm_vcpu_init.features array + bitmap: + type: string + description: 32-bit bitmap string defining which bits to modify. Format is "0b" followed by 32 characters where '0' = clear bit, '1' = set bit, 'x' = don't modify. Example "0b00000000000000000000000001100000" Drive: type: object @@ -1025,6 +1181,11 @@ definitions: description: Configurations for all net devices. items: $ref: "#/definitions/NetworkInterface" + pmem: + type: array + description: Configurations for all pmem devices. + items: + $ref: "#/definitions/Pmem" vsock: $ref: "#/definitions/Vsock" entropy: @@ -1213,6 +1374,7 @@ definitions: type: object description: Describes the contents of MMDS in JSON format. + additionalProperties: true NetworkInterface: type: object @@ -1416,10 +1578,61 @@ definitions: description: The configuration of the serial device properties: - output_path: + serial_out_path: type: string description: Path to a file or named pipe on the host to which serial output should be written. + MemoryHotplugConfig: + type: object + description: + The configuration of the hotpluggable memory device (virtio-mem) + properties: + total_size_mib: + type: integer + description: Total size of the hotpluggable memory in MiB. + slot_size_mib: + type: integer + default: 128 + minimum: 128 + description: Slot size for the hotpluggable memory in MiB. This will determine the granularity of + hot-plug memory from the host. Refer to the device documentation on how to tune this value. + block_size_mib: + type: integer + default: 2 + minimum: 2 + description: (Logical) Block size for the hotpluggable memory in MiB. This will determine the logical + granularity of hot-plug memory for the guest. Refer to the device documentation on how to tune this value. + + MemoryHotplugSizeUpdate: + type: object + description: + An update to the size of the hotpluggable memory region. + properties: + requested_size_mib: + type: integer + description: New target region size. + + MemoryHotplugStatus: + type: object + description: + The status of the hotpluggable memory device (virtio-mem) + properties: + total_size_mib: + type: integer + description: Total size of the hotpluggable memory in MiB. + slot_size_mib: + type: integer + description: Slot size for the hotpluggable memory in MiB. + block_size_mib: + type: integer + description: (Logical) Block size for the hotpluggable memory in MiB. + plugged_size_mib: + type: integer + description: Plugged size for the hotpluggable memory in MiB. + requested_size_mib: + type: integer + description: Requested size for the hotpluggable memory in MiB. + FirecrackerVersion: type: object description: diff --git a/src/vmm/Cargo.toml b/src/vmm/Cargo.toml index 3b1f78a9005..dde0bb4e211 100644 --- a/src/vmm/Cargo.toml +++ b/src/vmm/Cargo.toml @@ -21,6 +21,7 @@ aws-lc-rs = { version = "1.14.1", features = ["bindgen"] } base64 = "0.22.1" bincode = { version = "2.0.1", features = ["serde"] } bitflags = "2.10.0" +bitvec = { version = "1.0.1", features = ["atomic", "serde"] } byteorder = "1.5.0" crc64 = "2.0.0" derive_more = { version = "2.0.1", default-features = false, features = [ diff --git a/src/vmm/src/builder.rs b/src/vmm/src/builder.rs index fb57080e5b7..6d36056933f 100644 --- a/src/vmm/src/builder.rs +++ b/src/vmm/src/builder.rs @@ -13,7 +13,7 @@ use event_manager::SubscriberOps; use linux_loader::cmdline::Cmdline as LoaderKernelCmdline; use userfaultfd::Uffd; use utils::time::TimestampUs; -#[cfg(target_arch = "aarch64")] +use vm_allocator::AllocPolicy; use vm_memory::GuestAddress; #[cfg(target_arch = "aarch64")] @@ -31,6 +31,7 @@ use crate::device_manager::{ }; use crate::devices::virtio::balloon::Balloon; use crate::devices::virtio::block::device::Block; +use crate::devices::virtio::mem::{VIRTIO_MEM_DEFAULT_SLOT_SIZE_MIB, VirtioMem}; use crate::devices::virtio::net::Net; use crate::devices::virtio::pmem::device::Pmem; use crate::devices::virtio::rng::Entropy; @@ -43,8 +44,10 @@ use crate::persist::{MicrovmState, MicrovmStateError}; use crate::resources::VmResources; use crate::seccomp::BpfThreadMap; use crate::snapshot::Persist; +use crate::utils::mib_to_bytes; use crate::vmm_config::instance_info::InstanceInfo; use crate::vmm_config::machine_config::MachineConfigError; +use crate::vmm_config::memory_hotplug::MemoryHotplugConfig; use crate::vstate::kvm::{Kvm, KvmError}; use crate::vstate::memory::GuestRegionMmap; #[cfg(target_arch = "aarch64")] @@ -170,6 +173,22 @@ pub fn build_microvm_for_boot( let (mut vcpus, vcpus_exit_evt) = vm.create_vcpus(vm_resources.machine_config.vcpu_count)?; vm.register_dram_memory_regions(guest_memory)?; + // Allocate memory as soon as possible to make hotpluggable memory available to all consumers, + // before they clone the GuestMemoryMmap object + let virtio_mem_addr = if let Some(memory_hotplug) = &vm_resources.memory_hotplug { + let addr = allocate_virtio_mem_address(&vm, memory_hotplug.total_size_mib)?; + let hotplug_memory_region = vm_resources + .allocate_memory_region(addr, mib_to_bytes(memory_hotplug.total_size_mib)) + .map_err(StartMicrovmError::GuestMemory)?; + vm.register_hotpluggable_memory_region( + hotplug_memory_region, + mib_to_bytes(memory_hotplug.slot_size_mib), + )?; + Some(addr) + } else { + None + }; + let mut device_manager = DeviceManager::new( event_manager, &vcpus_exit_evt, @@ -247,6 +266,18 @@ pub fn build_microvm_for_boot( )?; } + // Attach virtio-mem device if configured + if let Some(memory_hotplug) = &vm_resources.memory_hotplug { + attach_virtio_mem_device( + &mut device_manager, + &vm, + &mut boot_cmdline, + memory_hotplug, + event_manager, + virtio_mem_addr.expect("address should be allocated"), + )?; + } + #[cfg(target_arch = "aarch64")] device_manager.attach_legacy_devices_aarch64( &vm, @@ -573,6 +604,47 @@ fn attach_entropy_device( device_manager.attach_virtio_device(vm, id, entropy_device.clone(), cmdline, false) } +fn allocate_virtio_mem_address( + vm: &Vm, + total_size_mib: usize, +) -> Result { + let addr = vm + .resource_allocator() + .past_mmio64_memory + .allocate( + mib_to_bytes(total_size_mib) as u64, + mib_to_bytes(VIRTIO_MEM_DEFAULT_SLOT_SIZE_MIB) as u64, + AllocPolicy::FirstMatch, + )? + .start(); + Ok(GuestAddress(addr)) +} + +fn attach_virtio_mem_device( + device_manager: &mut DeviceManager, + vm: &Arc, + cmdline: &mut LoaderKernelCmdline, + config: &MemoryHotplugConfig, + event_manager: &mut EventManager, + addr: GuestAddress, +) -> Result<(), StartMicrovmError> { + let virtio_mem = Arc::new(Mutex::new( + VirtioMem::new( + Arc::clone(vm), + addr, + config.total_size_mib, + config.block_size_mib, + config.slot_size_mib, + ) + .map_err(|e| StartMicrovmError::Internal(VmmError::VirtioMem(e)))?, + )); + + let id = virtio_mem.lock().expect("Poisoned lock").id().to_string(); + event_manager.add_subscriber(virtio_mem.clone()); + device_manager.attach_virtio_device(vm, id, virtio_mem.clone(), cmdline, false)?; + Ok(()) +} + fn attach_block_devices<'a, I: Iterator>> + Debug>( device_manager: &mut DeviceManager, vm: &Arc, @@ -1280,4 +1352,43 @@ pub(crate) mod tests { "virtio_mmio.device=4K@0xc0001000:5" )); } + + pub(crate) fn insert_virtio_mem_device( + vmm: &mut Vmm, + cmdline: &mut Cmdline, + event_manager: &mut EventManager, + config: MemoryHotplugConfig, + ) { + attach_virtio_mem_device( + &mut vmm.device_manager, + &vmm.vm, + cmdline, + &config, + event_manager, + GuestAddress(512 << 30), + ) + .unwrap(); + } + + #[test] + fn test_attach_virtio_mem_device() { + let mut event_manager = EventManager::new().expect("Unable to create EventManager"); + let mut vmm = default_vmm(); + + let config = MemoryHotplugConfig { + total_size_mib: 1024, + block_size_mib: 2, + slot_size_mib: 128, + }; + + let mut cmdline = default_kernel_cmdline(); + insert_virtio_mem_device(&mut vmm, &mut cmdline, &mut event_manager, config); + + // Check if the vsock device is described in kernel_cmdline. + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + assert!(cmdline_contains( + &cmdline, + "virtio_mmio.device=4K@0xc0001000:5" + )); + } } diff --git a/src/vmm/src/device_manager/pci_mngr.rs b/src/vmm/src/device_manager/pci_mngr.rs index 6c89540381a..a2270097686 100644 --- a/src/vmm/src/device_manager/pci_mngr.rs +++ b/src/vmm/src/device_manager/pci_mngr.rs @@ -18,6 +18,8 @@ use crate::devices::virtio::block::device::Block; use crate::devices::virtio::block::persist::{BlockConstructorArgs, BlockState}; use crate::devices::virtio::device::VirtioDevice; use crate::devices::virtio::generated::virtio_ids; +use crate::devices::virtio::mem::VirtioMem; +use crate::devices::virtio::mem::persist::{VirtioMemConstructorArgs, VirtioMemState}; use crate::devices::virtio::net::Net; use crate::devices::virtio::net::persist::{NetConstructorArgs, NetState}; use crate::devices::virtio::pmem::device::Pmem; @@ -34,6 +36,7 @@ use crate::devices::virtio::vsock::{Vsock, VsockUnixBackend}; use crate::pci::bus::PciRootError; use crate::resources::VmResources; use crate::snapshot::Persist; +use crate::vmm_config::memory_hotplug::MemoryHotplugConfig; use crate::vmm_config::mmds::MmdsConfigError; use crate::vstate::bus::BusError; use crate::vstate::interrupts::InterruptError; @@ -242,6 +245,8 @@ pub struct PciDevicesState { pub entropy_device: Option>, /// Pmem device states. pub pmem_devices: Vec>, + /// Memory device state. + pub memory_device: Option>, } pub struct PciDevicesConstructorArgs<'a> { @@ -401,6 +406,20 @@ impl<'a> Persist<'a> for PciDevices { transport_state, }); } + virtio_ids::VIRTIO_ID_MEM => { + let mem_dev = locked_virtio_dev + .as_mut_any() + .downcast_mut::() + .unwrap(); + let device_state = mem_dev.save(); + + state.memory_device = Some(VirtioDeviceState { + device_id: mem_dev.id().to_string(), + pci_device_bdf, + device_state, + transport_state, + }) + } _ => unreachable!(), } } @@ -604,6 +623,28 @@ impl<'a> Persist<'a> for PciDevices { .unwrap() } + if let Some(memory_device) = &state.memory_device { + let ctor_args = VirtioMemConstructorArgs::new(Arc::clone(constructor_args.vm)); + let device = VirtioMem::restore(ctor_args, &memory_device.device_state).unwrap(); + + constructor_args.vm_resources.memory_hotplug = Some(MemoryHotplugConfig { + total_size_mib: device.total_size_mib(), + block_size_mib: device.block_size_mib(), + slot_size_mib: device.slot_size_mib(), + }); + + let arcd_device = Arc::new(Mutex::new(device)); + pci_devices + .restore_pci_device( + constructor_args.vm, + arcd_device, + &memory_device.device_id, + &memory_device.transport_state, + constructor_args.event_manager, + ) + .unwrap() + } + Ok(pci_devices) } } @@ -621,6 +662,7 @@ mod tests { use crate::snapshot::Snapshot; use crate::vmm_config::balloon::BalloonDeviceConfig; use crate::vmm_config::entropy::EntropyDeviceConfig; + use crate::vmm_config::memory_hotplug::MemoryHotplugConfig; use crate::vmm_config::net::NetworkInterfaceConfig; use crate::vmm_config::pmem::PmemConfig; use crate::vmm_config::vsock::VsockDeviceConfig; @@ -695,6 +737,18 @@ mod tests { _pmem_files = insert_pmem_devices(&mut vmm, &mut cmdline, &mut event_manager, pmem_configs); + let memory_hotplug_config = MemoryHotplugConfig { + total_size_mib: 1024, + block_size_mib: 2, + slot_size_mib: 128, + }; + insert_virtio_mem_device( + &mut vmm, + &mut cmdline, + &mut event_manager, + memory_hotplug_config, + ); + Snapshot::new(vmm.device_manager.save()) .save(&mut buf.as_mut_slice()) .unwrap(); @@ -789,7 +843,12 @@ mod tests { "root_device": true, "read_only": true }} - ] + ], + "memory-hotplug": {{ + "total_size_mib": 1024, + "block_size_mib": 2, + "slot_size_mib": 128 + }} }}"#, _block_files.last().unwrap().as_path().to_str().unwrap(), tmp_sock_file.as_path().to_str().unwrap(), diff --git a/src/vmm/src/device_manager/persist.rs b/src/vmm/src/device_manager/persist.rs index fa83aae9e37..eca7b73b25b 100644 --- a/src/vmm/src/device_manager/persist.rs +++ b/src/vmm/src/device_manager/persist.rs @@ -28,6 +28,10 @@ use crate::devices::virtio::block::device::Block; use crate::devices::virtio::block::persist::{BlockConstructorArgs, BlockState}; use crate::devices::virtio::device::VirtioDevice; use crate::devices::virtio::generated::virtio_ids; +use crate::devices::virtio::mem::VirtioMem; +use crate::devices::virtio::mem::persist::{ + VirtioMemConstructorArgs, VirtioMemPersistError, VirtioMemState, +}; use crate::devices::virtio::net::Net; use crate::devices::virtio::net::persist::{ NetConstructorArgs, NetPersistError as NetError, NetState, @@ -49,6 +53,7 @@ use crate::devices::virtio::vsock::{Vsock, VsockError, VsockUnixBackend, VsockUn use crate::mmds::data_store::MmdsVersion; use crate::resources::VmResources; use crate::snapshot::Persist; +use crate::vmm_config::memory_hotplug::MemoryHotplugConfig; use crate::vmm_config::mmds::MmdsConfigError; use crate::vstate::bus::BusError; use crate::vstate::memory::GuestMemoryMmap; @@ -82,6 +87,8 @@ pub enum DevicePersistError { Entropy(#[from] EntropyError), /// Pmem: {0} Pmem(#[from] PmemError), + /// virtio-mem: {0} + VirtioMem(#[from] VirtioMemPersistError), /// Could not activate device: {0} DeviceActivation(#[from] ActivateError), } @@ -135,6 +142,8 @@ pub struct DeviceStates { pub entropy_device: Option>, /// Pmem device states. pub pmem_devices: Vec>, + /// Memory device state. + pub memory_device: Option>, } pub struct MMIODevManagerConstructorArgs<'a> { @@ -328,6 +337,20 @@ impl<'a> Persist<'a> for MMIODeviceManager { device_info, }) } + virtio_ids::VIRTIO_ID_MEM => { + let mem = locked_device + .as_mut_any() + .downcast_mut::() + .unwrap(); + let device_state = mem.save(); + + states.memory_device = Some(VirtioDeviceState { + device_id, + device_state, + transport_state, + device_info, + }); + } _ => unreachable!(), }; @@ -570,6 +593,30 @@ impl<'a> Persist<'a> for MMIODeviceManager { )?; } + if let Some(memory_state) = &state.memory_device { + let ctor_args = VirtioMemConstructorArgs::new(Arc::clone(vm)); + let device = VirtioMem::restore(ctor_args, &memory_state.device_state)?; + + constructor_args.vm_resources.memory_hotplug = Some(MemoryHotplugConfig { + total_size_mib: device.total_size_mib(), + block_size_mib: device.block_size_mib(), + slot_size_mib: device.slot_size_mib(), + }); + + let arcd_device = Arc::new(Mutex::new(device)); + + restore_helper( + arcd_device.clone(), + memory_state.device_state.virtio_state.activated, + false, + arcd_device, + &memory_state.device_id, + &memory_state.transport_state, + &memory_state.device_info, + constructor_args.event_manager, + )?; + } + Ok(dev_manager) } } @@ -586,6 +633,7 @@ mod tests { use crate::snapshot::Snapshot; use crate::vmm_config::balloon::BalloonDeviceConfig; use crate::vmm_config::entropy::EntropyDeviceConfig; + use crate::vmm_config::memory_hotplug::MemoryHotplugConfig; use crate::vmm_config::net::NetworkInterfaceConfig; use crate::vmm_config::pmem::PmemConfig; use crate::vmm_config::vsock::VsockDeviceConfig; @@ -604,6 +652,7 @@ mod tests { && self.net_devices == other.net_devices && self.vsock_device == other.vsock_device && self.entropy_device == other.entropy_device + && self.memory_device == other.memory_device } } @@ -699,6 +748,18 @@ mod tests { _pmem_files = insert_pmem_devices(&mut vmm, &mut cmdline, &mut event_manager, pmem_configs); + let memory_hotplug_config = MemoryHotplugConfig { + total_size_mib: 1024, + block_size_mib: 2, + slot_size_mib: 128, + }; + insert_virtio_mem_device( + &mut vmm, + &mut cmdline, + &mut event_manager, + memory_hotplug_config, + ); + Snapshot::new(vmm.device_manager.save()) .save(&mut buf.as_mut_slice()) .unwrap(); @@ -789,7 +850,12 @@ mod tests { "root_device": true, "read_only": true }} - ] + ], + "memory-hotplug": {{ + "total_size_mib": 1024, + "block_size_mib": 2, + "slot_size_mib": 128 + }} }}"#, _block_files.last().unwrap().as_path().to_str().unwrap(), tmp_sock_file.as_path().to_str().unwrap(), diff --git a/src/vmm/src/devices/virtio/generated/mod.rs b/src/vmm/src/devices/virtio/generated/mod.rs index a9d1f08f88f..712284e9359 100644 --- a/src/vmm/src/devices/virtio/generated/mod.rs +++ b/src/vmm/src/devices/virtio/generated/mod.rs @@ -13,5 +13,6 @@ pub mod virtio_blk; pub mod virtio_config; pub mod virtio_ids; +pub mod virtio_mem; pub mod virtio_net; pub mod virtio_ring; diff --git a/src/vmm/src/devices/virtio/generated/virtio_mem.rs b/src/vmm/src/devices/virtio/generated/virtio_mem.rs new file mode 100644 index 00000000000..e3c25ff0bb9 --- /dev/null +++ b/src/vmm/src/devices/virtio/generated/virtio_mem.rs @@ -0,0 +1,244 @@ +// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +// automatically generated by tools/bindgen.sh + +#![allow( + non_camel_case_types, + non_upper_case_globals, + dead_code, + non_snake_case, + clippy::ptr_as_ptr, + clippy::undocumented_unsafe_blocks, + missing_debug_implementations, + clippy::tests_outside_test_module, + unsafe_op_in_unsafe_fn, + clippy::redundant_static_lifetimes +)] + +pub const VIRTIO_MEM_F_ACPI_PXM: u32 = 0; +pub const VIRTIO_MEM_F_UNPLUGGED_INACCESSIBLE: u32 = 1; +pub const VIRTIO_MEM_F_PERSISTENT_SUSPEND: u32 = 2; +pub const VIRTIO_MEM_REQ_PLUG: u32 = 0; +pub const VIRTIO_MEM_REQ_UNPLUG: u32 = 1; +pub const VIRTIO_MEM_REQ_UNPLUG_ALL: u32 = 2; +pub const VIRTIO_MEM_REQ_STATE: u32 = 3; +pub const VIRTIO_MEM_RESP_ACK: u32 = 0; +pub const VIRTIO_MEM_RESP_NACK: u32 = 1; +pub const VIRTIO_MEM_RESP_BUSY: u32 = 2; +pub const VIRTIO_MEM_RESP_ERROR: u32 = 3; +pub const VIRTIO_MEM_STATE_PLUGGED: u32 = 0; +pub const VIRTIO_MEM_STATE_UNPLUGGED: u32 = 1; +pub const VIRTIO_MEM_STATE_MIXED: u32 = 2; +pub type __u8 = ::std::os::raw::c_uchar; +pub type __u16 = ::std::os::raw::c_ushort; +pub type __u64 = ::std::os::raw::c_ulonglong; +pub type __le16 = __u16; +pub type __le64 = __u64; +pub type __virtio16 = __u16; +pub type __virtio64 = __u64; +#[repr(C)] +#[derive(Debug, Default, Copy, Clone, PartialEq)] +pub struct virtio_mem_req_plug { + pub addr: __virtio64, + pub nb_blocks: __virtio16, + pub padding: [__virtio16; 3usize], +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of virtio_mem_req_plug"][::std::mem::size_of::() - 16usize]; + ["Alignment of virtio_mem_req_plug"][::std::mem::align_of::() - 8usize]; + ["Offset of field: virtio_mem_req_plug::addr"] + [::std::mem::offset_of!(virtio_mem_req_plug, addr) - 0usize]; + ["Offset of field: virtio_mem_req_plug::nb_blocks"] + [::std::mem::offset_of!(virtio_mem_req_plug, nb_blocks) - 8usize]; + ["Offset of field: virtio_mem_req_plug::padding"] + [::std::mem::offset_of!(virtio_mem_req_plug, padding) - 10usize]; +}; +#[repr(C)] +#[derive(Debug, Default, Copy, Clone, PartialEq)] +pub struct virtio_mem_req_unplug { + pub addr: __virtio64, + pub nb_blocks: __virtio16, + pub padding: [__virtio16; 3usize], +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of virtio_mem_req_unplug"][::std::mem::size_of::() - 16usize]; + ["Alignment of virtio_mem_req_unplug"] + [::std::mem::align_of::() - 8usize]; + ["Offset of field: virtio_mem_req_unplug::addr"] + [::std::mem::offset_of!(virtio_mem_req_unplug, addr) - 0usize]; + ["Offset of field: virtio_mem_req_unplug::nb_blocks"] + [::std::mem::offset_of!(virtio_mem_req_unplug, nb_blocks) - 8usize]; + ["Offset of field: virtio_mem_req_unplug::padding"] + [::std::mem::offset_of!(virtio_mem_req_unplug, padding) - 10usize]; +}; +#[repr(C)] +#[derive(Debug, Default, Copy, Clone, PartialEq)] +pub struct virtio_mem_req_state { + pub addr: __virtio64, + pub nb_blocks: __virtio16, + pub padding: [__virtio16; 3usize], +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of virtio_mem_req_state"][::std::mem::size_of::() - 16usize]; + ["Alignment of virtio_mem_req_state"][::std::mem::align_of::() - 8usize]; + ["Offset of field: virtio_mem_req_state::addr"] + [::std::mem::offset_of!(virtio_mem_req_state, addr) - 0usize]; + ["Offset of field: virtio_mem_req_state::nb_blocks"] + [::std::mem::offset_of!(virtio_mem_req_state, nb_blocks) - 8usize]; + ["Offset of field: virtio_mem_req_state::padding"] + [::std::mem::offset_of!(virtio_mem_req_state, padding) - 10usize]; +}; +#[repr(C)] +#[derive(Copy, Clone)] +pub struct virtio_mem_req { + pub type_: __virtio16, + pub padding: [__virtio16; 3usize], + pub u: virtio_mem_req__bindgen_ty_1, +} +#[repr(C)] +#[derive(Copy, Clone)] +pub union virtio_mem_req__bindgen_ty_1 { + pub plug: virtio_mem_req_plug, + pub unplug: virtio_mem_req_unplug, + pub state: virtio_mem_req_state, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of virtio_mem_req__bindgen_ty_1"] + [::std::mem::size_of::() - 16usize]; + ["Alignment of virtio_mem_req__bindgen_ty_1"] + [::std::mem::align_of::() - 8usize]; + ["Offset of field: virtio_mem_req__bindgen_ty_1::plug"] + [::std::mem::offset_of!(virtio_mem_req__bindgen_ty_1, plug) - 0usize]; + ["Offset of field: virtio_mem_req__bindgen_ty_1::unplug"] + [::std::mem::offset_of!(virtio_mem_req__bindgen_ty_1, unplug) - 0usize]; + ["Offset of field: virtio_mem_req__bindgen_ty_1::state"] + [::std::mem::offset_of!(virtio_mem_req__bindgen_ty_1, state) - 0usize]; +}; +impl Default for virtio_mem_req__bindgen_ty_1 { + fn default() -> Self { + let mut s = ::std::mem::MaybeUninit::::uninit(); + unsafe { + ::std::ptr::write_bytes(s.as_mut_ptr(), 0, 1); + s.assume_init() + } + } +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of virtio_mem_req"][::std::mem::size_of::() - 24usize]; + ["Alignment of virtio_mem_req"][::std::mem::align_of::() - 8usize]; + ["Offset of field: virtio_mem_req::type_"] + [::std::mem::offset_of!(virtio_mem_req, type_) - 0usize]; + ["Offset of field: virtio_mem_req::padding"] + [::std::mem::offset_of!(virtio_mem_req, padding) - 2usize]; + ["Offset of field: virtio_mem_req::u"][::std::mem::offset_of!(virtio_mem_req, u) - 8usize]; +}; +impl Default for virtio_mem_req { + fn default() -> Self { + let mut s = ::std::mem::MaybeUninit::::uninit(); + unsafe { + ::std::ptr::write_bytes(s.as_mut_ptr(), 0, 1); + s.assume_init() + } + } +} +#[repr(C)] +#[derive(Debug, Default, Copy, Clone, PartialEq)] +pub struct virtio_mem_resp_state { + pub state: __virtio16, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of virtio_mem_resp_state"][::std::mem::size_of::() - 2usize]; + ["Alignment of virtio_mem_resp_state"] + [::std::mem::align_of::() - 2usize]; + ["Offset of field: virtio_mem_resp_state::state"] + [::std::mem::offset_of!(virtio_mem_resp_state, state) - 0usize]; +}; +#[repr(C)] +#[derive(Copy, Clone)] +pub struct virtio_mem_resp { + pub type_: __virtio16, + pub padding: [__virtio16; 3usize], + pub u: virtio_mem_resp__bindgen_ty_1, +} +#[repr(C)] +#[derive(Copy, Clone)] +pub union virtio_mem_resp__bindgen_ty_1 { + pub state: virtio_mem_resp_state, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of virtio_mem_resp__bindgen_ty_1"] + [::std::mem::size_of::() - 2usize]; + ["Alignment of virtio_mem_resp__bindgen_ty_1"] + [::std::mem::align_of::() - 2usize]; + ["Offset of field: virtio_mem_resp__bindgen_ty_1::state"] + [::std::mem::offset_of!(virtio_mem_resp__bindgen_ty_1, state) - 0usize]; +}; +impl Default for virtio_mem_resp__bindgen_ty_1 { + fn default() -> Self { + let mut s = ::std::mem::MaybeUninit::::uninit(); + unsafe { + ::std::ptr::write_bytes(s.as_mut_ptr(), 0, 1); + s.assume_init() + } + } +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of virtio_mem_resp"][::std::mem::size_of::() - 10usize]; + ["Alignment of virtio_mem_resp"][::std::mem::align_of::() - 2usize]; + ["Offset of field: virtio_mem_resp::type_"] + [::std::mem::offset_of!(virtio_mem_resp, type_) - 0usize]; + ["Offset of field: virtio_mem_resp::padding"] + [::std::mem::offset_of!(virtio_mem_resp, padding) - 2usize]; + ["Offset of field: virtio_mem_resp::u"][::std::mem::offset_of!(virtio_mem_resp, u) - 8usize]; +}; +impl Default for virtio_mem_resp { + fn default() -> Self { + let mut s = ::std::mem::MaybeUninit::::uninit(); + unsafe { + ::std::ptr::write_bytes(s.as_mut_ptr(), 0, 1); + s.assume_init() + } + } +} +#[repr(C)] +#[derive(Debug, Default, Copy, Clone, PartialEq)] +pub struct virtio_mem_config { + pub block_size: __le64, + pub node_id: __le16, + pub padding: [__u8; 6usize], + pub addr: __le64, + pub region_size: __le64, + pub usable_region_size: __le64, + pub plugged_size: __le64, + pub requested_size: __le64, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of virtio_mem_config"][::std::mem::size_of::() - 56usize]; + ["Alignment of virtio_mem_config"][::std::mem::align_of::() - 8usize]; + ["Offset of field: virtio_mem_config::block_size"] + [::std::mem::offset_of!(virtio_mem_config, block_size) - 0usize]; + ["Offset of field: virtio_mem_config::node_id"] + [::std::mem::offset_of!(virtio_mem_config, node_id) - 8usize]; + ["Offset of field: virtio_mem_config::padding"] + [::std::mem::offset_of!(virtio_mem_config, padding) - 10usize]; + ["Offset of field: virtio_mem_config::addr"] + [::std::mem::offset_of!(virtio_mem_config, addr) - 16usize]; + ["Offset of field: virtio_mem_config::region_size"] + [::std::mem::offset_of!(virtio_mem_config, region_size) - 24usize]; + ["Offset of field: virtio_mem_config::usable_region_size"] + [::std::mem::offset_of!(virtio_mem_config, usable_region_size) - 32usize]; + ["Offset of field: virtio_mem_config::plugged_size"] + [::std::mem::offset_of!(virtio_mem_config, plugged_size) - 40usize]; + ["Offset of field: virtio_mem_config::requested_size"] + [::std::mem::offset_of!(virtio_mem_config, requested_size) - 48usize]; +}; diff --git a/src/vmm/src/devices/virtio/mem/device.rs b/src/vmm/src/devices/virtio/mem/device.rs new file mode 100644 index 00000000000..adce0331541 --- /dev/null +++ b/src/vmm/src/devices/virtio/mem/device.rs @@ -0,0 +1,1350 @@ +// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::io; +use std::ops::{Deref, Range}; +use std::sync::Arc; +use std::sync::atomic::AtomicU32; + +use bitvec::vec::BitVec; +use log::info; +use serde::{Deserialize, Serialize}; +use vm_memory::{ + Address, Bytes, GuestAddress, GuestMemory, GuestMemoryError, GuestMemoryRegion, GuestUsize, +}; +use vmm_sys_util::eventfd::EventFd; + +use super::{MEM_NUM_QUEUES, MEM_QUEUE}; +use crate::devices::virtio::ActivateError; +use crate::devices::virtio::device::{ActiveState, DeviceState, VirtioDevice}; +use crate::devices::virtio::generated::virtio_config::VIRTIO_F_VERSION_1; +use crate::devices::virtio::generated::virtio_ids::VIRTIO_ID_MEM; +use crate::devices::virtio::generated::virtio_mem::{ + self, VIRTIO_MEM_F_UNPLUGGED_INACCESSIBLE, virtio_mem_config, +}; +use crate::devices::virtio::iov_deque::IovDequeError; +use crate::devices::virtio::mem::VIRTIO_MEM_DEV_ID; +use crate::devices::virtio::mem::metrics::METRICS; +use crate::devices::virtio::mem::request::{BlockRangeState, Request, RequestedRange, Response}; +use crate::devices::virtio::queue::{ + DescriptorChain, FIRECRACKER_MAX_QUEUE_SIZE, InvalidAvailIdx, Queue, QueueError, +}; +use crate::devices::virtio::transport::{VirtioInterrupt, VirtioInterruptType}; +use crate::logger::{IncMetric, debug, error}; +use crate::utils::{bytes_to_mib, mib_to_bytes, u64_to_usize, usize_to_u64}; +use crate::vstate::interrupts::InterruptError; +use crate::vstate::memory::{ + ByteValued, GuestMemoryExtension, GuestMemoryMmap, GuestRegionMmap, GuestRegionType, +}; +use crate::vstate::vm::VmError; +use crate::{Vm, impl_device_type}; + +// SAFETY: virtio_mem_config only contains plain data types +unsafe impl ByteValued for virtio_mem_config {} + +#[derive(Debug, thiserror::Error, displaydoc::Display)] +pub enum VirtioMemError { + /// Error while handling an Event file descriptor: {0} + EventFd(#[from] io::Error), + /// Received error while sending an interrupt: {0} + InterruptError(#[from] InterruptError), + /// Size {0} is invalid: it must be a multiple of block size and less than the total size + InvalidSize(u64), + /// Device is not active + DeviceNotActive, + /// Descriptor is write-only + UnexpectedWriteOnlyDescriptor, + /// Error reading virtio descriptor + DescriptorWriteFailed, + /// Error writing virtio descriptor + DescriptorReadFailed, + /// Unknown request type: {0} + UnknownRequestType(u32), + /// Descriptor chain is too short + DescriptorChainTooShort, + /// Descriptor is too small + DescriptorLengthTooSmall, + /// Descriptor is read-only + UnexpectedReadOnlyDescriptor, + /// Error popping from virtio queue: {0} + InvalidAvailIdx(#[from] InvalidAvailIdx), + /// Error adding used queue: {0} + QueueError(#[from] QueueError), + /// Invalid requested range: {0:?}. + InvalidRange(RequestedRange), + /// The requested range cannot be plugged because it's {0:?}. + PlugRequestBlockStateInvalid(BlockRangeState), + /// Plug request rejected as plugged_size would be greater than requested_size + PlugRequestIsTooBig, + /// The requested range cannot be unplugged because it's {0:?}. + UnplugRequestBlockStateInvalid(BlockRangeState), + /// There was an error updating the KVM slot. + UpdateKvmSlot(VmError), +} + +#[derive(Debug)] +pub struct VirtioMem { + // VirtIO fields + avail_features: u64, + acked_features: u64, + activate_event: EventFd, + + // Transport fields + device_state: DeviceState, + pub(crate) queues: Vec, + queue_events: Vec, + + // Device specific fields + pub(crate) config: virtio_mem_config, + pub(crate) slot_size: usize, + // Bitmap to track which blocks are plugged + pub(crate) plugged_blocks: BitVec, + vm: Arc, +} + +/// Memory hotplug device status information. +#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] +pub struct VirtioMemStatus { + /// Block size in MiB. + pub block_size_mib: usize, + /// Total memory size in MiB that can be hotplugged. + pub total_size_mib: usize, + /// Size of the KVM slots in MiB. + pub slot_size_mib: usize, + /// Currently plugged memory size in MiB. + pub plugged_size_mib: usize, + /// Requested memory size in MiB. + pub requested_size_mib: usize, +} + +impl VirtioMem { + pub fn new( + vm: Arc, + addr: GuestAddress, + total_size_mib: usize, + block_size_mib: usize, + slot_size_mib: usize, + ) -> Result { + let queues = vec![Queue::new(FIRECRACKER_MAX_QUEUE_SIZE); MEM_NUM_QUEUES]; + let config = virtio_mem_config { + addr: addr.raw_value(), + region_size: mib_to_bytes(total_size_mib) as u64, + block_size: mib_to_bytes(block_size_mib) as u64, + ..Default::default() + }; + let plugged_blocks = BitVec::repeat(false, total_size_mib / block_size_mib); + + Self::from_state( + vm, + queues, + config, + mib_to_bytes(slot_size_mib), + plugged_blocks, + ) + } + + pub fn from_state( + vm: Arc, + queues: Vec, + config: virtio_mem_config, + slot_size: usize, + plugged_blocks: BitVec, + ) -> Result { + let activate_event = EventFd::new(libc::EFD_NONBLOCK)?; + let queue_events = (0..MEM_NUM_QUEUES) + .map(|_| EventFd::new(libc::EFD_NONBLOCK)) + .collect::, io::Error>>()?; + + Ok(Self { + avail_features: (1 << VIRTIO_F_VERSION_1) | (1 << VIRTIO_MEM_F_UNPLUGGED_INACCESSIBLE), + acked_features: 0u64, + activate_event, + device_state: DeviceState::Inactive, + queues, + queue_events, + config, + vm, + slot_size, + plugged_blocks, + }) + } + + pub fn id(&self) -> &str { + VIRTIO_MEM_DEV_ID + } + + pub fn guest_address(&self) -> GuestAddress { + GuestAddress(self.config.addr) + } + + /// Gets the total hotpluggable size. + pub fn total_size_mib(&self) -> usize { + bytes_to_mib(u64_to_usize(self.config.region_size)) + } + + /// Gets the block size. + pub fn block_size_mib(&self) -> usize { + bytes_to_mib(u64_to_usize(self.config.block_size)) + } + + /// Gets the block size. + pub fn slot_size_mib(&self) -> usize { + bytes_to_mib(self.slot_size) + } + + /// Gets the total size of the plugged memory blocks. + pub fn plugged_size_mib(&self) -> usize { + bytes_to_mib(u64_to_usize(self.config.plugged_size)) + } + + /// Gets the requested size + pub fn requested_size_mib(&self) -> usize { + bytes_to_mib(u64_to_usize(self.config.requested_size)) + } + + pub fn status(&self) -> VirtioMemStatus { + VirtioMemStatus { + block_size_mib: self.block_size_mib(), + total_size_mib: self.total_size_mib(), + slot_size_mib: self.slot_size_mib(), + plugged_size_mib: self.plugged_size_mib(), + requested_size_mib: self.requested_size_mib(), + } + } + + fn signal_used_queue(&self) -> Result<(), VirtioMemError> { + self.interrupt_trigger() + .trigger(VirtioInterruptType::Queue(MEM_QUEUE.try_into().unwrap())) + .map_err(VirtioMemError::InterruptError) + } + + fn guest_memory(&self) -> &GuestMemoryMmap { + &self.device_state.active_state().unwrap().mem + } + + fn nb_blocks_to_len(&self, nb_blocks: usize) -> usize { + nb_blocks * u64_to_usize(self.config.block_size) + } + + /// Returns the state of all the blocks in the given range. + /// + /// Note: the range passed to this function must be within the device memory to avoid + /// out-of-bound panics. + fn range_state(&self, range: &RequestedRange) -> BlockRangeState { + let plugged_count = self.plugged_blocks[self.unchecked_block_range(range)].count_ones(); + + match plugged_count { + nb_blocks if nb_blocks == range.nb_blocks => BlockRangeState::Plugged, + 0 => BlockRangeState::Unplugged, + _ => BlockRangeState::Mixed, + } + } + + fn parse_request( + &self, + avail_desc: &DescriptorChain, + ) -> Result<(Request, GuestAddress, u16), VirtioMemError> { + // The head contains the request type which MUST be readable. + if avail_desc.is_write_only() { + return Err(VirtioMemError::UnexpectedWriteOnlyDescriptor); + } + + if (avail_desc.len as usize) < size_of::() { + return Err(VirtioMemError::DescriptorLengthTooSmall); + } + + let request: virtio_mem::virtio_mem_req = self + .guest_memory() + .read_obj(avail_desc.addr) + .map_err(|_| VirtioMemError::DescriptorReadFailed)?; + + let resp_desc = avail_desc + .next_descriptor() + .ok_or(VirtioMemError::DescriptorChainTooShort)?; + + // The response MUST always be writable. + if !resp_desc.is_write_only() { + return Err(VirtioMemError::UnexpectedReadOnlyDescriptor); + } + + if (resp_desc.len as usize) < std::mem::size_of::() { + return Err(VirtioMemError::DescriptorLengthTooSmall); + } + + Ok((request.into(), resp_desc.addr, avail_desc.index)) + } + + fn write_response( + &mut self, + resp: Response, + resp_addr: GuestAddress, + used_idx: u16, + ) -> Result<(), VirtioMemError> { + debug!("virtio-mem: Response: {:?}", resp); + self.guest_memory() + .write_obj(virtio_mem::virtio_mem_resp::from(resp), resp_addr) + .map_err(|_| VirtioMemError::DescriptorWriteFailed) + .map(|_| size_of::())?; + self.queues[MEM_QUEUE] + .add_used( + used_idx, + u32::try_from(std::mem::size_of::()).unwrap(), + ) + .map_err(VirtioMemError::QueueError) + } + + /// Checks that the range provided by the driver is within the usable memory region + fn validate_range(&self, range: &RequestedRange) -> Result<(), VirtioMemError> { + // Ensure the range is aligned + if !range + .addr + .raw_value() + .is_multiple_of(self.config.block_size) + { + return Err(VirtioMemError::InvalidRange(*range)); + } + + if range.nb_blocks == 0 { + return Err(VirtioMemError::InvalidRange(*range)); + } + + // Ensure the start addr is within the usable region + let start_off = range + .addr + .checked_offset_from(self.guest_address()) + .filter(|&off| off < self.config.usable_region_size) + .ok_or(VirtioMemError::InvalidRange(*range))?; + + // Ensure the end offset (exclusive) is within the usable region + let end_off = start_off + .checked_add(usize_to_u64(self.nb_blocks_to_len(range.nb_blocks))) + .filter(|&end_off| end_off <= self.config.usable_region_size) + .ok_or(VirtioMemError::InvalidRange(*range))?; + + Ok(()) + } + + fn unchecked_block_range(&self, range: &RequestedRange) -> Range { + let start_block = u64_to_usize((range.addr.0 - self.config.addr) / self.config.block_size); + + start_block..(start_block + range.nb_blocks) + } + + fn process_plug_request(&mut self, range: &RequestedRange) -> Result<(), VirtioMemError> { + self.validate_range(range)?; + + if self.config.plugged_size + usize_to_u64(self.nb_blocks_to_len(range.nb_blocks)) + > self.config.requested_size + { + return Err(VirtioMemError::PlugRequestIsTooBig); + } + + match self.range_state(range) { + // the range was validated + BlockRangeState::Unplugged => self.update_range(range, true), + state => Err(VirtioMemError::PlugRequestBlockStateInvalid(state)), + } + } + + fn handle_plug_request( + &mut self, + range: &RequestedRange, + resp_addr: GuestAddress, + used_idx: u16, + ) -> Result<(), VirtioMemError> { + METRICS.plug_count.inc(); + let _metric = METRICS.plug_agg.record_latency_metrics(); + + let response = match self.process_plug_request(range) { + Err(err) => { + METRICS.plug_fails.inc(); + error!("virtio-mem: Failed to plug range: {}", err); + Response::error() + } + Ok(_) => { + METRICS + .plug_bytes + .add(usize_to_u64(self.nb_blocks_to_len(range.nb_blocks))); + Response::ack() + } + }; + self.write_response(response, resp_addr, used_idx) + } + + fn process_unplug_request(&mut self, range: &RequestedRange) -> Result<(), VirtioMemError> { + self.validate_range(range)?; + + match self.range_state(range) { + // the range was validated + BlockRangeState::Plugged => self.update_range(range, false), + state => Err(VirtioMemError::UnplugRequestBlockStateInvalid(state)), + } + } + + fn handle_unplug_request( + &mut self, + range: &RequestedRange, + resp_addr: GuestAddress, + used_idx: u16, + ) -> Result<(), VirtioMemError> { + METRICS.unplug_count.inc(); + let _metric = METRICS.unplug_agg.record_latency_metrics(); + let response = match self.process_unplug_request(range) { + Err(err) => { + METRICS.unplug_fails.inc(); + error!("virtio-mem: Failed to unplug range: {}", err); + Response::error() + } + Ok(_) => { + METRICS + .unplug_bytes + .add(usize_to_u64(self.nb_blocks_to_len(range.nb_blocks))); + Response::ack() + } + }; + self.write_response(response, resp_addr, used_idx) + } + + fn handle_unplug_all_request( + &mut self, + resp_addr: GuestAddress, + used_idx: u16, + ) -> Result<(), VirtioMemError> { + METRICS.unplug_all_count.inc(); + let _metric = METRICS.unplug_all_agg.record_latency_metrics(); + let range = RequestedRange { + addr: self.guest_address(), + nb_blocks: self.plugged_blocks.len(), + }; + let response = match self.update_range(&range, false) { + Err(err) => { + METRICS.unplug_all_fails.inc(); + error!("virtio-mem: Failed to unplug all: {}", err); + Response::error() + } + Ok(_) => { + self.config.usable_region_size = 0; + Response::ack() + } + }; + self.write_response(response, resp_addr, used_idx) + } + + fn handle_state_request( + &mut self, + range: &RequestedRange, + resp_addr: GuestAddress, + used_idx: u16, + ) -> Result<(), VirtioMemError> { + METRICS.state_count.inc(); + let _metric = METRICS.state_agg.record_latency_metrics(); + let response = match self.validate_range(range) { + Err(err) => { + METRICS.state_fails.inc(); + error!("virtio-mem: Failed to retrieve state of range: {}", err); + Response::error() + } + // the range was validated + Ok(_) => Response::ack_with_state(self.range_state(range)), + }; + self.write_response(response, resp_addr, used_idx) + } + + fn process_mem_queue(&mut self) -> Result<(), VirtioMemError> { + while let Some(desc) = self.queues[MEM_QUEUE].pop()? { + let index = desc.index; + + let (req, resp_addr, used_idx) = self.parse_request(&desc)?; + debug!("virtio-mem: Request: {:?}", req); + // Handle request and write response + match req { + Request::State(ref range) => self.handle_state_request(range, resp_addr, used_idx), + Request::Plug(ref range) => self.handle_plug_request(range, resp_addr, used_idx), + Request::Unplug(ref range) => { + self.handle_unplug_request(range, resp_addr, used_idx) + } + Request::UnplugAll => self.handle_unplug_all_request(resp_addr, used_idx), + Request::Unsupported(t) => Err(VirtioMemError::UnknownRequestType(t)), + }?; + } + + self.queues[MEM_QUEUE].advance_used_ring_idx(); + self.signal_used_queue()?; + + Ok(()) + } + + pub(crate) fn process_mem_queue_event(&mut self) { + METRICS.queue_event_count.inc(); + if let Err(err) = self.queue_events[MEM_QUEUE].read() { + METRICS.queue_event_fails.inc(); + error!("Failed to read mem queue event: {err}"); + return; + } + + if let Err(err) = self.process_mem_queue() { + METRICS.queue_event_fails.inc(); + error!("virtio-mem: Failed to process queue: {err}"); + } + } + + pub fn process_virtio_queues(&mut self) -> Result<(), VirtioMemError> { + self.process_mem_queue() + } + + pub(crate) fn set_avail_features(&mut self, features: u64) { + self.avail_features = features; + } + + pub(crate) fn set_acked_features(&mut self, features: u64) { + self.acked_features = features; + } + + pub(crate) fn activate_event(&self) -> &EventFd { + &self.activate_event + } + + fn update_kvm_slots(&self, updated_range: &RequestedRange) -> Result<(), VirtioMemError> { + let hp_region = self + .guest_memory() + .iter() + .find(|r| r.region_type == GuestRegionType::Hotpluggable) + .expect("there should be one and only one hotpluggable region"); + hp_region + .slots_intersecting_range( + updated_range.addr, + self.nb_blocks_to_len(updated_range.nb_blocks), + ) + .try_for_each(|slot| { + let slot_range = RequestedRange { + addr: slot.guest_addr, + nb_blocks: slot.slice.len() / u64_to_usize(self.config.block_size), + }; + match self.range_state(&slot_range) { + BlockRangeState::Mixed | BlockRangeState::Plugged => { + hp_region.update_slot(&self.vm, &slot, true) + } + BlockRangeState::Unplugged => hp_region.update_slot(&self.vm, &slot, false), + } + .map_err(VirtioMemError::UpdateKvmSlot) + }) + } + + /// Plugs/unplugs the given range + /// + /// Note: the range passed to this function must be within the device memory to avoid + /// out-of-bound panics. + fn update_range(&mut self, range: &RequestedRange, plug: bool) -> Result<(), VirtioMemError> { + // Update internal state + let block_range = self.unchecked_block_range(range); + let plugged_blocks_slice = &mut self.plugged_blocks[block_range]; + let plugged_before = plugged_blocks_slice.count_ones(); + plugged_blocks_slice.fill(plug); + let plugged_after = plugged_blocks_slice.count_ones(); + self.config.plugged_size -= usize_to_u64(self.nb_blocks_to_len(plugged_before)); + self.config.plugged_size += usize_to_u64(self.nb_blocks_to_len(plugged_after)); + + // If unplugging, discard the range + if !plug { + self.guest_memory() + .discard_range(range.addr, self.nb_blocks_to_len(range.nb_blocks)) + .inspect_err(|err| { + // Failure to discard is not fatal and is not reported to the driver. It only + // gets logged. + METRICS.unplug_discard_fails.inc(); + error!("virtio-mem: Failed to discard memory range: {}", err); + }); + } + + self.update_kvm_slots(range) + } + + /// Updates the requested size of the virtio-mem device. + pub fn update_requested_size( + &mut self, + requested_size_mib: usize, + ) -> Result<(), VirtioMemError> { + let requested_size = usize_to_u64(mib_to_bytes(requested_size_mib)); + if !self.is_activated() { + return Err(VirtioMemError::DeviceNotActive); + } + + if requested_size % self.config.block_size != 0 { + return Err(VirtioMemError::InvalidSize(requested_size)); + } + if requested_size > self.config.region_size { + return Err(VirtioMemError::InvalidSize(requested_size)); + } + + // Increase the usable_region_size if it's not enough for the guest to plug new + // memory blocks. + // The device cannot decrease the usable_region_size unless the guest requests + // to reset it with an UNPLUG_ALL request. + if self.config.usable_region_size < requested_size { + self.config.usable_region_size = + requested_size.next_multiple_of(usize_to_u64(self.slot_size)); + debug!( + "virtio-mem: Updated usable size to {} bytes", + self.config.usable_region_size + ); + } + + self.config.requested_size = requested_size; + debug!( + "virtio-mem: Updated requested size to {} bytes", + requested_size + ); + self.interrupt_trigger() + .trigger(VirtioInterruptType::Config) + .map_err(VirtioMemError::InterruptError) + } +} + +impl VirtioDevice for VirtioMem { + impl_device_type!(VIRTIO_ID_MEM); + + fn queues(&self) -> &[Queue] { + &self.queues + } + + fn queues_mut(&mut self) -> &mut [Queue] { + &mut self.queues + } + + fn queue_events(&self) -> &[EventFd] { + &self.queue_events + } + + fn interrupt_trigger(&self) -> &dyn VirtioInterrupt { + self.device_state + .active_state() + .expect("Device is not activated") + .interrupt + .deref() + } + + fn avail_features(&self) -> u64 { + self.avail_features + } + + fn acked_features(&self) -> u64 { + self.acked_features + } + + fn set_acked_features(&mut self, acked_features: u64) { + self.acked_features = acked_features; + } + + fn read_config(&self, offset: u64, data: &mut [u8]) { + let offset = u64_to_usize(offset); + self.config + .as_slice() + .get(offset..offset + data.len()) + .map(|s| data.copy_from_slice(s)) + .unwrap_or_else(|| { + error!( + "virtio-mem: Config read offset+length {offset}+{} out of bounds", + data.len() + ) + }) + } + + fn write_config(&mut self, offset: u64, _data: &[u8]) { + error!("virtio-mem: Attempted write to read-only config space at offset {offset}"); + } + + fn is_activated(&self) -> bool { + self.device_state.is_activated() + } + + fn activate( + &mut self, + mem: GuestMemoryMmap, + interrupt: Arc, + ) -> Result<(), ActivateError> { + if (self.acked_features & (1 << VIRTIO_MEM_F_UNPLUGGED_INACCESSIBLE)) == 0 { + error!( + "virtio-mem: VIRTIO_MEM_F_UNPLUGGED_INACCESSIBLE feature not acknowledged by guest" + ); + METRICS.activate_fails.inc(); + return Err(ActivateError::RequiredFeatureNotAcked( + "VIRTIO_MEM_F_UNPLUGGED_INACCESSIBLE", + )); + } + + for q in self.queues.iter_mut() { + q.initialize(&mem) + .map_err(ActivateError::QueueMemoryError)?; + } + + self.device_state = DeviceState::Activated(ActiveState { mem, interrupt }); + if self.activate_event.write(1).is_err() { + METRICS.activate_fails.inc(); + self.device_state = DeviceState::Inactive; + return Err(ActivateError::EventFd); + } + + Ok(()) + } +} + +#[cfg(test)] +pub(crate) mod test_utils { + use super::*; + use crate::devices::virtio::test_utils::test::VirtioTestDevice; + use crate::test_utils::single_region_mem; + use crate::vmm_config::machine_config::HugePageConfig; + use crate::vstate::memory; + use crate::vstate::vm::tests::setup_vm_with_memory; + + impl VirtioTestDevice for VirtioMem { + fn set_queues(&mut self, queues: Vec) { + self.queues = queues; + } + + fn num_queues() -> usize { + MEM_NUM_QUEUES + } + } + + pub(crate) fn default_virtio_mem() -> VirtioMem { + let (_, mut vm) = setup_vm_with_memory(0x1000); + let addr = GuestAddress(512 << 30); + vm.register_hotpluggable_memory_region( + memory::anonymous( + std::iter::once((addr, mib_to_bytes(1024))), + false, + HugePageConfig::None, + ) + .unwrap() + .pop() + .unwrap(), + mib_to_bytes(128), + ); + let vm = Arc::new(vm); + VirtioMem::new(vm, addr, 1024, 2, 128).unwrap() + } +} + +#[cfg(test)] +mod tests { + use std::ptr::null_mut; + + use serde_json::de; + use vm_memory::guest_memory; + use vm_memory::mmap::MmapRegionBuilder; + + use super::*; + use crate::devices::virtio::device::VirtioDevice; + use crate::devices::virtio::mem::device::test_utils::default_virtio_mem; + use crate::devices::virtio::queue::VIRTQ_DESC_F_WRITE; + use crate::devices::virtio::test_utils::test::VirtioTestHelper; + use crate::vstate::vm::tests::setup_vm_with_memory; + + #[test] + fn test_new() { + let mem = default_virtio_mem(); + + assert_eq!(mem.total_size_mib(), 1024); + assert_eq!(mem.block_size_mib(), 2); + assert_eq!(mem.plugged_size_mib(), 0); + assert_eq!(mem.id(), VIRTIO_MEM_DEV_ID); + assert_eq!(mem.device_type(), VIRTIO_ID_MEM); + + let features = (1 << VIRTIO_F_VERSION_1) | (1 << VIRTIO_MEM_F_UNPLUGGED_INACCESSIBLE); + assert_eq!(mem.avail_features(), features); + assert_eq!(mem.acked_features(), 0); + + assert!(!mem.is_activated()); + + assert_eq!(mem.queues().len(), MEM_NUM_QUEUES); + assert_eq!(mem.queue_events().len(), MEM_NUM_QUEUES); + } + + #[test] + fn test_from_state() { + let (_, vm) = setup_vm_with_memory(0x1000); + let vm = Arc::new(vm); + let queues = vec![Queue::new(FIRECRACKER_MAX_QUEUE_SIZE); MEM_NUM_QUEUES]; + let addr = 512 << 30; + let region_size_mib = 2048; + let block_size_mib = 2; + let slot_size_mib = 128; + let plugged_size_mib = 512; + let usable_region_size = mib_to_bytes(1024) as u64; + let config = virtio_mem_config { + addr, + region_size: mib_to_bytes(region_size_mib) as u64, + block_size: mib_to_bytes(block_size_mib) as u64, + plugged_size: mib_to_bytes(plugged_size_mib) as u64, + usable_region_size, + ..Default::default() + }; + let plugged_blocks = BitVec::repeat( + false, + mib_to_bytes(region_size_mib) / mib_to_bytes(block_size_mib), + ); + let mem = VirtioMem::from_state( + vm, + queues, + config, + mib_to_bytes(slot_size_mib), + plugged_blocks, + ) + .unwrap(); + assert_eq!(mem.config.addr, addr); + assert_eq!(mem.total_size_mib(), region_size_mib); + assert_eq!(mem.block_size_mib(), block_size_mib); + assert_eq!(mem.slot_size_mib(), slot_size_mib); + assert_eq!(mem.plugged_size_mib(), plugged_size_mib); + assert_eq!(mem.config.usable_region_size, usable_region_size); + } + + #[test] + fn test_read_config() { + let mem = default_virtio_mem(); + let mut data = [0u8; 8]; + + mem.read_config(0, &mut data); + assert_eq!( + u64::from_le_bytes(data), + mib_to_bytes(mem.block_size_mib()) as u64 + ); + + mem.read_config(16, &mut data); + assert_eq!(u64::from_le_bytes(data), 512 << 30); + + mem.read_config(24, &mut data); + assert_eq!( + u64::from_le_bytes(data), + mib_to_bytes(mem.total_size_mib()) as u64 + ); + } + + #[test] + fn test_read_config_out_of_bounds() { + let mem = default_virtio_mem(); + + let mut data = [0u8; 8]; + let config_size = std::mem::size_of::(); + mem.read_config(config_size as u64, &mut data); + assert_eq!(data, [0u8; 8]); // Should remain unchanged + + let mut data = vec![0u8; config_size]; + mem.read_config(8, &mut data); + assert_eq!(data, vec![0u8; config_size]); // Should remain unchanged + } + + #[test] + fn test_write_config() { + let mut mem = default_virtio_mem(); + let data = [1u8; 8]; + mem.write_config(0, &data); // Should log error but not crash + + // should not change config + let mut data = [0u8; 8]; + mem.read_config(0, &mut data); + let block_size = u64::from_le_bytes(data); + assert_eq!(block_size, mib_to_bytes(2) as u64); + } + + #[test] + fn test_set_features() { + let mut mem = default_virtio_mem(); + mem.set_avail_features(123); + assert_eq!(mem.avail_features(), 123); + mem.set_acked_features(456); + assert_eq!(mem.acked_features(), 456); + } + + #[test] + fn test_status() { + let mut mem = default_virtio_mem(); + let status = mem.status(); + assert_eq!( + status, + VirtioMemStatus { + block_size_mib: 2, + total_size_mib: 1024, + slot_size_mib: 128, + plugged_size_mib: 0, + requested_size_mib: 0, + } + ); + } + + #[allow(clippy::cast_possible_truncation)] + const REQ_SIZE: u32 = std::mem::size_of::() as u32; + #[allow(clippy::cast_possible_truncation)] + const RESP_SIZE: u32 = std::mem::size_of::() as u32; + + fn test_helper<'a>( + mut dev: VirtioMem, + mem: &'a GuestMemoryMmap, + ) -> VirtioTestHelper<'a, VirtioMem> { + dev.set_acked_features(dev.avail_features); + + let mut th = VirtioTestHelper::::new(mem, dev); + th.activate_device(mem); + th + } + + fn emulate_request( + th: &mut VirtioTestHelper, + mem: &GuestMemoryMmap, + req: Request, + ) -> Response { + th.add_desc_chain( + MEM_QUEUE, + 0, + &[(0, REQ_SIZE, 0), (1, RESP_SIZE, VIRTQ_DESC_F_WRITE)], + ); + mem.write_obj( + virtio_mem::virtio_mem_req::from(req), + th.desc_address(MEM_QUEUE, 0), + ) + .unwrap(); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + mem.read_obj::(th.desc_address(MEM_QUEUE, 1)) + .unwrap() + .into() + } + + #[test] + fn test_event_fail_descriptor_chain_too_short() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + + th.add_desc_chain(MEM_QUEUE, 0, &[(0, REQ_SIZE, 0)]); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails + 1); + } + + #[test] + fn test_event_fail_descriptor_length_too_small() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + + th.add_desc_chain(MEM_QUEUE, 0, &[(0, 1, 0)]); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails + 1); + } + + #[test] + fn test_event_fail_unexpected_writeonly_descriptor() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + + th.add_desc_chain(MEM_QUEUE, 0, &[(0, REQ_SIZE, VIRTQ_DESC_F_WRITE)]); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails + 1); + } + + #[test] + fn test_event_fail_unexpected_readonly_descriptor() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + + th.add_desc_chain(MEM_QUEUE, 0, &[(0, REQ_SIZE, 0), (1, RESP_SIZE, 0)]); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails + 1); + } + + #[test] + fn test_event_fail_response_descriptor_length_too_small() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + + th.add_desc_chain( + MEM_QUEUE, + 0, + &[(0, REQ_SIZE, 0), (1, 1, VIRTQ_DESC_F_WRITE)], + ); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails + 1); + } + + #[test] + fn test_update_requested_size_device_not_active() { + let mut mem_dev = default_virtio_mem(); + let result = mem_dev.update_requested_size(512); + assert!(matches!(result, Err(VirtioMemError::DeviceNotActive))); + } + + #[test] + fn test_update_requested_size_invalid_size() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + // Size not multiple of block size + let result = th.device().update_requested_size(3); + assert!(matches!(result, Err(VirtioMemError::InvalidSize(_)))); + + // Size too large + let result = th.device().update_requested_size(2048); + assert!(matches!(result, Err(VirtioMemError::InvalidSize(_)))); + } + + #[test] + fn test_update_requested_size_success() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + th.device().update_requested_size(512).unwrap(); + assert_eq!(th.device().requested_size_mib(), 512); + } + + #[test] + fn test_plug_request_success() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + let plug_count = METRICS.plug_count.count(); + let plug_bytes = METRICS.plug_bytes.count(); + let plug_fails = METRICS.plug_fails.count(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_ack()); + assert_eq!(th.device().plugged_size_mib(), 2); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails); + assert_eq!(METRICS.plug_count.count(), plug_count + 1); + assert_eq!(METRICS.plug_bytes.count(), plug_bytes + (2 << 20)); + assert_eq!(METRICS.plug_fails.count(), plug_fails); + } + + #[test] + fn test_plug_request_too_big() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(2); + let addr = th.device().guest_address(); + + let plug_count = METRICS.plug_count.count(); + let plug_bytes = METRICS.plug_bytes.count(); + let plug_fails = METRICS.plug_fails.count(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 2 }), + ); + assert!(resp.is_error()); + + assert_eq!(METRICS.plug_count.count(), plug_count + 1); + assert_eq!(METRICS.plug_bytes.count(), plug_bytes); + assert_eq!(METRICS.plug_fails.count(), plug_fails + 1); + } + + #[test] + fn test_plug_request_already_plugged() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + // First plug succeeds + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_ack()); + + // Second plug fails + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_error()); + } + + #[test] + fn test_unplug_request_success() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + let unplug_count = METRICS.unplug_count.count(); + let unplug_bytes = METRICS.unplug_bytes.count(); + let unplug_fails = METRICS.unplug_fails.count(); + + // First plug + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_ack()); + assert_eq!(th.device().plugged_size_mib(), 2); + + // Then unplug + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Unplug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_ack()); + assert_eq!(th.device().plugged_size_mib(), 0); + + assert_eq!(METRICS.unplug_count.count(), unplug_count + 1); + assert_eq!(METRICS.unplug_bytes.count(), unplug_bytes + (2 << 20)); + assert_eq!(METRICS.unplug_fails.count(), unplug_fails); + } + + #[test] + fn test_unplug_request_not_plugged() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + let unplug_count = METRICS.unplug_count.count(); + let unplug_bytes = METRICS.unplug_bytes.count(); + let unplug_fails = METRICS.unplug_fails.count(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Unplug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_error()); + + assert_eq!(METRICS.unplug_count.count(), unplug_count + 1); + assert_eq!(METRICS.unplug_bytes.count(), unplug_bytes); + assert_eq!(METRICS.unplug_fails.count(), unplug_fails + 1); + } + + #[test] + fn test_unplug_all_request() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + let unplug_all_count = METRICS.unplug_all_count.count(); + let unplug_all_fails = METRICS.unplug_all_fails.count(); + + // Plug some blocks + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 2 }), + ); + assert!(resp.is_ack()); + assert_eq!(th.device().plugged_size_mib(), 4); + + // Unplug all + let resp = emulate_request(&mut th, &guest_mem, Request::UnplugAll); + assert!(resp.is_ack()); + assert_eq!(th.device().plugged_size_mib(), 0); + + assert_eq!(METRICS.unplug_all_count.count(), unplug_all_count + 1); + assert_eq!(METRICS.unplug_all_fails.count(), unplug_all_fails); + } + + #[test] + fn test_state_request_unplugged() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + let state_count = METRICS.state_count.count(); + let state_fails = METRICS.state_fails.count(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::State(RequestedRange { addr, nb_blocks: 1 }), + ); + assert_eq!(resp, Response::ack_with_state(BlockRangeState::Unplugged)); + + assert_eq!(METRICS.state_count.count(), state_count + 1); + assert_eq!(METRICS.state_fails.count(), state_fails); + } + + #[test] + fn test_state_request_plugged() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + // Plug first + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_ack()); + + // Check state + let resp = emulate_request( + &mut th, + &guest_mem, + Request::State(RequestedRange { addr, nb_blocks: 1 }), + ); + assert_eq!(resp, Response::ack_with_state(BlockRangeState::Plugged)); + } + + #[test] + fn test_state_request_mixed() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + // Plug first block only + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_ack()); + + // Check state of 2 blocks (one plugged, one unplugged) + let resp = emulate_request( + &mut th, + &guest_mem, + Request::State(RequestedRange { addr, nb_blocks: 2 }), + ); + assert_eq!(resp, Response::ack_with_state(BlockRangeState::Mixed)); + } + + #[test] + fn test_invalid_range_unaligned() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address().unchecked_add(1); + + let state_count = METRICS.state_count.count(); + let state_fails = METRICS.state_fails.count(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::State(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_error()); + + assert_eq!(METRICS.state_count.count(), state_count + 1); + assert_eq!(METRICS.state_fails.count(), state_fails + 1); + } + + #[test] + fn test_invalid_range_zero_blocks() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::State(RequestedRange { addr, nb_blocks: 0 }), + ); + assert!(resp.is_error()); + } + + #[test] + fn test_invalid_range_out_of_bounds() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(4); + let addr = th.device().guest_address(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::State(RequestedRange { + addr, + nb_blocks: 1024, + }), + ); + assert!(resp.is_error()); + } + + #[test] + fn test_unsupported_request() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + + th.add_desc_chain( + MEM_QUEUE, + 0, + &[(0, REQ_SIZE, 0), (1, RESP_SIZE, VIRTQ_DESC_F_WRITE)], + ); + guest_mem + .write_obj( + virtio_mem::virtio_mem_req::from(Request::Unsupported(999)), + th.desc_address(MEM_QUEUE, 0), + ) + .unwrap(); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails + 1); + } +} diff --git a/src/vmm/src/devices/virtio/mem/event_handler.rs b/src/vmm/src/devices/virtio/mem/event_handler.rs new file mode 100644 index 00000000000..ae81ef12a64 --- /dev/null +++ b/src/vmm/src/devices/virtio/mem/event_handler.rs @@ -0,0 +1,192 @@ +// Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use event_manager::{EventOps, Events, MutEventSubscriber}; +use vmm_sys_util::epoll::EventSet; + +use crate::devices::virtio::device::VirtioDevice; +use crate::devices::virtio::mem::MEM_QUEUE; +use crate::devices::virtio::mem::device::VirtioMem; +use crate::logger::{error, warn}; + +impl VirtioMem { + const PROCESS_ACTIVATE: u32 = 0; + const PROCESS_MEM_QUEUE: u32 = 1; + + fn register_runtime_events(&self, ops: &mut EventOps) { + if let Err(err) = ops.add(Events::with_data( + &self.queue_events()[MEM_QUEUE], + Self::PROCESS_MEM_QUEUE, + EventSet::IN, + )) { + error!("virtio-mem: Failed to register queue event: {err}"); + } + } + + fn register_activate_event(&self, ops: &mut EventOps) { + if let Err(err) = ops.add(Events::with_data( + self.activate_event(), + Self::PROCESS_ACTIVATE, + EventSet::IN, + )) { + error!("virtio-mem: Failed to register activate event: {err}"); + } + } + + fn process_activate_event(&self, ops: &mut EventOps) { + if let Err(err) = self.activate_event().read() { + error!("virtio-mem: Failed to consume activate event: {err}"); + } + + // Register runtime events + self.register_runtime_events(ops); + + // Remove activate event + if let Err(err) = ops.remove(Events::with_data( + self.activate_event(), + Self::PROCESS_ACTIVATE, + EventSet::IN, + )) { + error!("virtio-mem: Failed to un-register activate event: {err}"); + } + } +} + +impl MutEventSubscriber for VirtioMem { + fn init(&mut self, ops: &mut event_manager::EventOps) { + // This function can be called during different points in the device lifetime: + // - shortly after device creation, + // - on device activation (is-activated already true at this point), + // - on device restore from snapshot. + if self.is_activated() { + self.register_runtime_events(ops); + } else { + self.register_activate_event(ops); + } + } + + fn process(&mut self, events: event_manager::Events, ops: &mut event_manager::EventOps) { + let event_set = events.event_set(); + let source = events.data(); + + if !event_set.contains(EventSet::IN) { + warn!("virtio-mem: Received unknown event: {event_set:?} from source {source}"); + return; + } + + if !self.is_activated() { + warn!("virtio-mem: The device is not activated yet. Spurious event received: {source}"); + return; + } + + match source { + Self::PROCESS_ACTIVATE => self.process_activate_event(ops), + Self::PROCESS_MEM_QUEUE => self.process_mem_queue_event(), + + _ => { + warn!("virtio-mem: Unknown event received: {source}"); + } + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::{Arc, Mutex}; + + use event_manager::{EventManager, SubscriberOps}; + use vmm_sys_util::epoll::EventSet; + + use super::*; + use crate::devices::virtio::ActivateError; + use crate::devices::virtio::generated::virtio_mem::VIRTIO_MEM_F_UNPLUGGED_INACCESSIBLE; + use crate::devices::virtio::mem::device::test_utils::default_virtio_mem; + use crate::devices::virtio::test_utils::{VirtQueue, default_interrupt, default_mem}; + use crate::vstate::memory::GuestAddress; + + #[test] + fn test_event_handler_activation() { + let mut event_manager = EventManager::new().unwrap(); + let mut mem_device = default_virtio_mem(); + let mem = default_mem(); + let interrupt = default_interrupt(); + + // Set up queue + let virtq = VirtQueue::new(GuestAddress(0), &mem, 16); + mem_device.queues_mut()[MEM_QUEUE] = virtq.create_queue(); + + let mem_device = Arc::new(Mutex::new(mem_device)); + let _id = event_manager.add_subscriber(mem_device.clone()); + + // Device should register activate event when inactive + assert!(!mem_device.lock().unwrap().is_activated()); + + // Device should prevent activation before features are acked + let err = mem_device + .lock() + .unwrap() + .activate(mem.clone(), interrupt.clone()) + .unwrap_err(); + + assert!(matches!(err, ActivateError::RequiredFeatureNotAcked(_))); + + // Ack the feature and activate the device + mem_device + .lock() + .unwrap() + .set_acked_features(1 << VIRTIO_MEM_F_UNPLUGGED_INACCESSIBLE); + + mem_device.lock().unwrap().activate(mem, interrupt).unwrap(); + + // Process activation event + let ev_count = event_manager.run_with_timeout(50).unwrap(); + assert_eq!(ev_count, 1); + assert!(mem_device.lock().unwrap().is_activated()); + } + + #[test] + fn test_process_mem_queue_event() { + let mut event_manager = EventManager::new().unwrap(); + let mut mem_device = default_virtio_mem(); + let mem = default_mem(); + let interrupt = default_interrupt(); + + // Set up queue + let virtq = VirtQueue::new(GuestAddress(0), &mem, 16); + mem_device.queues_mut()[MEM_QUEUE] = virtq.create_queue(); + mem_device.set_acked_features(mem_device.avail_features()); + + let mem_device = Arc::new(Mutex::new(mem_device)); + let _id = event_manager.add_subscriber(mem_device.clone()); + + // Activate device first + mem_device.lock().unwrap().activate(mem, interrupt).unwrap(); + event_manager.run_with_timeout(50).unwrap(); // Process activation + + // Trigger queue event + mem_device.lock().unwrap().queue_events()[MEM_QUEUE] + .write(1) + .unwrap(); + + // Process queue event + let ev_count = event_manager.run_with_timeout(50).unwrap(); + assert_eq!(ev_count, 1); + } + + #[test] + fn test_spurious_event_before_activation() { + let mut event_manager = EventManager::new().unwrap(); + let mem_device = default_virtio_mem(); + let mem_device = Arc::new(Mutex::new(mem_device)); + let _id = event_manager.add_subscriber(mem_device.clone()); + + // Try to trigger queue event before activation + mem_device.lock().unwrap().queue_events()[MEM_QUEUE] + .write(1) + .unwrap(); + + // Should not process queue events before activation + let ev_count = event_manager.run_with_timeout(50).unwrap(); + assert_eq!(ev_count, 0); + } +} diff --git a/src/vmm/src/devices/virtio/mem/metrics.rs b/src/vmm/src/devices/virtio/mem/metrics.rs new file mode 100644 index 00000000000..d69255d44ec --- /dev/null +++ b/src/vmm/src/devices/virtio/mem/metrics.rs @@ -0,0 +1,118 @@ +// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Defines the metrics system for memory devices. +//! +//! # Metrics format +//! The metrics are flushed in JSON when requested by vmm::logger::metrics::METRICS.write(). +//! +//! ## JSON example with metrics: +//! ```json +//! "memory_hotplug": { +//! "activate_fails": "SharedIncMetric", +//! "queue_event_fails": "SharedIncMetric", +//! "queue_event_count": "SharedIncMetric", +//! ... +//! } +//! } +//! ``` +//! Each `memory` field in the example above is a serializable `VirtioMemDeviceMetrics` structure +//! collecting metrics such as `activate_fails`, `queue_event_fails` etc. for the memoty hotplug +//! device. +//! Since Firecrakcer only supports one virtio-mem device, there is no per device metrics and +//! `memory_hotplug` represents the aggregate entropy metrics. + +use serde::ser::SerializeMap; +use serde::{Serialize, Serializer}; + +use crate::logger::{LatencyAggregateMetrics, SharedIncMetric}; + +/// Stores aggregated virtio-mem metrics +pub(super) static METRICS: VirtioMemDeviceMetrics = VirtioMemDeviceMetrics::new(); + +/// Called by METRICS.flush(), this function facilitates serialization of virtio-mem device metrics. +pub fn flush_metrics(serializer: S) -> Result { + let mut seq = serializer.serialize_map(Some(1))?; + seq.serialize_entry("memory_hotplug", &METRICS)?; + seq.end() +} + +#[derive(Debug, Serialize)] +pub(super) struct VirtioMemDeviceMetrics { + /// Number of device activation failures + pub activate_fails: SharedIncMetric, + /// Number of queue event handling failures + pub queue_event_fails: SharedIncMetric, + /// Number of queue events handled + pub queue_event_count: SharedIncMetric, + /// Latency of Plug operations + pub plug_agg: LatencyAggregateMetrics, + /// Number of Plug operations + pub plug_count: SharedIncMetric, + /// Number of plugged bytes + pub plug_bytes: SharedIncMetric, + /// Number of Plug operations failed + pub plug_fails: SharedIncMetric, + /// Latency of Unplug operations + pub unplug_agg: LatencyAggregateMetrics, + /// Number of Unplug operations + pub unplug_count: SharedIncMetric, + /// Number of unplugged bytes + pub unplug_bytes: SharedIncMetric, + /// Number of Unplug operations failed + pub unplug_fails: SharedIncMetric, + /// Number of discards failed for an Unplug or UnplugAll operation + pub unplug_discard_fails: SharedIncMetric, + /// Latency of UnplugAll operations + pub unplug_all_agg: LatencyAggregateMetrics, + /// Number of UnplugAll operations + pub unplug_all_count: SharedIncMetric, + /// Number of UnplugAll operations failed + pub unplug_all_fails: SharedIncMetric, + /// Latency of State operations + pub state_agg: LatencyAggregateMetrics, + /// Number of State operations + pub state_count: SharedIncMetric, + /// Number of State operations failed + pub state_fails: SharedIncMetric, +} + +impl VirtioMemDeviceMetrics { + /// Const default construction. + const fn new() -> Self { + Self { + activate_fails: SharedIncMetric::new(), + queue_event_fails: SharedIncMetric::new(), + queue_event_count: SharedIncMetric::new(), + plug_agg: LatencyAggregateMetrics::new(), + plug_count: SharedIncMetric::new(), + plug_bytes: SharedIncMetric::new(), + plug_fails: SharedIncMetric::new(), + unplug_agg: LatencyAggregateMetrics::new(), + unplug_count: SharedIncMetric::new(), + unplug_bytes: SharedIncMetric::new(), + unplug_fails: SharedIncMetric::new(), + unplug_discard_fails: SharedIncMetric::new(), + unplug_all_agg: LatencyAggregateMetrics::new(), + unplug_all_count: SharedIncMetric::new(), + unplug_all_fails: SharedIncMetric::new(), + state_agg: LatencyAggregateMetrics::new(), + state_count: SharedIncMetric::new(), + state_fails: SharedIncMetric::new(), + } + } +} + +#[cfg(test)] +pub mod tests { + use super::*; + use crate::logger::IncMetric; + + #[test] + fn test_memory_hotplug_metrics() { + let mem_metrics: VirtioMemDeviceMetrics = VirtioMemDeviceMetrics::new(); + mem_metrics.queue_event_count.inc(); + assert_eq!(mem_metrics.queue_event_count.count(), 1); + let _ = serde_json::to_string(&mem_metrics).unwrap(); + } +} diff --git a/src/vmm/src/devices/virtio/mem/mod.rs b/src/vmm/src/devices/virtio/mem/mod.rs new file mode 100644 index 00000000000..14c399ecd23 --- /dev/null +++ b/src/vmm/src/devices/virtio/mem/mod.rs @@ -0,0 +1,22 @@ +// Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +mod device; +mod event_handler; +pub mod metrics; +pub mod persist; +mod request; + +use vm_memory::GuestAddress; + +pub use self::device::{VirtioMem, VirtioMemError, VirtioMemStatus}; +use crate::arch::FIRST_ADDR_PAST_64BITS_MMIO; + +pub(crate) const MEM_NUM_QUEUES: usize = 1; + +pub(crate) const MEM_QUEUE: usize = 0; + +pub const VIRTIO_MEM_DEFAULT_BLOCK_SIZE_MIB: usize = 2; +pub const VIRTIO_MEM_DEFAULT_SLOT_SIZE_MIB: usize = 128; + +pub const VIRTIO_MEM_DEV_ID: &str = "mem"; diff --git a/src/vmm/src/devices/virtio/mem/persist.rs b/src/vmm/src/devices/virtio/mem/persist.rs new file mode 100644 index 00000000000..345b464c48d --- /dev/null +++ b/src/vmm/src/devices/virtio/mem/persist.rs @@ -0,0 +1,149 @@ +// Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Defines the structures needed for saving/restoring virtio-mem devices. + +use std::sync::Arc; + +use bitvec::vec::BitVec; +use serde::{Deserialize, Serialize}; +use vm_memory::Address; + +use crate::Vm; +use crate::devices::virtio::generated::virtio_ids::VIRTIO_ID_MEM; +use crate::devices::virtio::generated::virtio_mem::virtio_mem_config; +use crate::devices::virtio::mem::{MEM_NUM_QUEUES, VirtioMem, VirtioMemError}; +use crate::devices::virtio::persist::{PersistError as VirtioStateError, VirtioDeviceState}; +use crate::devices::virtio::queue::FIRECRACKER_MAX_QUEUE_SIZE; +use crate::snapshot::Persist; +use crate::utils::usize_to_u64; +use crate::vstate::memory::{GuestMemoryMmap, GuestRegionMmap}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VirtioMemState { + pub virtio_state: VirtioDeviceState, + addr: u64, + region_size: u64, + block_size: u64, + usable_region_size: u64, + requested_size: u64, + slot_size: usize, + plugged_blocks: Vec, +} + +#[derive(Debug)] +pub struct VirtioMemConstructorArgs { + vm: Arc, +} + +impl VirtioMemConstructorArgs { + pub fn new(vm: Arc) -> Self { + Self { vm } + } +} + +#[derive(Debug, thiserror::Error, displaydoc::Display)] +pub enum VirtioMemPersistError { + /// Create virtio-mem: {0} + CreateVirtioMem(#[from] VirtioMemError), + /// Virtio state: {0} + VirtioState(#[from] VirtioStateError), +} + +impl Persist<'_> for VirtioMem { + type State = VirtioMemState; + type ConstructorArgs = VirtioMemConstructorArgs; + type Error = VirtioMemPersistError; + + fn save(&self) -> Self::State { + VirtioMemState { + virtio_state: VirtioDeviceState::from_device(self), + addr: self.config.addr, + region_size: self.config.region_size, + block_size: self.config.block_size, + usable_region_size: self.config.usable_region_size, + plugged_blocks: self.plugged_blocks.iter().by_vals().collect(), + requested_size: self.config.requested_size, + slot_size: self.slot_size, + } + } + + fn restore( + constructor_args: Self::ConstructorArgs, + state: &Self::State, + ) -> Result { + let queues = state.virtio_state.build_queues_checked( + constructor_args.vm.guest_memory(), + VIRTIO_ID_MEM, + MEM_NUM_QUEUES, + FIRECRACKER_MAX_QUEUE_SIZE, + )?; + + let plugged_blocks = BitVec::from_iter(state.plugged_blocks.iter()); + + let config = virtio_mem_config { + addr: state.addr, + region_size: state.region_size, + block_size: state.block_size, + usable_region_size: state.usable_region_size, + plugged_size: usize_to_u64(plugged_blocks.count_ones()) * state.block_size, + requested_size: state.requested_size, + ..Default::default() + }; + + let mut virtio_mem = VirtioMem::from_state( + constructor_args.vm, + queues, + config, + state.slot_size, + plugged_blocks, + )?; + virtio_mem.set_avail_features(state.virtio_state.avail_features); + virtio_mem.set_acked_features(state.virtio_state.acked_features); + + Ok(virtio_mem) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::devices::virtio::device::VirtioDevice; + use crate::devices::virtio::mem::device::test_utils::default_virtio_mem; + use crate::vstate::vm::tests::setup_vm_with_memory; + + #[test] + fn test_save_state() { + let dev = default_virtio_mem(); + let state = dev.save(); + + assert_eq!(state.addr, dev.config.addr); + assert_eq!(state.region_size, dev.config.region_size); + assert_eq!(state.block_size, dev.config.block_size); + assert_eq!(state.usable_region_size, dev.config.usable_region_size); + assert_eq!( + state.plugged_blocks.iter().collect::(), + dev.plugged_blocks + ); + assert_eq!(state.requested_size, dev.config.requested_size); + assert_eq!(state.slot_size, dev.slot_size); + } + + #[test] + fn test_save_restore_state() { + let mut original_dev = default_virtio_mem(); + original_dev.set_acked_features(original_dev.avail_features()); + let state = original_dev.save(); + + // Create a "new" VM for restore + let (_, vm) = setup_vm_with_memory(0x1000); + let vm = Arc::new(vm); + let constructor_args = VirtioMemConstructorArgs::new(vm); + let restored_dev = VirtioMem::restore(constructor_args, &state).unwrap(); + + assert_eq!(original_dev.config, restored_dev.config); + assert_eq!(original_dev.slot_size, restored_dev.slot_size); + assert_eq!(original_dev.avail_features(), restored_dev.avail_features()); + assert_eq!(original_dev.acked_features(), restored_dev.acked_features()); + } +} diff --git a/src/vmm/src/devices/virtio/mem/request.rs b/src/vmm/src/devices/virtio/mem/request.rs new file mode 100644 index 00000000000..1cc0643392c --- /dev/null +++ b/src/vmm/src/devices/virtio/mem/request.rs @@ -0,0 +1,207 @@ +// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use vm_memory::{Address, ByteValued, GuestAddress}; + +use crate::devices::virtio::generated::virtio_mem; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct RequestedRange { + pub(crate) addr: GuestAddress, + pub(crate) nb_blocks: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) enum Request { + Plug(RequestedRange), + Unplug(RequestedRange), + UnplugAll, + State(RequestedRange), + Unsupported(u32), +} + +// SAFETY: this is safe, trust me bro +unsafe impl ByteValued for virtio_mem::virtio_mem_req {} + +impl From for Request { + fn from(req: virtio_mem::virtio_mem_req) -> Self { + match req.type_.into() { + // SAFETY: union type is checked in the match + virtio_mem::VIRTIO_MEM_REQ_PLUG => unsafe { + Request::Plug(RequestedRange { + addr: GuestAddress(req.u.plug.addr), + nb_blocks: req.u.plug.nb_blocks.into(), + }) + }, + // SAFETY: union type is checked in the match + virtio_mem::VIRTIO_MEM_REQ_UNPLUG => unsafe { + Request::Unplug(RequestedRange { + addr: GuestAddress(req.u.unplug.addr), + nb_blocks: req.u.unplug.nb_blocks.into(), + }) + }, + virtio_mem::VIRTIO_MEM_REQ_UNPLUG_ALL => Request::UnplugAll, + // SAFETY: union type is checked in the match + virtio_mem::VIRTIO_MEM_REQ_STATE => unsafe { + Request::State(RequestedRange { + addr: GuestAddress(req.u.state.addr), + nb_blocks: req.u.state.nb_blocks.into(), + }) + }, + t => Request::Unsupported(t), + } + } +} + +#[repr(u16)] +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +#[allow(clippy::cast_possible_truncation)] +pub enum ResponseType { + Ack = virtio_mem::VIRTIO_MEM_RESP_ACK as u16, + Nack = virtio_mem::VIRTIO_MEM_RESP_NACK as u16, + Busy = virtio_mem::VIRTIO_MEM_RESP_BUSY as u16, + Error = virtio_mem::VIRTIO_MEM_RESP_ERROR as u16, +} + +#[repr(u16)] +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +#[allow(clippy::cast_possible_truncation)] +pub enum BlockRangeState { + Plugged = virtio_mem::VIRTIO_MEM_STATE_PLUGGED as u16, + Unplugged = virtio_mem::VIRTIO_MEM_STATE_UNPLUGGED as u16, + Mixed = virtio_mem::VIRTIO_MEM_STATE_MIXED as u16, +} + +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct Response { + pub resp_type: ResponseType, + // Only for State requests + pub state: Option, +} + +impl Response { + pub(crate) fn error() -> Self { + Response { + resp_type: ResponseType::Error, + state: None, + } + } + + pub(crate) fn ack() -> Self { + Response { + resp_type: ResponseType::Ack, + state: None, + } + } + + pub(crate) fn ack_with_state(state: BlockRangeState) -> Self { + Response { + resp_type: ResponseType::Ack, + state: Some(state), + } + } + + pub(crate) fn is_ack(&self) -> bool { + self.resp_type == ResponseType::Ack + } + + pub(crate) fn is_error(&self) -> bool { + self.resp_type == ResponseType::Error + } +} + +// SAFETY: Plain data structures +unsafe impl ByteValued for virtio_mem::virtio_mem_resp {} + +impl From for virtio_mem::virtio_mem_resp { + fn from(resp: Response) -> Self { + let mut out = virtio_mem::virtio_mem_resp { + type_: resp.resp_type as u16, + ..Default::default() + }; + if let Some(state) = resp.state { + out.u.state.state = state as u16; + } + out + } +} + +#[cfg(test)] +mod test_util { + use super::*; + + // Implement the reverse conversions to use in test code. + + impl From for virtio_mem::virtio_mem_req { + fn from(req: Request) -> virtio_mem::virtio_mem_req { + match req { + Request::Plug(r) => virtio_mem::virtio_mem_req { + type_: virtio_mem::VIRTIO_MEM_REQ_PLUG.try_into().unwrap(), + u: virtio_mem::virtio_mem_req__bindgen_ty_1 { + plug: virtio_mem::virtio_mem_req_plug { + addr: r.addr.raw_value(), + nb_blocks: r.nb_blocks.try_into().unwrap(), + ..Default::default() + }, + }, + ..Default::default() + }, + Request::Unplug(r) => virtio_mem::virtio_mem_req { + type_: virtio_mem::VIRTIO_MEM_REQ_UNPLUG.try_into().unwrap(), + u: virtio_mem::virtio_mem_req__bindgen_ty_1 { + unplug: virtio_mem::virtio_mem_req_unplug { + addr: r.addr.raw_value(), + nb_blocks: r.nb_blocks.try_into().unwrap(), + ..Default::default() + }, + }, + ..Default::default() + }, + Request::UnplugAll => virtio_mem::virtio_mem_req { + type_: virtio_mem::VIRTIO_MEM_REQ_UNPLUG_ALL.try_into().unwrap(), + ..Default::default() + }, + Request::State(r) => virtio_mem::virtio_mem_req { + type_: virtio_mem::VIRTIO_MEM_REQ_STATE.try_into().unwrap(), + u: virtio_mem::virtio_mem_req__bindgen_ty_1 { + state: virtio_mem::virtio_mem_req_state { + addr: r.addr.raw_value(), + nb_blocks: r.nb_blocks.try_into().unwrap(), + ..Default::default() + }, + }, + ..Default::default() + }, + Request::Unsupported(t) => virtio_mem::virtio_mem_req { + type_: t.try_into().unwrap(), + ..Default::default() + }, + } + } + } + + impl From for Response { + fn from(resp: virtio_mem::virtio_mem_resp) -> Self { + Response { + resp_type: match resp.type_.into() { + virtio_mem::VIRTIO_MEM_RESP_ACK => ResponseType::Ack, + virtio_mem::VIRTIO_MEM_RESP_NACK => ResponseType::Nack, + virtio_mem::VIRTIO_MEM_RESP_BUSY => ResponseType::Busy, + virtio_mem::VIRTIO_MEM_RESP_ERROR => ResponseType::Error, + t => panic!("Invalid response type: {:?}", t), + }, + // There is no way to know whether this is present or not as it depends on the + // request types. Callers should ignore this value if the request wasn't STATE + /// SAFETY: test code only. Uninitialized values are 0 and recognized as PLUGGED. + state: Some(unsafe { + match resp.u.state.state.into() { + virtio_mem::VIRTIO_MEM_STATE_PLUGGED => BlockRangeState::Plugged, + virtio_mem::VIRTIO_MEM_STATE_UNPLUGGED => BlockRangeState::Unplugged, + virtio_mem::VIRTIO_MEM_STATE_MIXED => BlockRangeState::Mixed, + t => panic!("Invalid state: {:?}", t), + } + }), + } + } + } +} diff --git a/src/vmm/src/devices/virtio/mod.rs b/src/vmm/src/devices/virtio/mod.rs index 840dbffdb5e..b17db5b02ac 100644 --- a/src/vmm/src/devices/virtio/mod.rs +++ b/src/vmm/src/devices/virtio/mod.rs @@ -18,6 +18,7 @@ pub mod device; pub mod generated; mod iov_deque; pub mod iovec; +pub mod mem; pub mod net; pub mod persist; pub mod pmem; @@ -64,6 +65,8 @@ pub enum ActivateError { TapSetOffload(TapError), /// Error setting pointers in the queue: (0) QueueMemoryError(QueueError), + /// The driver didn't acknowledge a required feature: {0} + RequiredFeatureNotAcked(&'static str), } /// Trait that helps in upcasting an object to Any diff --git a/src/vmm/src/devices/virtio/pmem/device.rs b/src/vmm/src/devices/virtio/pmem/device.rs index d128225aed8..8ac150fb894 100644 --- a/src/vmm/src/devices/virtio/pmem/device.rs +++ b/src/vmm/src/devices/virtio/pmem/device.rs @@ -26,12 +26,13 @@ use crate::logger::{IncMetric, error}; use crate::utils::{align_up, u64_to_usize}; use crate::vmm_config::pmem::PmemConfig; use crate::vstate::memory::{ByteValued, Bytes, GuestMemoryMmap, GuestMmapRegion}; +use crate::vstate::vm::VmError; use crate::{Vm, impl_device_type}; #[derive(Debug, thiserror::Error, displaydoc::Display)] pub enum PmemError { /// Cannot set the memory regions: {0} - SetUserMemoryRegion(kvm_ioctls::Error), + SetUserMemoryRegion(VmError), /// Unablet to allocate a KVM slot for the device NoKvmSlotAvailable, /// Error accessing backing file: {0} @@ -221,7 +222,7 @@ impl Pmem { /// Set user memory region in KVM pub fn set_mem_region(&mut self, vm: &Vm) -> Result<(), PmemError> { - let next_slot = vm.next_kvm_slot().ok_or(PmemError::NoKvmSlotAvailable)?; + let next_slot = vm.next_kvm_slot(1).ok_or(PmemError::NoKvmSlotAvailable)?; let memory_region = kvm_userspace_memory_region { slot: next_slot, guest_phys_addr: self.config_space.start, @@ -233,14 +234,9 @@ impl Pmem { 0 }, }; - // SAFETY: The fd is a valid VM file descriptor and all fields in the - // `memory_region` struct are valid. - unsafe { - vm.fd() - .set_user_memory_region(memory_region) - .map_err(PmemError::SetUserMemoryRegion)?; - } - Ok(()) + + vm.set_user_memory_region(memory_region) + .map_err(PmemError::SetUserMemoryRegion) } fn handle_queue(&mut self) -> Result<(), PmemError> { diff --git a/src/vmm/src/devices/virtio/test_utils.rs b/src/vmm/src/devices/virtio/test_utils.rs index 6f1489dd380..0c7978504e7 100644 --- a/src/vmm/src/devices/virtio/test_utils.rs +++ b/src/vmm/src/devices/virtio/test_utils.rs @@ -442,6 +442,11 @@ pub(crate) mod test { self.virtqueues.last().unwrap().end().raw_value() } + /// Get the address of a descriptor + pub fn desc_address(&self, queue: usize, index: usize) -> GuestAddress { + GuestAddress(self.virtqueues[queue].dtable[index].addr.get()) + } + /// Add a new Descriptor in one of the device's queues /// /// This function adds in one of the queues of the device a DescriptorChain at some offset diff --git a/src/vmm/src/lib.rs b/src/vmm/src/lib.rs index 6b88a317605..8059bc76b9c 100644 --- a/src/vmm/src/lib.rs +++ b/src/vmm/src/lib.rs @@ -140,6 +140,7 @@ use crate::devices::virtio::balloon::{ }; use crate::devices::virtio::block::BlockError; use crate::devices::virtio::block::device::Block; +use crate::devices::virtio::mem::{VIRTIO_MEM_DEV_ID, VirtioMem, VirtioMemError, VirtioMemStatus}; use crate::devices::virtio::net::Net; use crate::logger::{METRICS, MetricsError, error, info, warn}; use crate::persist::{MicrovmState, MicrovmStateError, VmInfo}; @@ -252,6 +253,8 @@ pub enum VmmError { Block(#[from] BlockError), /// Balloon: {0} Balloon(#[from] BalloonError), + /// Failed to create memory hotplug device: {0} + VirtioMem(#[from] VirtioMemError), } /// Shorthand type for KVM dirty page bitmap. @@ -603,6 +606,23 @@ impl Vmm { Ok(()) } + /// Returns the current state of the memory hotplug device. + pub fn memory_hotplug_status(&self) -> Result { + self.device_manager + .with_virtio_device(VIRTIO_MEM_DEV_ID, |dev: &mut VirtioMem| dev.status()) + .map_err(VmmError::FindDeviceError) + } + + /// Returns the current state of the memory hotplug device. + pub fn update_memory_hotplug_size(&self, requested_size_mib: usize) -> Result<(), VmmError> { + self.device_manager + .with_virtio_device(VIRTIO_MEM_DEV_ID, |dev: &mut VirtioMem| { + dev.update_requested_size(requested_size_mib) + }) + .map_err(VmmError::FindDeviceError)??; + Ok(()) + } + /// Signals Vmm to stop and exit. pub fn stop(&mut self, exit_code: FcExitCode) { // To avoid cycles, all teardown paths take the following route: diff --git a/src/vmm/src/logger/metrics.rs b/src/vmm/src/logger/metrics.rs index 527ba911461..2bd7330c61a 100644 --- a/src/vmm/src/logger/metrics.rs +++ b/src/vmm/src/logger/metrics.rs @@ -74,6 +74,7 @@ use super::FcLineWriter; use crate::devices::legacy; use crate::devices::virtio::balloon::metrics as balloon_metrics; use crate::devices::virtio::block::virtio::metrics as block_metrics; +use crate::devices::virtio::mem::metrics as virtio_mem_metrics; use crate::devices::virtio::net::metrics as net_metrics; use crate::devices::virtio::pmem::metrics as pmem_metrics; use crate::devices::virtio::rng::metrics as entropy_metrics; @@ -360,6 +361,8 @@ pub struct GetRequestsMetrics { pub mmds_count: SharedIncMetric, /// Number of GETs for getting the VMM version. pub vmm_version_count: SharedIncMetric, + /// Number of GETs for getting hotpluggable memory status. + pub hotplug_memory_count: SharedIncMetric, } impl GetRequestsMetrics { /// Const default construction. @@ -369,6 +372,7 @@ impl GetRequestsMetrics { machine_cfg_count: SharedIncMetric::new(), mmds_count: SharedIncMetric::new(), vmm_version_count: SharedIncMetric::new(), + hotplug_memory_count: SharedIncMetric::new(), } } } @@ -424,6 +428,10 @@ pub struct PutRequestsMetrics { pub serial_count: SharedIncMetric, /// Number of failed PUTs to /serial pub serial_fails: SharedIncMetric, + /// Number of PUTs to /hotplug/memory + pub hotplug_memory_count: SharedIncMetric, + /// Number of failed PUTs to /hotplug/memory + pub hotplug_memory_fails: SharedIncMetric, } impl PutRequestsMetrics { /// Const default construction. @@ -453,6 +461,8 @@ impl PutRequestsMetrics { pmem_fails: SharedIncMetric::new(), serial_count: SharedIncMetric::new(), serial_fails: SharedIncMetric::new(), + hotplug_memory_count: SharedIncMetric::new(), + hotplug_memory_fails: SharedIncMetric::new(), } } } @@ -476,6 +486,10 @@ pub struct PatchRequestsMetrics { pub mmds_count: SharedIncMetric, /// Number of failures in PATCHing an mmds. pub mmds_fails: SharedIncMetric, + /// Number of PATCHes to /hotplug/memory + pub hotplug_memory_count: SharedIncMetric, + /// Number of failed PATCHes to /hotplug/memory + pub hotplug_memory_fails: SharedIncMetric, } impl PatchRequestsMetrics { /// Const default construction. @@ -489,6 +503,8 @@ impl PatchRequestsMetrics { machine_cfg_fails: SharedIncMetric::new(), mmds_count: SharedIncMetric::new(), mmds_fails: SharedIncMetric::new(), + hotplug_memory_count: SharedIncMetric::new(), + hotplug_memory_fails: SharedIncMetric::new(), } } } @@ -876,6 +892,7 @@ create_serialize_proxy!(EntropyMetricsSerializeProxy, entropy_metrics); create_serialize_proxy!(VsockMetricsSerializeProxy, vsock_metrics); create_serialize_proxy!(PmemMetricsSerializeProxy, pmem_metrics); create_serialize_proxy!(LegacyDevMetricsSerializeProxy, legacy); +create_serialize_proxy!(MemoryHotplugSerializeProxy, virtio_mem_metrics); /// Structure storing all metrics while enforcing serialization support on them. #[derive(Debug, Default, Serialize)] @@ -931,6 +948,9 @@ pub struct FirecrackerMetrics { pub vhost_user_ser: VhostUserMetricsSerializeProxy, /// Interrupt related metrics pub interrupts: InterruptMetrics, + #[serde(flatten)] + /// Virtio-mem device related metrics (memory hotplugging) + pub memory_hotplug_ser: MemoryHotplugSerializeProxy, } impl FirecrackerMetrics { /// Const default construction. @@ -958,6 +978,7 @@ impl FirecrackerMetrics { pmem_ser: PmemMetricsSerializeProxy {}, vhost_user_ser: VhostUserMetricsSerializeProxy {}, interrupts: InterruptMetrics::new(), + memory_hotplug_ser: MemoryHotplugSerializeProxy {}, } } } diff --git a/src/vmm/src/persist.rs b/src/vmm/src/persist.rs index cbc4beac95a..405a5fb4b8d 100644 --- a/src/vmm/src/persist.rs +++ b/src/vmm/src/persist.rs @@ -698,6 +698,7 @@ mod tests { base_address: 0, size: 0x20000, region_type: GuestRegionType::Dram, + plugged: vec![true], }], }; diff --git a/src/vmm/src/resources.rs b/src/vmm/src/resources.rs index 066fe3524be..4cdf7ac1014 100644 --- a/src/vmm/src/resources.rs +++ b/src/vmm/src/resources.rs @@ -25,6 +25,7 @@ use crate::vmm_config::instance_info::InstanceInfo; use crate::vmm_config::machine_config::{ HugePageConfig, MachineConfig, MachineConfigError, MachineConfigUpdate, }; +use crate::vmm_config::memory_hotplug::{MemoryHotplugConfig, MemoryHotplugConfigError}; use crate::vmm_config::metrics::{MetricsConfig, MetricsConfigError, init_metrics}; use crate::vmm_config::mmds::{MmdsConfig, MmdsConfigError}; use crate::vmm_config::net::*; @@ -65,6 +66,8 @@ pub enum ResourcesError { EntropyDevice(#[from] EntropyDeviceError), /// Pmem device error: {0} PmemDevice(#[from] PmemConfigError), + /// Memory hotplug config error: {0} + MemoryHotplugConfig(#[from] MemoryHotplugConfigError), } #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] @@ -94,6 +97,7 @@ pub struct VmmConfig { pmem_devices: Vec, #[serde(skip)] serial_config: Option, + memory_hotplug: Option, } /// A data structure that encapsulates the device configurations @@ -116,6 +120,8 @@ pub struct VmResources { pub entropy: EntropyDeviceBuilder, /// The pmem devices. pub pmem: PmemBuilder, + /// The memory hotplug configuration. + pub memory_hotplug: Option, /// The optional Mmds data store. // This is initialised on demand (if ever used), so that we don't allocate it unless it's // actually used. @@ -213,6 +219,10 @@ impl VmResources { resources.serial_out_path = serial_cfg.serial_out_path; } + if let Some(memory_hotplug_config) = vmm_config.memory_hotplug { + resources.set_memory_hotplug_config(memory_hotplug_config)?; + } + Ok(resources) } @@ -373,6 +383,16 @@ impl VmResources { self.pmem.build(body, has_block_root) } + /// Sets the memory hotplug configuration. + pub fn set_memory_hotplug_config( + &mut self, + config: MemoryHotplugConfig, + ) -> Result<(), MemoryHotplugConfigError> { + config.validate()?; + self.memory_hotplug = Some(config); + Ok(()) + } + /// Setter for mmds config. pub fn set_mmds_config( &mut self, @@ -491,6 +511,18 @@ impl VmResources { crate::arch::arch_memory_regions(mib_to_bytes(self.machine_config.mem_size_mib)); self.allocate_memory_regions(®ions) } + + /// Allocates a single guest memory region. + pub fn allocate_memory_region( + &self, + start: GuestAddress, + size: usize, + ) -> Result { + Ok(self + .allocate_memory_regions(&[(start, size)])? + .pop() + .unwrap()) + } } impl From<&VmResources> for VmmConfig { @@ -510,6 +542,7 @@ impl From<&VmResources> for VmmConfig { pmem_devices: resources.pmem.configs(), // serial_config is marked serde(skip) so that it doesnt end up in snapshots. serial_config: None, + memory_hotplug: resources.memory_hotplug.clone(), } } } @@ -622,6 +655,7 @@ mod tests { pmem: Default::default(), pci_enabled: false, serial_out_path: None, + memory_hotplug: Default::default(), } } diff --git a/src/vmm/src/rpc_interface.rs b/src/vmm/src/rpc_interface.rs index 6bae98f3546..1487795ca1c 100644 --- a/src/vmm/src/rpc_interface.rs +++ b/src/vmm/src/rpc_interface.rs @@ -14,6 +14,7 @@ use super::{Vmm, VmmError}; use crate::EventManager; use crate::builder::StartMicrovmError; use crate::cpu_config::templates::{CustomCpuTemplate, GuestConfigError}; +use crate::devices::virtio::mem::VirtioMemStatus; use crate::logger::{LoggerConfig, info, warn, *}; use crate::mmds::data_store::{self, Mmds}; use crate::persist::{CreateSnapshotError, RestoreFromSnapshotError, VmInfo}; @@ -28,6 +29,9 @@ use crate::vmm_config::drive::{BlockDeviceConfig, BlockDeviceUpdateConfig, Drive use crate::vmm_config::entropy::{EntropyDeviceConfig, EntropyDeviceError}; use crate::vmm_config::instance_info::InstanceInfo; use crate::vmm_config::machine_config::{MachineConfig, MachineConfigError, MachineConfigUpdate}; +use crate::vmm_config::memory_hotplug::{ + MemoryHotplugConfig, MemoryHotplugConfigError, MemoryHotplugSizeUpdate, +}; use crate::vmm_config::metrics::{MetricsConfig, MetricsConfigError}; use crate::vmm_config::mmds::{MmdsConfig, MmdsConfigError}; use crate::vmm_config::net::{ @@ -109,6 +113,14 @@ pub enum VmmAction { /// Set the entropy device using `EntropyDeviceConfig` as input. This action can only be called /// before the microVM has booted. SetEntropyDevice(EntropyDeviceConfig), + /// Get the memory hotplug device configuration and status. + GetMemoryHotplugStatus, + /// Set the memory hotplug device using `MemoryHotplugConfig` as input. This action can only be + /// called before the microVM has booted. + SetMemoryHotplugDevice(MemoryHotplugConfig), + /// Updates the memory hotplug device using `MemoryHotplugConfigUpdate` as input. This action + /// can only be called after the microVM has booted. + UpdateMemoryHotplugSize(MemoryHotplugSizeUpdate), /// Launch the microVM. This action can only be called before the microVM has booted. StartMicroVm, /// Send CTRL+ALT+DEL to the microVM, using the i8042 keyboard function. If an AT-keyboard @@ -148,6 +160,10 @@ pub enum VmmActionError { EntropyDevice(#[from] EntropyDeviceError), /// Pmem device error: {0} PmemDevice(#[from] PmemConfigError), + /// Memory hotplug config error: {0} + MemoryHotplugConfig(#[from] MemoryHotplugConfigError), + /// Memory hotplug update error: {0} + MemoryHotplugUpdate(VmmError), /// Internal VMM error: {0} InternalVmm(#[from] VmmError), /// Load snapshot error: {0} @@ -201,6 +217,8 @@ pub enum VmmData { InstanceInformation(InstanceInfo), /// The microVM version. VmmVersion(String), + /// The status of the memory hotplug device. + VirtioMemStatus(VirtioMemStatus), } /// Trait used for deduplicating the MMDS request handling across the two ApiControllers. @@ -453,15 +471,18 @@ impl<'a> PrebootApiController<'a> { StartMicroVm => self.start_microvm(), UpdateMachineConfiguration(config) => self.update_machine_config(config), SetEntropyDevice(config) => self.set_entropy_device(config), + SetMemoryHotplugDevice(config) => self.set_memory_hotplug_device(config), // Operations not allowed pre-boot. CreateSnapshot(_) | FlushMetrics | Pause | Resume | GetBalloonStats + | GetMemoryHotplugStatus | UpdateBalloon(_) | UpdateBalloonStatistics(_) | UpdateBlockDevice(_) + | UpdateMemoryHotplugSize(_) | UpdateNetworkInterface(_) => Err(VmmActionError::OperationNotSupportedPreBoot), #[cfg(target_arch = "x86_64")] SendCtrlAltDel => Err(VmmActionError::OperationNotSupportedPreBoot), @@ -560,6 +581,15 @@ impl<'a> PrebootApiController<'a> { Ok(VmmData::Empty) } + fn set_memory_hotplug_device( + &mut self, + cfg: MemoryHotplugConfig, + ) -> Result { + self.boot_path = true; + self.vm_resources.set_memory_hotplug_config(cfg)?; + Ok(VmmData::Empty) + } + // On success, this command will end the pre-boot stage and this controller // will be replaced by a runtime controller. fn start_microvm(&mut self) -> Result { @@ -662,6 +692,13 @@ impl RuntimeApiController { .map(VmmData::BalloonStats) .map_err(VmmActionError::InternalVmm), GetFullVmConfig => Ok(VmmData::FullVmConfig((&self.vm_resources).into())), + GetMemoryHotplugStatus => self + .vmm + .lock() + .expect("Poisoned lock") + .memory_hotplug_status() + .map(VmmData::VirtioMemStatus) + .map_err(VmmActionError::InternalVmm), GetMMDS => self.get_mmds(), GetVmMachineConfig => Ok(VmmData::MachineConfiguration( self.vm_resources.machine_config.clone(), @@ -694,7 +731,13 @@ impl RuntimeApiController { .map_err(VmmActionError::BalloonUpdate), UpdateBlockDevice(new_cfg) => self.update_block_device(new_cfg), UpdateNetworkInterface(netif_update) => self.update_net_rate_limiters(netif_update), - + UpdateMemoryHotplugSize(cfg) => self + .vmm + .lock() + .expect("Poisoned lock") + .update_memory_hotplug_size(cfg.requested_size_mib) + .map(|_| VmmData::Empty) + .map_err(VmmActionError::MemoryHotplugUpdate), // Operations not allowed post-boot. ConfigureBootSource(_) | ConfigureLogger(_) @@ -709,6 +752,7 @@ impl RuntimeApiController { | SetVsockDevice(_) | SetMmdsConfiguration(_) | SetEntropyDevice(_) + | SetMemoryHotplugDevice(_) | StartMicroVm | UpdateMachineConfiguration(_) => Err(VmmActionError::OperationNotSupportedPostBoot), } @@ -1166,6 +1210,11 @@ mod tests { ))); #[cfg(target_arch = "x86_64")] check_unsupported(preboot_request(VmmAction::SendCtrlAltDel)); + check_unsupported(preboot_request(VmmAction::UpdateMemoryHotplugSize( + MemoryHotplugSizeUpdate { + requested_size_mib: 0, + }, + ))); } fn runtime_request(request: VmmAction) -> Result { @@ -1293,5 +1342,8 @@ mod tests { root_device: false, read_only: false, }))); + check_unsupported(runtime_request(VmmAction::SetMemoryHotplugDevice( + MemoryHotplugConfig::default(), + ))); } } diff --git a/src/vmm/src/test_utils/mod.rs b/src/vmm/src/test_utils/mod.rs index 89b4e238b3b..41809b71b34 100644 --- a/src/vmm/src/test_utils/mod.rs +++ b/src/vmm/src/test_utils/mod.rs @@ -15,6 +15,7 @@ use crate::test_utils::mock_resources::{MockBootSourceConfig, MockVmConfig, Mock use crate::vmm_config::boot_source::BootSourceConfig; use crate::vmm_config::instance_info::InstanceInfo; use crate::vmm_config::machine_config::HugePageConfig; +use crate::vmm_config::memory_hotplug::MemoryHotplugConfig; use crate::vstate::memory::{self, GuestMemoryMmap, GuestRegionMmap, GuestRegionMmapExt}; use crate::{EventManager, Vmm}; @@ -72,6 +73,7 @@ pub fn create_vmm( is_diff: bool, boot_microvm: bool, pci_enabled: bool, + memory_hotplug_enabled: bool, ) -> (Arc>, EventManager) { let mut event_manager = EventManager::new().unwrap(); let empty_seccomp_filters = get_empty_filters(); @@ -95,6 +97,14 @@ pub fn create_vmm( resources.pci_enabled = pci_enabled; + if memory_hotplug_enabled { + resources.memory_hotplug = Some(MemoryHotplugConfig { + total_size_mib: 1024, + block_size_mib: 2, + slot_size_mib: 128, + }); + } + let vmm = build_microvm_for_boot( &InstanceInfo::default(), &resources, @@ -111,23 +121,15 @@ pub fn create_vmm( } pub fn default_vmm(kernel_image: Option<&str>) -> (Arc>, EventManager) { - create_vmm(kernel_image, false, true, false) + create_vmm(kernel_image, false, true, false, false) } pub fn default_vmm_no_boot(kernel_image: Option<&str>) -> (Arc>, EventManager) { - create_vmm(kernel_image, false, false, false) -} - -pub fn default_vmm_pci_no_boot(kernel_image: Option<&str>) -> (Arc>, EventManager) { - create_vmm(kernel_image, false, false, true) + create_vmm(kernel_image, false, false, false, false) } pub fn dirty_tracking_vmm(kernel_image: Option<&str>) -> (Arc>, EventManager) { - create_vmm(kernel_image, true, true, false) -} - -pub fn default_vmm_pci(kernel_image: Option<&str>) -> (Arc>, EventManager) { - create_vmm(kernel_image, false, true, false) + create_vmm(kernel_image, true, true, false, false) } #[allow(clippy::undocumented_unsafe_blocks)] diff --git a/src/vmm/src/utils/mod.rs b/src/vmm/src/utils/mod.rs index b5d5f94c7ff..1288abef0ba 100644 --- a/src/vmm/src/utils/mod.rs +++ b/src/vmm/src/utils/mod.rs @@ -59,6 +59,11 @@ pub const fn mib_to_bytes(mib: usize) -> usize { mib << MIB_TO_BYTES_SHIFT } +/// Converts Bytes to MiB, truncating any remainder +pub const fn bytes_to_mib(bytes: usize) -> usize { + bytes >> MIB_TO_BYTES_SHIFT +} + /// Align address up to the aligment. pub const fn align_up(addr: u64, align: u64) -> u64 { debug_assert!(align != 0); diff --git a/src/vmm/src/vmm_config/memory_hotplug.rs b/src/vmm/src/vmm_config/memory_hotplug.rs new file mode 100644 index 00000000000..85cf45ee5e8 --- /dev/null +++ b/src/vmm/src/vmm_config/memory_hotplug.rs @@ -0,0 +1,227 @@ +// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use serde::{Deserialize, Serialize}; + +use crate::devices::virtio::mem::{ + VIRTIO_MEM_DEFAULT_BLOCK_SIZE_MIB, VIRTIO_MEM_DEFAULT_SLOT_SIZE_MIB, +}; + +/// Errors associated with memory hotplug configuration. +#[derive(Debug, thiserror::Error, displaydoc::Display)] +pub enum MemoryHotplugConfigError { + /// Block size must not be lower than {0} MiB + BlockSizeTooSmall(usize), + /// Block size must be a power of 2 + BlockSizeNotPowerOfTwo, + /// Slot size must not be lower than {0} MiB + SlotSizeTooSmall(usize), + /// Slot size must be a multiple of block size ({0} MiB) + SlotSizeNotMultipleOfBlockSize(usize), + /// Total size must not be lower than slot size ({0} MiB) + TotalSizeTooSmall(usize), + /// Total size must be a multiple of slot size ({0} MiB) + TotalSizeNotMultipleOfSlotSize(usize), +} + +fn default_block_size_mib() -> usize { + VIRTIO_MEM_DEFAULT_BLOCK_SIZE_MIB +} + +fn default_slot_size_mib() -> usize { + VIRTIO_MEM_DEFAULT_SLOT_SIZE_MIB +} + +/// Configuration for memory hotplug device. +#[derive(Clone, Debug, Default, PartialEq, Eq, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] +pub struct MemoryHotplugConfig { + /// Total memory size in MiB that can be hotplugged. + pub total_size_mib: usize, + /// Block size in MiB. A block is the smallest unit the guest can hot(un)plug + #[serde(default = "default_block_size_mib")] + pub block_size_mib: usize, + /// Slot size in MiB. A slot is the smallest unit the host can (de)attach memory + #[serde(default = "default_slot_size_mib")] + pub slot_size_mib: usize, +} + +impl MemoryHotplugConfig { + /// Validates the configuration. + pub fn validate(&self) -> Result<(), MemoryHotplugConfigError> { + let min_block_size_mib = VIRTIO_MEM_DEFAULT_BLOCK_SIZE_MIB; + if self.block_size_mib < min_block_size_mib { + return Err(MemoryHotplugConfigError::BlockSizeTooSmall( + min_block_size_mib, + )); + } + if !self.block_size_mib.is_power_of_two() { + return Err(MemoryHotplugConfigError::BlockSizeNotPowerOfTwo); + } + + let min_slot_size_mib = VIRTIO_MEM_DEFAULT_SLOT_SIZE_MIB; + if self.slot_size_mib < min_slot_size_mib { + return Err(MemoryHotplugConfigError::SlotSizeTooSmall( + min_slot_size_mib, + )); + } + if !self.slot_size_mib.is_multiple_of(self.block_size_mib) { + return Err(MemoryHotplugConfigError::SlotSizeNotMultipleOfBlockSize( + self.block_size_mib, + )); + } + + if self.total_size_mib < self.slot_size_mib { + return Err(MemoryHotplugConfigError::TotalSizeTooSmall( + self.slot_size_mib, + )); + } + if !self.total_size_mib.is_multiple_of(self.slot_size_mib) { + return Err(MemoryHotplugConfigError::TotalSizeNotMultipleOfSlotSize( + self.slot_size_mib, + )); + } + + Ok(()) + } +} + +/// Configuration for memory hotplug device. +#[derive(Clone, Debug, Default, PartialEq, Eq, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] +pub struct MemoryHotplugSizeUpdate { + /// Requested size in MiB to resize the hotpluggable memory to. + pub requested_size_mib: usize, +} + +#[cfg(test)] +mod tests { + use serde_json; + + use super::*; + + #[test] + fn test_valid_config() { + let config = MemoryHotplugConfig { + total_size_mib: 1024, + block_size_mib: 2, + slot_size_mib: 128, + }; + config.validate().unwrap(); + } + + #[test] + fn test_block_size_too_small() { + let config = MemoryHotplugConfig { + total_size_mib: 1024, + block_size_mib: 1, + slot_size_mib: 128, + }; + match config.validate() { + Err(MemoryHotplugConfigError::BlockSizeTooSmall(min)) => assert_eq!(min, 2), + _ => panic!("Expected InvalidBlockSizeTooSmall error"), + } + } + + #[test] + fn test_block_size_not_power_of_two() { + let config = MemoryHotplugConfig { + total_size_mib: 1024, + block_size_mib: 3, + slot_size_mib: 128, + }; + match config.validate() { + Err(MemoryHotplugConfigError::BlockSizeNotPowerOfTwo) => {} + _ => panic!("Expected InvalidBlockSizePowerOfTwo error"), + } + } + + #[test] + fn test_slot_size_too_small() { + let config = MemoryHotplugConfig { + total_size_mib: 1024, + block_size_mib: 2, + slot_size_mib: 1, + }; + match config.validate() { + Err(MemoryHotplugConfigError::SlotSizeTooSmall(min)) => assert_eq!(min, 128), + _ => panic!("Expected InvalidSlotSizeTooSmall error"), + } + } + + #[test] + fn test_slot_size_not_multiple_of_block_size() { + let config = MemoryHotplugConfig { + total_size_mib: 1024, + block_size_mib: 4, + slot_size_mib: 130, + }; + match config.validate() { + Err(MemoryHotplugConfigError::SlotSizeNotMultipleOfBlockSize(block_size)) => { + assert_eq!(block_size, 4) + } + _ => panic!("Expected InvalidSlotSizeMultiple error"), + } + } + + #[test] + fn test_total_size_too_small() { + let config = MemoryHotplugConfig { + total_size_mib: 64, + block_size_mib: 2, + slot_size_mib: 128, + }; + match config.validate() { + Err(MemoryHotplugConfigError::TotalSizeTooSmall(slot_size)) => { + assert_eq!(slot_size, 128) + } + _ => panic!("Expected InvalidTotalSizeTooSmall error"), + } + } + + #[test] + fn test_total_size_not_multiple_of_slot_size() { + let config = MemoryHotplugConfig { + total_size_mib: 1000, + block_size_mib: 2, + slot_size_mib: 128, + }; + match config.validate() { + Err(MemoryHotplugConfigError::TotalSizeNotMultipleOfSlotSize(slot_size)) => { + assert_eq!(slot_size, 128) + } + _ => panic!("Expected InvalidTotalSizeMultiple error"), + } + } + + #[test] + fn test_defaults() { + assert_eq!(default_block_size_mib(), 2); + assert_eq!(default_slot_size_mib(), 128); + + let json = r#"{ + "total_size_mib": 1024 + }"#; + let deserialized: MemoryHotplugConfig = serde_json::from_str(json).unwrap(); + assert_eq!( + deserialized, + MemoryHotplugConfig { + total_size_mib: 1024, + block_size_mib: 2, + slot_size_mib: 128, + } + ); + } + + #[test] + fn test_serde() { + let config = MemoryHotplugConfig { + total_size_mib: 1024, + block_size_mib: 4, + slot_size_mib: 256, + }; + let json = serde_json::to_string(&config).unwrap(); + let deserialized: MemoryHotplugConfig = serde_json::from_str(&json).unwrap(); + assert_eq!(config, deserialized); + } +} diff --git a/src/vmm/src/vmm_config/mod.rs b/src/vmm/src/vmm_config/mod.rs index 3f544751142..9a4c104ce3a 100644 --- a/src/vmm/src/vmm_config/mod.rs +++ b/src/vmm/src/vmm_config/mod.rs @@ -20,6 +20,8 @@ pub mod entropy; pub mod instance_info; /// Wrapper for configuring the memory and CPU of the microVM. pub mod machine_config; +/// Wrapper for configuring memory hotplug. +pub mod memory_hotplug; /// Wrapper for configuring the metrics. pub mod metrics; /// Wrapper for configuring the MMDS. diff --git a/src/vmm/src/vstate/memory.rs b/src/vmm/src/vstate/memory.rs index 2d3f6a1b724..de96630dcf5 100644 --- a/src/vmm/src/vstate/memory.rs +++ b/src/vmm/src/vstate/memory.rs @@ -8,8 +8,9 @@ use std::fs::File; use std::io::SeekFrom; use std::ops::Deref; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; +use bitvec::vec::BitVec; use kvm_bindings::{KVM_MEM_LOG_DIRTY_PAGES, kvm_userspace_memory_region}; use log::error; use serde::{Deserialize, Serialize}; @@ -23,9 +24,10 @@ pub use vm_memory::{ use vm_memory::{GuestMemoryError, GuestMemoryRegionBytes, VolatileSlice, WriteVolatile}; use vmm_sys_util::errno; -use crate::DirtyBitmap; use crate::utils::{get_page_size, u64_to_usize}; use crate::vmm_config::machine_config::HugePageConfig; +use crate::vstate::vm::VmError; +use crate::{DirtyBitmap, Vm}; /// Type of GuestRegionMmap. pub type GuestRegionMmap = vm_memory::GuestRegionMmap>; @@ -53,6 +55,10 @@ pub enum MemoryError { OffsetTooLarge, /// Cannot retrieve snapshot file metadata: {0} FileMetadata(std::io::Error), + /// Memory region state is invalid: {0} + InvalidRegionState(&'static str), + /// Error protecting memory slot: {0} + Mprotect(std::io::Error), } /// Type of the guest region @@ -60,55 +66,289 @@ pub enum MemoryError { pub enum GuestRegionType { /// Guest DRAM Dram, + /// Hotpluggable memory + Hotpluggable, } -/// An extension to GuestMemoryRegion that stores the type of region, and the KVM slot -/// number. +/// An extension to GuestMemoryRegion that can be split into multiple KVM slots of +/// the same slot_size, and stores the type of region, and the starting KVM slot number. #[derive(Debug)] pub struct GuestRegionMmapExt { /// the wrapped GuestRegionMmap pub inner: GuestRegionMmap, /// the type of region pub region_type: GuestRegionType, - /// the KVM slot number assigned to this region - pub slot: u32, + /// the starting KVM slot number assigned to this region + pub slot_from: u32, + /// the size of the slots of this region + pub slot_size: usize, + /// a bitvec indicating whether slot `i` is plugged into KVM (1) or not (0) + pub plugged: Mutex, +} + +/// A guest memory slot, which is a slice of a guest memory region +#[derive(Debug)] +pub struct GuestMemorySlot<'a> { + /// KVM memory slot number + pub(crate) slot: u32, + /// Start guest address of the slot + pub(crate) guest_addr: GuestAddress, + /// Corresponding slice in host memory + pub(crate) slice: VolatileSlice<'a, BS<'a, Option>>, +} + +impl From<&GuestMemorySlot<'_>> for kvm_userspace_memory_region { + fn from(mem_slot: &GuestMemorySlot) -> Self { + let flags = if mem_slot.slice.bitmap().is_some() { + KVM_MEM_LOG_DIRTY_PAGES + } else { + 0 + }; + kvm_userspace_memory_region { + flags, + slot: mem_slot.slot, + guest_phys_addr: mem_slot.guest_addr.raw_value(), + memory_size: mem_slot.slice.len() as u64, + userspace_addr: mem_slot.slice.ptr_guard().as_ptr() as u64, + } + } +} + +impl<'a> GuestMemorySlot<'a> { + /// Dumps the dirty pages in this slot onto the writer + pub(crate) fn dump_dirty( + &self, + writer: &mut T, + kvm_bitmap: &[u64], + page_size: usize, + ) -> Result<(), GuestMemoryError> { + let firecracker_bitmap = self.slice.bitmap(); + let mut write_size = 0; + let mut skip_size = 0; + let mut dirty_batch_start = 0; + + for (i, v) in kvm_bitmap.iter().enumerate() { + for j in 0..64 { + let is_kvm_page_dirty = ((v >> j) & 1u64) != 0u64; + let page_offset = ((i * 64) + j) * page_size; + let is_firecracker_page_dirty = firecracker_bitmap.dirty_at(page_offset); + + if is_kvm_page_dirty || is_firecracker_page_dirty { + // We are at the start of a new batch of dirty pages. + if skip_size > 0 { + // Seek forward over the unmodified pages. + writer + .seek(SeekFrom::Current(skip_size.try_into().unwrap())) + .unwrap(); + dirty_batch_start = page_offset; + skip_size = 0; + } + write_size += page_size; + } else { + // We are at the end of a batch of dirty pages. + if write_size > 0 { + // Dump the dirty pages. + let slice = &self.slice.subslice(dirty_batch_start, write_size)?; + writer.write_all_volatile(slice)?; + write_size = 0; + } + skip_size += page_size; + } + } + } + + if write_size > 0 { + writer.write_all_volatile(&self.slice.subslice(dirty_batch_start, write_size)?)?; + } + + Ok(()) + } + + /// Makes the slot host memory PROT_NONE (true) or PROT_READ|PROT_WRITE (false) + pub(crate) fn protect(&self, protected: bool) -> Result<(), MemoryError> { + let prot = if protected { + libc::PROT_NONE + } else { + libc::PROT_READ | libc::PROT_WRITE + }; + // SAFETY: Parameters refer to an existing host memory region + let ret = unsafe { + libc::mprotect( + self.slice.ptr_guard_mut().as_ptr().cast(), + self.slice.len(), + prot, + ) + }; + if ret != 0 { + Err(MemoryError::Mprotect(std::io::Error::last_os_error())) + } else { + Ok(()) + } + } +} + +fn addr_in_range(addr: GuestAddress, start: GuestAddress, len: usize) -> bool { + if let Some(end) = start.checked_add(len as u64) { + addr >= start && addr < end + } else { + false + } } impl GuestRegionMmapExt { + /// Adds a DRAM region which only contains a single plugged slot pub(crate) fn dram_from_mmap_region(region: GuestRegionMmap, slot: u32) -> Self { + let slot_size = u64_to_usize(region.len()); GuestRegionMmapExt { inner: region, region_type: GuestRegionType::Dram, - slot, + slot_from: slot, + slot_size, + plugged: Mutex::new(BitVec::repeat(true, 1)), + } + } + + /// Adds an hotpluggable region which can contain multiple slots and is initially unplugged + pub(crate) fn hotpluggable_from_mmap_region( + region: GuestRegionMmap, + slot_from: u32, + slot_size: usize, + ) -> Self { + let slot_cnt = (u64_to_usize(region.len())) / slot_size; + + GuestRegionMmapExt { + inner: region, + region_type: GuestRegionType::Hotpluggable, + slot_from, + slot_size, + plugged: Mutex::new(BitVec::repeat(false, slot_cnt)), } } pub(crate) fn from_state( region: GuestRegionMmap, state: &GuestMemoryRegionState, - slot: u32, + slot_from: u32, ) -> Result { + let slot_cnt = state.plugged.len(); + let slot_size = u64_to_usize(region.len()).checked_div(slot_cnt).ok_or( + MemoryError::InvalidRegionState("memory region should be aligned to the slot size"), + )?; + + // validate the region state to avoid spurious crashes when resuming from an invalid state + if state.region_type == GuestRegionType::Dram { + if slot_cnt != 1 { + return Err(MemoryError::InvalidRegionState( + "DRAM region should contain only one slot", + )); + } + if !state.plugged[0] { + return Err(MemoryError::InvalidRegionState( + "DRAM region should be plugged", + )); + } + } + Ok(GuestRegionMmapExt { inner: region, + slot_size, region_type: state.region_type, + slot_from, + plugged: Mutex::new(BitVec::from_iter(state.plugged.iter())), + }) + } + + pub(crate) fn slot_cnt(&self) -> u32 { + u32::try_from(u64_to_usize(self.len()) / self.slot_size).unwrap() + } + + pub(crate) fn mem_slot(&self, slot: u32) -> GuestMemorySlot<'_> { + assert!(slot >= self.slot_from && slot < self.slot_from + self.slot_cnt()); + + let offset = ((slot - self.slot_from) as u64) * (self.slot_size as u64); + + GuestMemorySlot { slot, + guest_addr: self.start_addr().unchecked_add(offset), + slice: self + .inner + .get_slice(MemoryRegionAddress(offset), self.slot_size) + .expect("slot range should be valid"), + } + } + + /// Returns a snapshot of the slots and their state at the time of calling + /// + /// Note: to avoid TOCTOU races use only within VMM thread. + pub(crate) fn slots(&self) -> impl Iterator, bool)> { + self.plugged + .lock() + .unwrap() + .iter() + .enumerate() + .map(|(i, b)| { + ( + self.mem_slot(self.slot_from + u32::try_from(i).unwrap()), + *b, + ) + }) + .collect::>() + .into_iter() + } + + /// Returns a snapshot of the plugged slots at the time of calling + /// + /// Note: to avoid TOCTOU races use only within VMM thread. + pub(crate) fn plugged_slots(&self) -> impl Iterator> { + self.slots() + .filter(|(_, plugged)| *plugged) + .map(|(slot, _)| slot) + } + + pub(crate) fn slots_intersecting_range( + &self, + from: GuestAddress, + len: usize, + ) -> impl Iterator> { + self.slots().map(|(slot, _)| slot).filter(move |slot| { + if let Some(slot_end) = slot.guest_addr.checked_add(slot.slice.len() as u64) { + addr_in_range(slot.guest_addr, from, len) || addr_in_range(slot_end, from, len) + } else { + false + } }) } - pub(crate) fn kvm_userspace_memory_region(&self) -> kvm_userspace_memory_region { - let flags = if self.inner.bitmap().is_some() { - KVM_MEM_LOG_DIRTY_PAGES - } else { - 0 - }; + /// (un)plug a slot from an Hotpluggable memory region + pub(crate) fn update_slot( + &self, + vm: &Vm, + mem_slot: &GuestMemorySlot<'_>, + plug: bool, + ) -> Result<(), VmError> { + // This function can only be called on hotpluggable regions! + assert!(self.region_type == GuestRegionType::Hotpluggable); + + let mut bitmap_guard = self.plugged.lock().unwrap(); + let prev = bitmap_guard.replace((mem_slot.slot - self.slot_from) as usize, plug); + // do not do anything if the state is what we're trying to set + if prev == plug { + return Ok(()); + } - kvm_userspace_memory_region { - flags, - slot: self.slot, - guest_phys_addr: self.inner.start_addr().raw_value(), - memory_size: self.inner.len(), - userspace_addr: self.inner.as_ptr() as u64, + let mut kvm_region = kvm_userspace_memory_region::from(mem_slot); + if plug { + // make it accessible _before_ adding it to KVM + mem_slot.protect(false)?; + vm.set_user_memory_region(kvm_region)?; + } else { + // to remove it we need to pass a size of zero + kvm_region.memory_size = 0; + vm.set_user_memory_region(kvm_region)?; + // make it protected _after_ removing it from KVM + mem_slot.protect(true)?; } + Ok(()) } pub(crate) fn discard_range( @@ -329,7 +569,7 @@ where fn mark_dirty(&self, addr: GuestAddress, len: usize); /// Dumps all contents of GuestMemoryMmap to a writer. - fn dump(&self, writer: &mut T) -> Result<(), MemoryError>; + fn dump(&self, writer: &mut T) -> Result<(), MemoryError>; /// Dumps all pages of GuestMemoryMmap present in `dirty_bitmap` to a writer. fn dump_dirty( @@ -369,6 +609,8 @@ pub struct GuestMemoryRegionState { pub size: usize, /// Region type pub region_type: GuestRegionType, + /// Plugged/unplugged status of each slot + pub plugged: Vec, } /// Describes guest memory regions and their snapshot file mappings. @@ -397,6 +639,7 @@ impl GuestMemoryExtension for GuestMemoryMmap { base_address: region.start_addr().0, size: u64_to_usize(region.len()), region_type: region.region_type, + plugged: region.plugged.lock().unwrap().iter().by_vals().collect(), }); }); guest_memory_state @@ -411,9 +654,18 @@ impl GuestMemoryExtension for GuestMemoryMmap { } /// Dumps all contents of GuestMemoryMmap to a writer. - fn dump(&self, writer: &mut T) -> Result<(), MemoryError> { + fn dump(&self, writer: &mut T) -> Result<(), MemoryError> { self.iter() - .try_for_each(|region| Ok(writer.write_all_volatile(®ion.as_volatile_slice()?)?)) + .flat_map(|region| region.slots()) + .try_for_each(|(mem_slot, plugged)| { + if !plugged { + let ilen = i64::try_from(mem_slot.slice.len()).unwrap(); + writer.seek(SeekFrom::Current(ilen)).unwrap(); + } else { + writer.write_all_volatile(&mem_slot.slice)?; + } + Ok(()) + }) .map_err(MemoryError::WriteMemory) } @@ -423,52 +675,21 @@ impl GuestMemoryExtension for GuestMemoryMmap { writer: &mut T, dirty_bitmap: &DirtyBitmap, ) -> Result<(), MemoryError> { - let mut writer_offset = 0; let page_size = get_page_size().map_err(MemoryError::PageSize)?; - let write_result = self.iter().try_for_each(|region| { - let kvm_bitmap = dirty_bitmap.get(®ion.slot).unwrap(); - let firecracker_bitmap = region.bitmap(); - let mut write_size = 0; - let mut dirty_batch_start: u64 = 0; - - for (i, v) in kvm_bitmap.iter().enumerate() { - for j in 0..64 { - let is_kvm_page_dirty = ((v >> j) & 1u64) != 0u64; - let page_offset = ((i * 64) + j) * page_size; - let is_firecracker_page_dirty = firecracker_bitmap.dirty_at(page_offset); - - if is_kvm_page_dirty || is_firecracker_page_dirty { - // We are at the start of a new batch of dirty pages. - if write_size == 0 { - // Seek forward over the unmodified pages. - writer - .seek(SeekFrom::Start(writer_offset + page_offset as u64)) - .unwrap(); - dirty_batch_start = page_offset as u64; - } - write_size += page_size; - } else if write_size > 0 { - // We are at the end of a batch of dirty pages. - writer.write_all_volatile( - ®ion - .get_slice(MemoryRegionAddress(dirty_batch_start), write_size)?, - )?; - - write_size = 0; + let write_result = + self.iter() + .flat_map(|region| region.slots()) + .try_for_each(|(mem_slot, plugged)| { + if !plugged { + let ilen = i64::try_from(mem_slot.slice.len()).unwrap(); + writer.seek(SeekFrom::Current(ilen)).unwrap(); + } else { + let kvm_bitmap = dirty_bitmap.get(&mem_slot.slot).unwrap(); + mem_slot.dump_dirty(writer, kvm_bitmap, page_size)?; } - } - } - - if write_size > 0 { - writer.write_all_volatile( - ®ion.get_slice(MemoryRegionAddress(dirty_batch_start), write_size)?, - )?; - } - writer_offset += region.len(); - - Ok(()) - }); + Ok(()) + }); if write_result.is_err() { self.store_dirty_bitmap(dirty_bitmap, page_size); @@ -490,22 +711,24 @@ impl GuestMemoryExtension for GuestMemoryMmap { /// Stores the dirty bitmap inside into the internal bitmap fn store_dirty_bitmap(&self, dirty_bitmap: &DirtyBitmap, page_size: usize) { - self.iter().for_each(|region| { - let kvm_bitmap = dirty_bitmap.get(®ion.slot).unwrap(); - let firecracker_bitmap = region.bitmap(); + self.iter() + .flat_map(|region| region.plugged_slots()) + .for_each(|mem_slot| { + let kvm_bitmap = dirty_bitmap.get(&mem_slot.slot).unwrap(); + let firecracker_bitmap = mem_slot.slice.bitmap(); - for (i, v) in kvm_bitmap.iter().enumerate() { - for j in 0..64 { - let is_kvm_page_dirty = ((v >> j) & 1u64) != 0u64; + for (i, v) in kvm_bitmap.iter().enumerate() { + for j in 0..64 { + let is_kvm_page_dirty = ((v >> j) & 1u64) != 0u64; - if is_kvm_page_dirty { - let page_offset = ((i * 64) + j) * page_size; + if is_kvm_page_dirty { + let page_offset = ((i * 64) + j) * page_size; - firecracker_bitmap.mark_dirty(page_offset, 1) + firecracker_bitmap.mark_dirty(page_offset, 1) + } } } - } - }); + }); } fn try_for_each_region_in_range( @@ -788,11 +1011,13 @@ mod tests { base_address: 0, size: page_size, region_type: GuestRegionType::Dram, + plugged: vec![true], }, GuestMemoryRegionState { base_address: page_size as u64 * 2, size: page_size, region_type: GuestRegionType::Dram, + plugged: vec![true], }, ], }; @@ -815,11 +1040,13 @@ mod tests { base_address: 0, size: page_size * 3, region_type: GuestRegionType::Dram, + plugged: vec![true], }, GuestMemoryRegionState { base_address: page_size as u64 * 4, size: page_size * 3, region_type: GuestRegionType::Dram, + plugged: vec![true], }, ], }; diff --git a/src/vmm/src/vstate/vm.rs b/src/vmm/src/vstate/vm.rs index 3cc2319b360..83e899eff1d 100644 --- a/src/vmm/src/vstate/vm.rs +++ b/src/vmm/src/vstate/vm.rs @@ -16,7 +16,7 @@ use std::sync::{Arc, Mutex, MutexGuard}; use kvm_bindings::KVM_IRQCHIP_IOAPIC; use kvm_bindings::{ KVM_IRQ_ROUTING_IRQCHIP, KVM_IRQ_ROUTING_MSI, KVM_MSI_VALID_DEVID, KvmIrqRouting, - kvm_irq_routing_entry, + kvm_irq_routing_entry, kvm_userspace_memory_region, }; use kvm_ioctls::VmFd; use log::debug; @@ -29,7 +29,6 @@ use crate::arch::{GSI_MSI_END, host_page_size}; use crate::logger::info; use crate::pci::{DeviceRelocation, DeviceRelocationError, PciDevice}; use crate::persist::CreateSnapshotError; -use crate::utils::u64_to_usize; use crate::vmm_config::snapshot::SnapshotType; use crate::vstate::bus::Bus; use crate::vstate::interrupts::{InterruptError, MsixVector, MsixVectorConfig, MsixVectorGroup}; @@ -165,9 +164,12 @@ impl Vm { Ok((vcpus, exit_evt)) } - /// Obtain the next free kvm slot id - pub fn next_kvm_slot(&self) -> Option { - let next = self.common.next_kvm_slot.fetch_add(1, Ordering::Relaxed); + /// Reserves the next `slot_cnt` contiguous kvm slot ids and returns the first one + pub fn next_kvm_slot(&self, slot_cnt: u32) -> Option { + let next = self + .common + .next_kvm_slot + .fetch_add(slot_cnt, Ordering::Relaxed); if self.common.max_memslots <= next { None } else { @@ -175,18 +177,32 @@ impl Vm { } } + pub(crate) fn set_user_memory_region( + &self, + region: kvm_userspace_memory_region, + ) -> Result<(), VmError> { + // SAFETY: Safe because the fd is a valid KVM file descriptor. + unsafe { + self.fd() + .set_user_memory_region(region) + .map_err(VmError::SetUserMemoryRegion) + } + } + fn register_memory_region(&mut self, region: Arc) -> Result<(), VmError> { let new_guest_memory = self .common .guest_memory .insert_region(Arc::clone(®ion))?; - // SAFETY: Safe because the fd is a valid KVM file descriptor. - unsafe { - self.fd() - .set_user_memory_region(region.kvm_userspace_memory_region()) - .map_err(VmError::SetUserMemoryRegion)?; - } + region + .slots() + .try_for_each(|(ref slot, plugged)| match plugged { + // if the slot is plugged, add it to kvm user memory regions + true => self.set_user_memory_region(slot.into()), + // if the slot is not plugged, protect accesses to it + false => slot.protect(true).map_err(VmError::MemoryError), + })?; self.common.guest_memory = new_guest_memory; @@ -200,7 +216,7 @@ impl Vm { ) -> Result<(), VmError> { for region in regions { let next_slot = self - .next_kvm_slot() + .next_kvm_slot(1) .ok_or(VmError::NotEnoughMemorySlots(self.common.max_memslots))?; let arcd_region = @@ -212,6 +228,27 @@ impl Vm { Ok(()) } + /// Register a new hotpluggable region to this [`Vm`]. + pub fn register_hotpluggable_memory_region( + &mut self, + region: GuestRegionMmap, + slot_size: usize, + ) -> Result<(), VmError> { + // caller should ensure the slot size divides the region length. + assert!(region.len().is_multiple_of(slot_size as u64)); + let slot_cnt = (region.len() / (slot_size as u64)) + .try_into() + .map_err(|_| VmError::NotEnoughMemorySlots(self.common.max_memslots))?; + let slot_from = self + .next_kvm_slot(slot_cnt) + .ok_or(VmError::NotEnoughMemorySlots(self.common.max_memslots))?; + let arcd_region = Arc::new(GuestRegionMmapExt::hotpluggable_from_mmap_region( + region, slot_from, slot_size, + )); + + self.register_memory_region(arcd_region) + } + /// Register a list of new memory regions to this [`Vm`]. /// /// Note: regions and state.regions need to be in the same order. @@ -221,8 +258,14 @@ impl Vm { state: &GuestMemoryState, ) -> Result<(), VmError> { for (region, state) in regions.into_iter().zip(state.regions.iter()) { + let slot_cnt = state + .plugged + .len() + .try_into() + .map_err(|_| VmError::NotEnoughMemorySlots(self.common.max_memslots))?; + let next_slot = self - .next_kvm_slot() + .next_kvm_slot(slot_cnt) .ok_or(VmError::NotEnoughMemorySlots(self.common.max_memslots))?; let arcd_region = Arc::new(GuestRegionMmapExt::from_state(region, state, next_slot)?); @@ -253,26 +296,31 @@ impl Vm { /// Resets the KVM dirty bitmap for each of the guest's memory regions. pub fn reset_dirty_bitmap(&self) { - self.guest_memory().iter().for_each(|region| { - let _ = self - .fd() - .get_dirty_log(region.slot, u64_to_usize(region.len())); - }); + self.guest_memory() + .iter() + .flat_map(|region| region.plugged_slots()) + .for_each(|mem_slot| { + let _ = self.fd().get_dirty_log(mem_slot.slot, mem_slot.slice.len()); + }); } /// Retrieves the KVM dirty bitmap for each of the guest's memory regions. pub fn get_dirty_bitmap(&self) -> Result { self.guest_memory() .iter() - .map(|region| { - let bitmap = match region.bitmap() { + .flat_map(|region| region.plugged_slots()) + .map(|mem_slot| { + let bitmap = match mem_slot.slice.bitmap() { Some(_) => self .fd() - .get_dirty_log(region.slot, u64_to_usize(region.len())) + .get_dirty_log(mem_slot.slot, mem_slot.slice.len()) .map_err(VmError::GetDirtyLog)?, - None => mincore_bitmap(®ion.inner)?, + None => mincore_bitmap( + mem_slot.slice.ptr_guard_mut().as_ptr(), + mem_slot.slice.len(), + )?, }; - Ok((region.slot, bitmap)) + Ok((mem_slot.slot, bitmap)) }) .collect() } @@ -455,7 +503,7 @@ impl Vm { /// Use `mincore(2)` to overapproximate the dirty bitmap for the given memslot. To be used /// if a diff snapshot is requested, but dirty page tracking wasn't enabled. -fn mincore_bitmap(region: &GuestRegionMmap) -> Result, VmError> { +fn mincore_bitmap(addr: *mut u8, len: usize) -> Result, VmError> { // TODO: Once Host 5.10 goes out of support, we can make this more robust and work on // swap-enabled systems, by doing mlock2(MLOCK_ONFAULT)/munlock() in this function (to // force swapped-out pages to get paged in, so that mincore will consider them incore). @@ -466,8 +514,8 @@ fn mincore_bitmap(region: &GuestRegionMmap) -> Result, VmError> { // is a hugetlbfs VMA (e.g. to report a single hugepage as "present", mincore will // give us 512 4k markers with the lowest bit set). let page_size = host_page_size(); - let mut mincore_bitmap = vec![0u8; u64_to_usize(region.len()) / page_size]; - let mut bitmap = vec![0u64; (u64_to_usize(region.len()) / page_size).div_ceil(64)]; + let mut mincore_bitmap = vec![0u8; len / page_size]; + let mut bitmap = vec![0u64; (len / page_size).div_ceil(64)]; // SAFETY: The safety invariants of GuestRegionMmap ensure that region.as_ptr() is a valid // userspace mapping of size region.len() bytes. The bitmap has exactly one byte for each @@ -475,13 +523,7 @@ fn mincore_bitmap(region: &GuestRegionMmap) -> Result, VmError> { // KVM_MEM_LOG_DIRTY_PAGES, but rather it uses 8 bits per page (e.g. 1 byte), setting the // least significant bit to 1 if the page corresponding to a byte is in core (available in // the page cache and resolvable via just a minor page fault). - let r = unsafe { - libc::mincore( - region.as_ptr().cast::(), - u64_to_usize(region.len()), - mincore_bitmap.as_mut_ptr(), - ) - }; + let r = unsafe { libc::mincore(addr.cast(), len, mincore_bitmap.as_mut_ptr()) }; if r != 0 { return Err(VmError::Mincore(vmm_sys_util::errno::Error::last())); diff --git a/src/vmm/tests/integration_tests.rs b/src/vmm/tests/integration_tests.rs index 4abbedc4530..6a5e6a08a14 100644 --- a/src/vmm/tests/integration_tests.rs +++ b/src/vmm/tests/integration_tests.rs @@ -18,9 +18,7 @@ use vmm::rpc_interface::{ use vmm::seccomp::get_empty_filters; use vmm::snapshot::Snapshot; use vmm::test_utils::mock_resources::{MockVmResources, NOISY_KERNEL_IMAGE}; -use vmm::test_utils::{ - create_vmm, default_vmm, default_vmm_no_boot, default_vmm_pci, default_vmm_pci_no_boot, -}; +use vmm::test_utils::{create_vmm, default_vmm, default_vmm_no_boot}; use vmm::vmm_config::balloon::BalloonDeviceConfig; use vmm::vmm_config::boot_source::BootSourceConfig; use vmm::vmm_config::drive::BlockDeviceConfig; @@ -66,13 +64,12 @@ fn test_build_and_boot_microvm() { assert_eq!(format!("{:?}", vmm_ret.err()), "Some(MissingKernelConfig)"); } - // Success case. - let (vmm, evmgr) = default_vmm(None); - check_booted_microvm(vmm, evmgr); - - // microVM with PCI - let (vmm, evmgr) = default_vmm_pci(None); - check_booted_microvm(vmm, evmgr); + for pci_enabled in [false, true] { + for memory_hotplug in [false, true] { + let (vmm, evmgr) = create_vmm(None, false, true, pci_enabled, memory_hotplug); + check_booted_microvm(vmm, evmgr); + } + } } #[allow(unused_mut, unused_variables)] @@ -96,10 +93,12 @@ fn check_build_microvm(vmm: Arc>, mut evmgr: EventManager) { #[test] fn test_build_microvm() { - let (vmm, evtmgr) = default_vmm_no_boot(None); - check_build_microvm(vmm, evtmgr); - let (vmm, evtmgr) = default_vmm_pci_no_boot(None); - check_build_microvm(vmm, evtmgr); + for pci_enabled in [false, true] { + for memory_hotplug in [false, true] { + let (vmm, evmgr) = create_vmm(None, false, false, pci_enabled, memory_hotplug); + check_build_microvm(vmm, evmgr); + } + } } fn pause_resume_microvm(vmm: Arc>) { @@ -118,13 +117,14 @@ fn pause_resume_microvm(vmm: Arc>) { #[test] fn test_pause_resume_microvm() { - // Tests that pausing and resuming a microVM work as expected. - let (vmm, _) = default_vmm(None); + for pci_enabled in [false, true] { + for memory_hotplug in [false, true] { + // Tests that pausing and resuming a microVM work as expected. + let (vmm, _) = create_vmm(None, false, true, pci_enabled, memory_hotplug); - pause_resume_microvm(vmm); - - let (vmm, _) = default_vmm_pci(None); - pause_resume_microvm(vmm); + pause_resume_microvm(vmm); + } + } } #[test] @@ -195,11 +195,21 @@ fn test_disallow_dump_cpu_config_without_pausing() { vmm.lock().unwrap().stop(FcExitCode::Ok); } -fn verify_create_snapshot(is_diff: bool, pci_enabled: bool) -> (TempFile, TempFile) { +fn verify_create_snapshot( + is_diff: bool, + pci_enabled: bool, + memory_hotplug: bool, +) -> (TempFile, TempFile) { let snapshot_file = TempFile::new().unwrap(); let memory_file = TempFile::new().unwrap(); - let (vmm, _) = create_vmm(Some(NOISY_KERNEL_IMAGE), is_diff, true, pci_enabled); + let (vmm, _) = create_vmm( + Some(NOISY_KERNEL_IMAGE), + is_diff, + true, + pci_enabled, + memory_hotplug, + ); let resources = VmResources { machine_config: MachineConfig { mem_size_mib: 1, @@ -303,14 +313,19 @@ fn verify_load_snapshot(snapshot_file: TempFile, memory_file: TempFile) { #[test] fn test_create_and_load_snapshot() { - for (diff_snap, pci_enabled) in [(false, false), (false, true), (true, false), (true, true)] { - // Create snapshot. - let (snapshot_file, memory_file) = verify_create_snapshot(diff_snap, pci_enabled); - // Create a new microVm from snapshot. This only tests code-level logic; it verifies - // that a microVM can be built with no errors from given snapshot. - // It does _not_ verify that the guest is actually restored properly. We're using - // python integration tests for that. - verify_load_snapshot(snapshot_file, memory_file); + for diff_snap in [false, true] { + for pci_enabled in [false, true] { + for memory_hotplug in [false, true] { + // Create snapshot. + let (snapshot_file, memory_file) = + verify_create_snapshot(diff_snap, pci_enabled, memory_hotplug); + // Create a new microVm from snapshot. This only tests code-level logic; it verifies + // that a microVM can be built with no errors from given snapshot. + // It does _not_ verify that the guest is actually restored properly. We're using + // python integration tests for that. + verify_load_snapshot(snapshot_file, memory_file); + } + } } } @@ -338,7 +353,7 @@ fn check_snapshot(mut microvm_state: MicrovmState) { fn get_microvm_state_from_snapshot(pci_enabled: bool) -> MicrovmState { // Create a diff snapshot - let (snapshot_file, _) = verify_create_snapshot(true, pci_enabled); + let (snapshot_file, _) = verify_create_snapshot(true, pci_enabled, false); // Deserialize the microVM state. snapshot_file.as_file().seek(SeekFrom::Start(0)).unwrap(); @@ -346,7 +361,7 @@ fn get_microvm_state_from_snapshot(pci_enabled: bool) -> MicrovmState { } fn verify_load_snap_disallowed_after_boot_resources(res: VmmAction, res_name: &str) { - let (snapshot_file, memory_file) = verify_create_snapshot(false, false); + let (snapshot_file, memory_file) = verify_create_snapshot(false, false, false); let mut event_manager = EventManager::new().unwrap(); let empty_seccomp_filters = get_empty_filters(); diff --git a/tests/framework/guest_stats.py b/tests/framework/guest_stats.py new file mode 100644 index 00000000000..468d7167c44 --- /dev/null +++ b/tests/framework/guest_stats.py @@ -0,0 +1,79 @@ +# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Classes for querying guest stats inside microVMs. +""" + + +class ByteUnit: + """Represents a byte unit that can be converted to other units.""" + + value_bytes: int + + def __init__(self, value_bytes: int): + self.value_bytes = value_bytes + + @classmethod + def from_kib(cls, value_kib: int): + """Creates a ByteUnit from a value in KiB.""" + if value_kib < 0: + raise ValueError("value_kib must be non-negative") + return ByteUnit(value_kib * 1024) + + def bytes(self) -> int: + """Returns the value in B.""" + return self.value_bytes + + def kib(self) -> float: + """Returns the value in KiB as a decimal.""" + return self.value_bytes / 1024 + + def mib(self) -> float: + """Returns the value in MiB as a decimal.""" + return self.value_bytes / (1 << 20) + + def gib(self) -> float: + """Returns the value in GiB as a decimal.""" + return self.value_bytes / (1 << 30) + + +class Meminfo: + """Represents the contents of /proc/meminfo inside the guest""" + + mem_total: ByteUnit + mem_free: ByteUnit + mem_available: ByteUnit + buffers: ByteUnit + cached: ByteUnit + + def __init__(self): + self.mem_total = ByteUnit(0) + self.mem_free = ByteUnit(0) + self.mem_available = ByteUnit(0) + self.buffers = ByteUnit(0) + self.cached = ByteUnit(0) + + +class MeminfoGuest: + """Queries /proc/meminfo inside the guest""" + + def __init__(self, vm): + self.vm = vm + + def get(self) -> Meminfo: + """Returns the contents of /proc/meminfo inside the guest""" + meminfo = Meminfo() + for line in self.vm.ssh.check_output("cat /proc/meminfo").stdout.splitlines(): + parts = line.split() + if parts[0] == "MemTotal:": + meminfo.mem_total = ByteUnit.from_kib(int(parts[1])) + elif parts[0] == "MemFree:": + meminfo.mem_free = ByteUnit.from_kib(int(parts[1])) + elif parts[0] == "MemAvailable:": + meminfo.mem_available = ByteUnit.from_kib(int(parts[1])) + elif parts[0] == "Buffers:": + meminfo.buffers = ByteUnit.from_kib(int(parts[1])) + elif parts[0] == "Cached:": + meminfo.cached = ByteUnit.from_kib(int(parts[1])) + + return meminfo diff --git a/tests/framework/http_api.py b/tests/framework/http_api.py index 0ae2e279571..058bdb00995 100644 --- a/tests/framework/http_api.py +++ b/tests/framework/http_api.py @@ -9,6 +9,8 @@ import requests from requests_unixsocket import DEFAULT_SCHEME, UnixAdapter +from framework.swagger_validator import SwaggerValidator, ValidationError + class Session(requests.Session): """An HTTP over UNIX sockets Session @@ -65,6 +67,21 @@ def get(self): self._api.error_callback("GET", self.resource, str(e)) raise assert res.status_code == HTTPStatus.OK, res.json() + + # Validate response against Swagger specification + # only validate successful requests + if self._api.validator and res.status_code == HTTPStatus.OK: + try: + response_body = res.json() + self._api.validator.validate_response( + "GET", self.resource, 200, response_body + ) + except ValidationError as e: + # Re-raise with more context + raise ValidationError( + f"Response validation failed for GET {self.resource}: {e.message}" + ) from e + return res def request(self, method, path, **kwargs): @@ -85,6 +102,32 @@ def request(self, method, path, **kwargs): elif "error" in json: msg = json["error"] raise RuntimeError(msg, json, res) + + # Validate request against Swagger specification + # do this after the actual request as we only want to validate successful + # requests as the tests may be trying to pass bad requests and assert an + # error is raised. + if self._api.validator: + if kwargs: + try: + self._api.validator.validate_request(method, path, kwargs) + except ValidationError as e: + # Re-raise with more context + raise ValidationError( + f"Request validation failed for {method} {path}: {e.message}" + ) from e + + if res.status_code == HTTPStatus.OK: + try: + response_body = res.json() + self._api.validator.validate_response( + method, path, 200, response_body + ) + except ValidationError as e: + # Re-raise with more context + raise ValidationError( + f"Response validation failed for {method} {path}: {e.message}" + ) from e return res def put(self, **kwargs): @@ -105,13 +148,16 @@ def patch(self, **kwargs): class Api: """A simple HTTP client for the Firecracker API""" - def __init__(self, api_usocket_full_name, *, on_error=None): + def __init__(self, api_usocket_full_name, *, validate=True, on_error=None): self.error_callback = on_error self.socket = api_usocket_full_name url_encoded_path = urllib.parse.quote_plus(api_usocket_full_name) self.endpoint = DEFAULT_SCHEME + url_encoded_path self.session = Session() + # Initialize the swagger validator + self.validator = SwaggerValidator() if validate else None + self.describe = Resource(self, "/") self.vm = Resource(self, "/vm") self.vm_config = Resource(self, "/vm/config") @@ -134,3 +180,4 @@ def __init__(self, api_usocket_full_name, *, on_error=None): self.entropy = Resource(self, "/entropy") self.pmem = Resource(self, "/pmem", "id") self.serial = Resource(self, "/serial") + self.memory_hotplug = Resource(self, "/hotplug/memory") diff --git a/tests/framework/microvm.py b/tests/framework/microvm.py index 74ae180950c..ded0b6e38d6 100644 --- a/tests/framework/microvm.py +++ b/tests/framework/microvm.py @@ -23,10 +23,11 @@ from collections import namedtuple from dataclasses import dataclass from enum import Enum, auto -from functools import lru_cache +from functools import cached_property, lru_cache from pathlib import Path from typing import Optional +import psutil from tenacity import Retrying, retry, stop_after_attempt, wait_fixed import host_tools.cargo_build as build_tools @@ -472,7 +473,7 @@ def state(self): """Get the InstanceInfo property and return the state field.""" return self.api.describe.get().json()["state"] - @property + @cached_property def firecracker_pid(self): """Return Firecracker's PID @@ -491,6 +492,11 @@ def firecracker_pid(self): with attempt: return int(self.jailer.pid_file.read_text(encoding="ascii")) + @cached_property + def ps(self): + """Returns a handle to the psutil.Process for this VM""" + return psutil.Process(self.firecracker_pid) + @property def dimensions(self): """Gets a default set of cloudwatch dimensions describing the configuration of this microvm""" @@ -634,6 +640,7 @@ def spawn( log_show_origin=False, metrics_path="fc.ndjson", emit_metrics: bool = False, + validate_api: bool = True, ): """Start a microVM as a daemon or in a screen session.""" # pylint: disable=subprocess-run-check @@ -641,6 +648,7 @@ def spawn( self.jailer.setup() self.api = Api( self.jailer.api_socket_path(), + validate=validate_api, on_error=lambda verb, uri, err_msg: self._dump_debug_information( f"Error during {verb} {uri}: {err_msg}" ), @@ -1198,6 +1206,28 @@ def wait_for_ssh_up(self): # run commands. The actual connection retry loop happens in SSHConnection._init_connection _ = self.ssh_iface(0) + def hotplug_memory( + self, requested_size_mib: int, timeout: int = 60, poll: float = 0.1 + ): + """Send a hot(un)plug request and wait up to timeout seconds for completion polling every poll seconds + + Returns: api latency (secs), total latency (secs) + """ + api_start = time.time() + self.api.memory_hotplug.patch(requested_size_mib=requested_size_mib) + api_end = time.time() + # Wait for the hotplug to complete + deadline = time.time() + timeout + while time.time() < deadline: + if ( + self.api.memory_hotplug.get().json()["plugged_size_mib"] + == requested_size_mib + ): + plug_end = time.time() + return api_end - api_start, plug_end - api_start + time.sleep(poll) + raise TimeoutError(f"Hotplug did not complete within {timeout} seconds") + class MicroVMFactory: """MicroVM factory""" @@ -1249,11 +1279,13 @@ def build(self, kernel=None, rootfs=None, **kwargs): vm.ssh_key = ssh_key return vm - def build_from_snapshot(self, snapshot: Snapshot): + def build_from_snapshot(self, snapshot: Snapshot, uffd_handler_name=None): """Build a microvm from a snapshot""" vm = self.build() vm.spawn() - vm.restore_from_snapshot(snapshot, resume=True) + vm.restore_from_snapshot( + snapshot, resume=True, uffd_handler_name=uffd_handler_name + ) return vm def build_n_from_snapshot( diff --git a/tests/framework/swagger_validator.py b/tests/framework/swagger_validator.py new file mode 100644 index 00000000000..0ad9e310268 --- /dev/null +++ b/tests/framework/swagger_validator.py @@ -0,0 +1,186 @@ +# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""A validator for Firecracker API Swagger schema""" + +from pathlib import Path + +import yaml +from jsonschema import Draft4Validator, ValidationError + + +def _filter_none_recursive(data): + if isinstance(data, dict): + return {k: _filter_none_recursive(v) for k, v in data.items() if v is not None} + if isinstance(data, list): + return [_filter_none_recursive(item) for item in data if item is not None] + return data + + +class SwaggerValidator: + """Validator for API requests against the Swagger/OpenAPI specification""" + + _instance = None + _initialized = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + """Initialize the validator with the Swagger specification.""" + if self._initialized: + return + self._initialized = True + + swagger_path = ( + Path(__file__).parent.parent.parent + / "src" + / "firecracker" + / "swagger" + / "firecracker.yaml" + ) + + with open(swagger_path, "r", encoding="utf-8") as f: + self.swagger_spec = yaml.safe_load(f) + + # Cache validators for each endpoint + self._validators = {} + self._build_validators() + + def _build_validators(self): + """Build JSON schema validators for each endpoint.""" + paths = self.swagger_spec.get("paths", {}) + definitions = self.swagger_spec.get("definitions", {}) + + for path, methods in paths.items(): + for method, spec in methods.items(): + if method.upper() not in ["GET", "PUT", "PATCH", "POST", "DELETE"]: + continue + + # Build request body validators + parameters = spec.get("parameters", []) + for param in parameters: + if param.get("in") == "body" and "schema" in param: + schema = self._resolve_schema(param["schema"], definitions) + if method.upper() == "PATCH": + # do not validate required fields on PATCH requests + schema["required"] = [] + key = ("request", method.upper(), path) + self._validators[key] = Draft4Validator(schema) + + # Build response validators for 200/204 responses + responses = spec.get("responses", {}) + for status_code, response_spec in responses.items(): + if str(status_code) in ["200", "204"] and "schema" in response_spec: + schema = self._resolve_schema( + response_spec["schema"], definitions + ) + key = ("response", method.upper(), path, str(status_code)) + self._validators[key] = Draft4Validator(schema) + + def _resolve_schema(self, schema, definitions): + """Resolve $ref references in schema.""" + if "$ref" in schema: + ref_path = schema["$ref"] + if ref_path.startswith("#/definitions/"): + def_name = ref_path.split("/")[-1] + if def_name in definitions: + return self._resolve_schema(definitions[def_name], definitions) + + # Recursively resolve nested schemas + resolved = schema.copy() + if "properties" in resolved: + resolved["properties"] = { + k: self._resolve_schema(v, definitions) + for k, v in resolved["properties"].items() + } + if "items" in resolved and isinstance(resolved["items"], dict): + resolved["items"] = self._resolve_schema(resolved["items"], definitions) + + if not "additionalProperties" in resolved: + resolved["additionalProperties"] = False + + return resolved + + def validate_request(self, method, path, body): + """ + Validate a request body against the Swagger specification. + + Args: + method: HTTP method (GET, PUT, PATCH, etc.) + path: API path (e.g., "/drives/{drive_id}") + body: Request body as a dictionary + + Raises: + ValidationError: If the request body doesn't match the schema + """ + # Normalize path - replace specific IDs with parameter placeholders + normalized_path = self._normalize_path(path) + key = ("request", method.upper(), normalized_path) + + if key in self._validators: + validator = self._validators[key] + # Remove None values from body before validation + cleaned_body = _filter_none_recursive(body) + validator.validate(cleaned_body) + else: + raise ValidationError(f"{key} is not in the schema") + + def validate_response(self, method, path, status_code, body): + """ + Validate a response body against the Swagger specification. + + Args: + method: HTTP method (GET, PUT, PATCH, etc.) + path: API path (e.g., "/drives/{drive_id}") + status_code: HTTP status code (e.g., 200, 204) + body: Response body as a dictionary + + Raises: + ValidationError: If the response body doesn't match the schema + """ + # Normalize path - replace specific IDs with parameter placeholders + normalized_path = self._normalize_path(path) + key = ("response", method.upper(), normalized_path, str(status_code)) + + if key in self._validators: + validator = self._validators[key] + # Remove None values from body before validation + cleaned_body = _filter_none_recursive(body) + validator.validate(cleaned_body) + else: + raise ValidationError(f"{key} is not in the schema") + + def _normalize_path(self, path): + """ + Normalize a path by replacing specific IDs with parameter placeholders. + + E.g., "/drives/rootfs" -> "/drives/{drive_id}" + """ + # Match against known patterns in the swagger spec + paths = self.swagger_spec.get("paths", {}) + + # Direct match + if path in paths: + return path + + # Try to match parameterized paths + parts = path.split("/") + for swagger_path in paths.keys(): + swagger_parts = swagger_path.split("/") + if len(parts) == len(swagger_parts): + match = True + for _, (part, swagger_part) in enumerate(zip(parts, swagger_parts)): + # Check if it's a parameter placeholder or exact match + if swagger_part.startswith("{") and swagger_part.endswith("}"): + continue # This is a parameter, any value matches + if part != swagger_part: + match = False + break + + if match: + return swagger_path + + return path diff --git a/tests/framework/utils.py b/tests/framework/utils.py index 64bc9526e5c..448b351fd86 100644 --- a/tests/framework/utils.py +++ b/tests/framework/utils.py @@ -14,6 +14,7 @@ import typing from collections import defaultdict, namedtuple from contextlib import contextmanager +from pathlib import Path from typing import Dict import psutil @@ -129,6 +130,19 @@ def track_cpu_utilization( return cpu_utilization +def get_resident_memory(process: psutil.Process): + """Returns current memory utilization in KiB, including used HugeTLBFS""" + + proc_status = Path("/proc", str(process.pid), "status").read_text("utf-8") + for line in proc_status.splitlines(): + if line.startswith("HugetlbPages:"): # entry is in KiB + hugetlbfs_usage = int(line.split()[1]) + break + else: + assert False, f"HugetlbPages not found in {str(proc_status)}" + return hugetlbfs_usage + process.memory_info().rss // 1024 + + @contextmanager def chroot(path): """ @@ -240,25 +254,6 @@ def search_output_from_cmd(cmd: str, find_regex: typing.Pattern) -> typing.Match ) -def get_free_mem_ssh(ssh_connection): - """ - Get how much free memory in kB a guest sees, over ssh. - - :param ssh_connection: connection to the guest - :return: available mem column output of 'free' - """ - _, stdout, stderr = ssh_connection.run("cat /proc/meminfo | grep MemAvailable") - assert stderr == "" - - # Split "MemAvailable: 123456 kB" and validate it - meminfo_data = stdout.split() - if len(meminfo_data) == 3: - # Return the middle element in the array - return int(meminfo_data[1]) - - raise Exception("Available memory not found in `/proc/meminfo") - - def _format_output_message(proc, stdout, stderr): output_message = f"\n[{proc.pid}] Command:\n{proc.args}" # Append stdout/stderr to the output message diff --git a/tests/framework/vm_config.json b/tests/framework/vm_config.json index ae3b4920444..b2bac4066d5 100644 --- a/tests/framework/vm_config.json +++ b/tests/framework/vm_config.json @@ -32,5 +32,6 @@ "metrics": null, "mmds-config": null, "entropy": null, - "pmem": [] + "pmem": [], + "memory-hotplug": null } diff --git a/tests/host_tools/fcmetrics.py b/tests/host_tools/fcmetrics.py index 5b1343ffab7..cfe79ce7711 100644 --- a/tests/host_tools/fcmetrics.py +++ b/tests/host_tools/fcmetrics.py @@ -150,6 +150,7 @@ def validate_fc_metrics(metrics): "machine_cfg_count", "mmds_count", "vmm_version_count", + "hotplug_memory_count", ], "i8042": [ "error_count", @@ -201,6 +202,8 @@ def validate_fc_metrics(metrics): "machine_cfg_fails", "mmds_count", "mmds_fails", + "hotplug_memory_count", + "hotplug_memory_fails", ], "put_api_requests": [ "actions_count", @@ -227,6 +230,8 @@ def validate_fc_metrics(metrics): "pmem_fails", "serial_count", "serial_fails", + "hotplug_memory_count", + "hotplug_memory_fails", ], "seccomp": [ "num_faults", @@ -301,6 +306,26 @@ def validate_fc_metrics(metrics): "event_fails", "queue_event_count", ], + "memory_hotplug": [ + "activate_fails", + "queue_event_fails", + "queue_event_count", + "plug_count", + "plug_bytes", + "plug_fails", + {"plug_agg": latency_agg_metrics_fields}, + "unplug_count", + "unplug_bytes", + "unplug_fails", + "unplug_discard_fails", + {"unplug_agg": latency_agg_metrics_fields}, + "state_count", + "state_fails", + {"state_agg": latency_agg_metrics_fields}, + "unplug_all_count", + "unplug_all_fails", + {"unplug_all_agg": latency_agg_metrics_fields}, + ], } # validate timestamp before jsonschema validation which some more time diff --git a/tests/host_tools/memory.py b/tests/host_tools/memory.py index 134147724cd..c09dae0d206 100644 --- a/tests/host_tools/memory.py +++ b/tests/host_tools/memory.py @@ -99,7 +99,9 @@ def is_guest_mem_x86(self, size, guest_mem_bytes): Checks if a region is a guest memory region based on x86_64 physical memory layout """ - return size in ( + # it could be bigger if hotplugging is enabled + # if it's bigger, it's likely not from FC because we don't have big allocations + return size >= guest_mem_bytes or size in ( # memory fits before the first gap guest_mem_bytes, # guest memory spans at least two regions & memory fits before the second gap @@ -121,7 +123,9 @@ def is_guest_mem_arch64(self, size, guest_mem_bytes): Checks if a region is a guest memory region based on ARM64 physical memory layout """ - return size in ( + # it could be bigger if hotplugging is enabled + # if it's bigger, it's likely not from FC because we don't have big allocations + return size >= guest_mem_bytes or size in ( # guest memory fits before the gap guest_mem_bytes, # guest memory fills the space before the gap diff --git a/tests/integration_tests/functional/test_api.py b/tests/integration_tests/functional/test_api.py index 7dab0e14e6d..ab466c718d1 100644 --- a/tests/integration_tests/functional/test_api.py +++ b/tests/integration_tests/functional/test_api.py @@ -981,6 +981,49 @@ def test_api_entropy(uvm_plain): test_microvm.api.entropy.put() +def test_api_memory_hotplug(uvm_plain_6_1): + """ + Test hotplug related API commands. + """ + test_microvm = uvm_plain_6_1 + test_microvm.spawn() + test_microvm.basic_config() + test_microvm.add_net_iface() + + # Adding hotplug memory region should be OK. + test_microvm.api.memory_hotplug.put( + total_size_mib=1024, block_size_mib=128, slot_size_mib=1024 + ) + + # Overwriting an existing should be OK. + # Omitting optional values should be ok + test_microvm.api.memory_hotplug.put(total_size_mib=1024) + + # Get API should be rejected before boot + with pytest.raises(AssertionError): + test_microvm.api.memory_hotplug.get() + + # Patch API should be rejected before boot + with pytest.raises(RuntimeError, match=NOT_SUPPORTED_BEFORE_START): + test_microvm.api.memory_hotplug.patch(requested_size_mib=512) + + # Start the microvm + test_microvm.start() + + # Put API should be rejected after boot + with pytest.raises(RuntimeError, match=NOT_SUPPORTED_AFTER_START): + test_microvm.api.memory_hotplug.put(total_size_mib=1024) + + # Get API should work after boot + status = test_microvm.api.memory_hotplug.get().json() + assert status["total_size_mib"] == 1024 + + # Patch API should work after boot + test_microvm.api.memory_hotplug.patch(requested_size_mib=512) + status = test_microvm.api.memory_hotplug.get().json() + assert status["requested_size_mib"] == 512 + + def test_api_balloon(uvm_nano): """ Test balloon related API commands. @@ -1173,6 +1216,13 @@ def test_get_full_config_after_restoring_snapshot(microvm_factory, uvm_nano): uvm_nano.api.vsock.put(guest_cid=15, uds_path="vsock.sock") setup_cfg["vsock"] = {"guest_cid": 15, "uds_path": "vsock.sock"} + setup_cfg["memory-hotplug"] = { + "total_size_mib": 1024, + "block_size_mib": 128, + "slot_size_mib": 1024, + } + uvm_nano.api.memory_hotplug.put(**setup_cfg["memory-hotplug"]) + setup_cfg["logger"] = None setup_cfg["metrics"] = None setup_cfg["mmds-config"] = { @@ -1299,6 +1349,14 @@ def test_get_full_config(uvm_plain): response = test_microvm.api.vsock.put(guest_cid=15, uds_path="vsock.sock") expected_cfg["vsock"] = {"guest_cid": 15, "uds_path": "vsock.sock"} + # Add hot-pluggable memory. + expected_cfg["memory-hotplug"] = { + "total_size_mib": 1024, + "block_size_mib": 128, + "slot_size_mib": 1024, + } + test_microvm.api.memory_hotplug.put(**expected_cfg["memory-hotplug"]) + # Add a net device. iface_id = "1" tapname = test_microvm.id[:8] + "tap" + iface_id diff --git a/tests/integration_tests/functional/test_balloon.py b/tests/integration_tests/functional/test_balloon.py index f8960bedb6d..19b1651c72a 100644 --- a/tests/integration_tests/functional/test_balloon.py +++ b/tests/integration_tests/functional/test_balloon.py @@ -9,12 +9,13 @@ import pytest import requests -from framework.utils import check_output, get_free_mem_ssh +from framework.guest_stats import MeminfoGuest +from framework.utils import get_resident_memory STATS_POLLING_INTERVAL_S = 1 -def get_stable_rss_mem_by_pid(pid, percentage_delta=1): +def get_stable_rss_mem(uvm, percentage_delta=1): """ Get the RSS memory that a guest uses, given the pid of the guest. @@ -22,22 +23,16 @@ def get_stable_rss_mem_by_pid(pid, percentage_delta=1): Or print a warning if this does not happen. """ - # All values are reported as KiB - - def get_rss_from_pmap(): - _, output, _ = check_output("pmap -X {}".format(pid)) - return int(output.split("\n")[-2].split()[1], 10) - first_rss = 0 second_rss = 0 for _ in range(5): - first_rss = get_rss_from_pmap() + first_rss = get_resident_memory(uvm.ps) time.sleep(1) - second_rss = get_rss_from_pmap() + second_rss = get_resident_memory(uvm.ps) abs_diff = abs(first_rss - second_rss) abs_delta = abs_diff / first_rss * 100 print( - f"RSS readings: old: {first_rss} new: {second_rss} abs_diff: {abs_diff} abs_delta: {abs_delta}" + f"RSS readings (bytes): old: {first_rss} new: {second_rss} abs_diff: {abs_diff} abs_delta: {abs_delta}" ) if abs_delta < percentage_delta: return second_rss @@ -87,25 +82,24 @@ def make_guest_dirty_memory(ssh_connection, amount_mib=32): def _test_rss_memory_lower(test_microvm): """Check inflating the balloon makes guest use less rss memory.""" # Get the firecracker pid, and open an ssh connection. - firecracker_pid = test_microvm.firecracker_pid ssh_connection = test_microvm.ssh # Using deflate_on_oom, get the RSS as low as possible test_microvm.api.balloon.patch(amount_mib=200) # Get initial rss consumption. - init_rss = get_stable_rss_mem_by_pid(firecracker_pid) + init_rss = get_stable_rss_mem(test_microvm) # Get the balloon back to 0. test_microvm.api.balloon.patch(amount_mib=0) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Dirty memory, then inflate balloon and get ballooned rss consumption. make_guest_dirty_memory(ssh_connection, amount_mib=32) test_microvm.api.balloon.patch(amount_mib=200) - balloon_rss = get_stable_rss_mem_by_pid(firecracker_pid) + balloon_rss = get_stable_rss_mem(test_microvm) # Check that the ballooning reclaimed the memory. assert balloon_rss - init_rss <= 15000 @@ -149,18 +143,18 @@ def test_inflate_reduces_free(uvm_plain_any): # Start the microvm test_microvm.start() - firecracker_pid = test_microvm.firecracker_pid + meminfo = MeminfoGuest(test_microvm) # Get the free memory before ballooning. - available_mem_deflated = get_free_mem_ssh(test_microvm.ssh) + available_mem_deflated = meminfo.get().mem_free.kib() # Inflate 64 MB == 16384 page balloon. test_microvm.api.balloon.patch(amount_mib=64) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Get the free memory after ballooning. - available_mem_inflated = get_free_mem_ssh(test_microvm.ssh) + available_mem_inflated = meminfo.get().mem_free.kib() # Assert that ballooning reclaimed about 64 MB of memory. assert available_mem_inflated <= available_mem_deflated - 85 * 64000 / 100 @@ -195,19 +189,18 @@ def test_deflate_on_oom(uvm_plain_any, deflate_on_oom): # Start the microvm. test_microvm.start() - firecracker_pid = test_microvm.firecracker_pid # We get an initial reading of the RSS, then calculate the amount # we need to inflate the balloon with by subtracting it from the # VM size and adding an offset of 50 MiB in order to make sure we # get a lower reading than the initial one. - initial_rss = get_stable_rss_mem_by_pid(firecracker_pid) + initial_rss = get_stable_rss_mem(test_microvm) inflate_size = 256 - (int(initial_rss / 1024) + 50) # Inflate the balloon test_microvm.api.balloon.patch(amount_mib=inflate_size) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Check that using memory leads to the balloon device automatically # deflate (or not). @@ -250,39 +243,38 @@ def test_reinflate_balloon(uvm_plain_any): # Start the microvm. test_microvm.start() - firecracker_pid = test_microvm.firecracker_pid # First inflate the balloon to free up the uncertain amount of memory # used by the kernel at boot and establish a baseline, then give back # the memory. test_microvm.api.balloon.patch(amount_mib=200) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) test_microvm.api.balloon.patch(amount_mib=0) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Get the guest to dirty memory. make_guest_dirty_memory(test_microvm.ssh, amount_mib=32) - first_reading = get_stable_rss_mem_by_pid(firecracker_pid) + first_reading = get_stable_rss_mem(test_microvm) # Now inflate the balloon. test_microvm.api.balloon.patch(amount_mib=200) - second_reading = get_stable_rss_mem_by_pid(firecracker_pid) + second_reading = get_stable_rss_mem(test_microvm) # Now deflate the balloon. test_microvm.api.balloon.patch(amount_mib=0) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Now have the guest dirty memory again. make_guest_dirty_memory(test_microvm.ssh, amount_mib=32) - third_reading = get_stable_rss_mem_by_pid(firecracker_pid) + third_reading = get_stable_rss_mem(test_microvm) # Now inflate the balloon again. test_microvm.api.balloon.patch(amount_mib=200) - fourth_reading = get_stable_rss_mem_by_pid(firecracker_pid) + fourth_reading = get_stable_rss_mem(test_microvm) # Check that the memory used is the same after regardless of the previous # inflate history of the balloon (with the third reading being allowed @@ -309,10 +301,9 @@ def test_size_reduction(uvm_plain_any): # Start the microvm. test_microvm.start() - firecracker_pid = test_microvm.firecracker_pid # Check memory usage. - first_reading = get_stable_rss_mem_by_pid(firecracker_pid) + first_reading = get_stable_rss_mem(test_microvm) # Have the guest drop its caches. test_microvm.ssh.run("sync; echo 3 > /proc/sys/vm/drop_caches") @@ -328,7 +319,7 @@ def test_size_reduction(uvm_plain_any): test_microvm.api.balloon.patch(amount_mib=inflate_size) # Check memory usage again. - second_reading = get_stable_rss_mem_by_pid(firecracker_pid) + second_reading = get_stable_rss_mem(test_microvm) # There should be a reduction of at least 10MB. assert first_reading - second_reading >= 10000 @@ -353,7 +344,6 @@ def test_stats(uvm_plain_any): # Start the microvm. test_microvm.start() - firecracker_pid = test_microvm.firecracker_pid # Give Firecracker enough time to poll the stats at least once post-boot time.sleep(STATS_POLLING_INTERVAL_S * 2) @@ -371,7 +361,7 @@ def test_stats(uvm_plain_any): make_guest_dirty_memory(test_microvm.ssh, amount_mib=10) time.sleep(1) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Make sure that the stats catch the page faults. after_workload_stats = test_microvm.api.balloon_stats.get().json() @@ -380,7 +370,7 @@ def test_stats(uvm_plain_any): # Now inflate the balloon with 10MB of pages. test_microvm.api.balloon.patch(amount_mib=10) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Get another reading of the stats after the polling interval has passed. inflated_stats = test_microvm.api.balloon_stats.get().json() @@ -393,7 +383,7 @@ def test_stats(uvm_plain_any): # available memory. test_microvm.api.balloon.patch(amount_mib=0) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Get another reading of the stats after the polling interval has passed. deflated_stats = test_microvm.api.balloon_stats.get().json() @@ -421,13 +411,12 @@ def test_stats_update(uvm_plain_any): # Start the microvm. test_microvm.start() - firecracker_pid = test_microvm.firecracker_pid # Dirty 30MB of pages. make_guest_dirty_memory(test_microvm.ssh, amount_mib=30) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Get an initial reading of the stats. initial_stats = test_microvm.api.balloon_stats.get().json() @@ -477,17 +466,14 @@ def test_balloon_snapshot(uvm_plain_any, microvm_factory): make_guest_dirty_memory(vm.ssh, amount_mib=60) time.sleep(1) - # Get the firecracker pid, and open an ssh connection. - firecracker_pid = vm.firecracker_pid - # Check memory usage. - first_reading = get_stable_rss_mem_by_pid(firecracker_pid) + first_reading = get_stable_rss_mem(vm) # Now inflate the balloon with 20MB of pages. vm.api.balloon.patch(amount_mib=20) # Check memory usage again. - second_reading = get_stable_rss_mem_by_pid(firecracker_pid) + second_reading = get_stable_rss_mem(vm) # There should be a reduction in RSS, but it's inconsistent. # We only test that the reduction happens. @@ -496,28 +482,25 @@ def test_balloon_snapshot(uvm_plain_any, microvm_factory): snapshot = vm.snapshot_full() microvm = microvm_factory.build_from_snapshot(snapshot) - # Get the firecracker from snapshot pid, and open an ssh connection. - firecracker_pid = microvm.firecracker_pid - # Wait out the polling interval, then get the updated stats. time.sleep(STATS_POLLING_INTERVAL_S * 2) stats_after_snap = microvm.api.balloon_stats.get().json() # Check memory usage. - third_reading = get_stable_rss_mem_by_pid(firecracker_pid) + third_reading = get_stable_rss_mem(microvm) # Dirty 60MB of pages. make_guest_dirty_memory(microvm.ssh, amount_mib=60) # Check memory usage. - fourth_reading = get_stable_rss_mem_by_pid(firecracker_pid) + fourth_reading = get_stable_rss_mem(microvm) assert fourth_reading > third_reading # Inflate the balloon with another 20MB of pages. microvm.api.balloon.patch(amount_mib=40) - fifth_reading = get_stable_rss_mem_by_pid(firecracker_pid) + fifth_reading = get_stable_rss_mem(microvm) # There should be a reduction in RSS, but it's inconsistent. # We only test that the reduction happens. @@ -557,15 +540,14 @@ def test_memory_scrub(uvm_plain_any): microvm.api.balloon.patch(amount_mib=60) # Get the firecracker pid, and open an ssh connection. - firecracker_pid = microvm.firecracker_pid # Wait for the inflate to complete. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(microvm) # Deflate the balloon completely. microvm.api.balloon.patch(amount_mib=0) # Wait for the deflate to complete. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(microvm) microvm.ssh.check_output("/usr/local/bin/readmem {} {}".format(60, 1)) diff --git a/tests/integration_tests/functional/test_snapshot_restore_cross_kernel.py b/tests/integration_tests/functional/test_snapshot_restore_cross_kernel.py index bfe5316d9e5..253502a2d1f 100644 --- a/tests/integration_tests/functional/test_snapshot_restore_cross_kernel.py +++ b/tests/integration_tests/functional/test_snapshot_restore_cross_kernel.py @@ -20,7 +20,7 @@ from framework.utils_cpu_templates import get_supported_cpu_templates from framework.utils_vsock import check_vsock_device from integration_tests.functional.test_balloon import ( - get_stable_rss_mem_by_pid, + get_stable_rss_mem, make_guest_dirty_memory, ) @@ -28,21 +28,18 @@ def _test_balloon(microvm): - # Get the firecracker pid. - firecracker_pid = microvm.firecracker_pid - # Check memory usage. - first_reading = get_stable_rss_mem_by_pid(firecracker_pid) + first_reading = get_stable_rss_mem(microvm) # Dirty 300MB of pages. make_guest_dirty_memory(microvm.ssh, amount_mib=300) # Check memory usage again. - second_reading = get_stable_rss_mem_by_pid(firecracker_pid) + second_reading = get_stable_rss_mem(microvm) assert second_reading > first_reading # Inflate the balloon. Get back 200MB. microvm.api.balloon.patch(amount_mib=200) - third_reading = get_stable_rss_mem_by_pid(firecracker_pid) + third_reading = get_stable_rss_mem(microvm) # Ensure that there is a reduction in RSS. assert second_reading > third_reading diff --git a/tests/integration_tests/performance/test_hotplug_memory.py b/tests/integration_tests/performance/test_hotplug_memory.py new file mode 100644 index 00000000000..bd5c2bd6a9a --- /dev/null +++ b/tests/integration_tests/performance/test_hotplug_memory.py @@ -0,0 +1,480 @@ +# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Tests for verifying the virtio-mem is working correctly + +This file also contains functional tests for virtio-mem because they need to be +run on an ag=1 host due to the use of HugePages. +""" + +import pytest +from packaging import version +from tenacity import Retrying, retry_if_exception_type, stop_after_delay, wait_fixed + +from framework.guest_stats import MeminfoGuest +from framework.microvm import HugePagesConfig, SnapshotType +from framework.properties import global_props +from framework.utils import get_kernel_version, get_resident_memory + +MEMHP_BOOTARGS = "console=ttyS0 reboot=k panic=1 memhp_default_state=online_movable" +DEFAULT_CONFIG = {"total_size_mib": 1024, "slot_size_mib": 128, "block_size_mib": 2} + + +def uvm_booted_memhp( + uvm, + rootfs, + _microvm_factory, + vhost_user, + memhp_config, + huge_pages, + _uffd_handler, + snapshot_type, +): + """Boots a VM with the given memory hotplugging config""" + + uvm.spawn() + uvm.memory_monitor = None + uvm_config = { + "boot_args": MEMHP_BOOTARGS, + "huge_pages": huge_pages, + # we need enough memory to be able to hotplug up to 16GB + "mem_size_mib": 512, + } + if vhost_user: + # We need to setup ssh keys manually because we did not specify rootfs + # in microvm_factory.build method + ssh_key = rootfs.with_suffix(".id_rsa") + uvm.ssh_key = ssh_key + uvm.basic_config( + **uvm_config, + add_root_device=False, + track_dirty_pages=( + snapshot_type.needs_dirty_page_tracking if snapshot_type else False + ), + ) + uvm.add_vhost_user_drive( + "rootfs", rootfs, is_root_device=True, is_read_only=True + ) + else: + uvm.basic_config(**uvm_config) + + uvm.api.memory_hotplug.put(**memhp_config) + uvm.add_net_iface() + uvm.start() + return uvm + + +def uvm_resumed_memhp( + uvm_plain, + rootfs, + microvm_factory, + vhost_user, + memhp_config, + huge_pages, + uffd_handler, + snapshot_type, +): + """Restores a VM with the given memory hotplugging config after booting and snapshotting""" + if vhost_user: + pytest.skip("vhost-user doesn't support snapshot/restore") + if huge_pages and huge_pages != HugePagesConfig.NONE and not uffd_handler: + pytest.skip("Hugepages requires a UFFD handler") + uvm = uvm_booted_memhp( + uvm_plain, + rootfs, + microvm_factory, + vhost_user, + memhp_config, + huge_pages, + None, + snapshot_type, + ) + snapshot = uvm.make_snapshot(snapshot_type) + return microvm_factory.build_from_snapshot(snapshot, uffd_handler_name=uffd_handler) + + +@pytest.fixture( + params=[ + (uvm_booted_memhp, False, HugePagesConfig.NONE, None, None), + (uvm_booted_memhp, False, HugePagesConfig.HUGETLBFS_2MB, None, None), + (uvm_booted_memhp, True, HugePagesConfig.NONE, None, None), + (uvm_resumed_memhp, False, HugePagesConfig.NONE, None, SnapshotType.FULL), + (uvm_resumed_memhp, False, HugePagesConfig.NONE, None, SnapshotType.DIFF), + ( + uvm_resumed_memhp, + False, + HugePagesConfig.NONE, + None, + SnapshotType.DIFF_MINCORE, + ), + ( + uvm_resumed_memhp, + False, + HugePagesConfig.NONE, + "on_demand", + SnapshotType.FULL, + ), + ( + uvm_resumed_memhp, + False, + HugePagesConfig.HUGETLBFS_2MB, + "on_demand", + SnapshotType.FULL, + ), + ], + ids=[ + "booted", + "booted-huge-pages", + "booted-vhost-user", + "resumed", + "resumed-diff", + "resumed-mincore", + "resumed-uffd", + "resumed-uffd-huge-pages", + ], +) +def uvm_any_memhp(request, uvm_plain_6_1, rootfs, microvm_factory): + """Fixture that yields a booted or resumed VM with memory hotplugging""" + ctor, vhost_user, huge_pages, uffd_handler, snapshot_type = request.param + yield ctor( + uvm_plain_6_1, + rootfs, + microvm_factory, + vhost_user, + DEFAULT_CONFIG, + huge_pages, + uffd_handler, + snapshot_type, + ) + + +def supports_hugetlbfs_discard(): + """Returns True if the kernel supports hugetlbfs discard""" + return version.parse(get_kernel_version()) >= version.parse("5.18.0") + + +def validate_metrics(uvm): + """Validates that there are no fails in the metrics""" + metrics_to_check = ["plug_fails", "unplug_fails", "unplug_all_fails", "state_fails"] + if supports_hugetlbfs_discard(): + metrics_to_check.append("unplug_discard_fails") + uvm.flush_metrics() + for metrics in uvm.get_all_metrics(): + for k in metrics_to_check: + assert ( + metrics["memory_hotplug"][k] == 0 + ), f"{k}={metrics[k]} is greater than zero" + + +def check_device_detected(uvm): + """ + Check that the guest kernel has enabled virtio-mem. + """ + hp_config = uvm.api.memory_hotplug.get().json() + _, stdout, _ = uvm.ssh.check_output("dmesg | grep 'virtio_mem'") + for line in stdout.splitlines(): + _, key, value = line.strip().split(":") + key = key.strip() + value = int(value.strip(), base=0) + match key: + case "start address": + assert value >= (512 << 30), "start address isn't in past MMIO64 region" + case "region size": + assert ( + value == hp_config["total_size_mib"] << 20 + ), "region size doesn't match" + case "device block size": + assert ( + value == hp_config["block_size_mib"] << 20 + ), "block size doesn't match" + case "plugged size": + assert value == 0, "plugged size doesn't match" + case "requested size": + assert value == 0, "requested size doesn't match" + case _: + continue + + +def check_memory_usable(uvm): + """Allocates memory to verify it's usable (5% margin to avoid OOM-kill)""" + mem_available = MeminfoGuest(uvm).get().mem_available.mib() + # try to allocate 95% of available memory + amount_mib = int(mem_available * 95 / 100) + + _ = uvm.ssh.check_output(f"/usr/local/bin/fillmem {amount_mib}", timeout=30) + # verify the allocation was successful + _ = uvm.ssh.check_output("cat /tmp/fillmem_output.txt | grep successful") + + +def check_hotplug(uvm, requested_size_mib): + """Verifies memory can be hot(un)plugged""" + meminfo = MeminfoGuest(uvm) + mem_total_fixed = ( + meminfo.get().mem_total.mib() + - uvm.api.memory_hotplug.get().json()["plugged_size_mib"] + ) + uvm.hotplug_memory(requested_size_mib) + + # verify guest driver received the request + _, stdout, _ = uvm.ssh.check_output( + "dmesg | grep 'virtio_mem' | grep 'requested size' | tail -1" + ) + assert ( + int(stdout.strip().split(":")[-1].strip(), base=0) == requested_size_mib << 20 + ) + + for attempt in Retrying( + retry=retry_if_exception_type(AssertionError), + stop=stop_after_delay(5), + wait=wait_fixed(1), + reraise=True, + ): + with attempt: + # verify guest driver executed the request + mem_total_after = meminfo.get().mem_total.mib() + assert mem_total_after == mem_total_fixed + requested_size_mib + + +def check_hotunplug(uvm, requested_size_mib): + """Verifies memory can be hotunplugged and gets released""" + + rss_before = get_resident_memory(uvm.ps) + + check_hotplug(uvm, requested_size_mib) + + rss_after = get_resident_memory(uvm.ps) + + print(f"RSS before: {rss_before}, after: {rss_after}") + + huge_pages = HugePagesConfig(uvm.api.machine_config.get().json()["huge_pages"]) + if huge_pages == HugePagesConfig.HUGETLBFS_2MB and supports_hugetlbfs_discard(): + assert rss_after < rss_before, "RSS didn't decrease" + + +def test_virtio_mem_hotplug_hotunplug(uvm_any_memhp): + """ + Check that memory can be hotplugged into the VM. + """ + uvm = uvm_any_memhp + check_device_detected(uvm) + + check_hotplug(uvm, 1024) + check_memory_usable(uvm) + + check_hotunplug(uvm, 0) + + # Check it works again + check_hotplug(uvm, 1024) + check_memory_usable(uvm) + + validate_metrics(uvm) + + +@pytest.mark.parametrize( + "memhp_config", + [ + {"total_size_mib": 256, "slot_size_mib": 128, "block_size_mib": 64}, + {"total_size_mib": 256, "slot_size_mib": 128, "block_size_mib": 128}, + {"total_size_mib": 256, "slot_size_mib": 256, "block_size_mib": 64}, + {"total_size_mib": 256, "slot_size_mib": 256, "block_size_mib": 256}, + ], + ids=["all_different", "slot_sized_block", "single_slot", "single_block"], +) +def test_virtio_mem_configs(uvm_plain_6_1, memhp_config): + """ + Check that the virtio mem device is working as expected for different configs + """ + uvm = uvm_booted_memhp( + uvm_plain_6_1, None, None, False, memhp_config, None, None, None + ) + if not uvm.pci_enabled: + pytest.skip( + "Skip tests on MMIO transport to save time as we don't expect any difference." + ) + + check_device_detected(uvm) + + for size in range( + 0, memhp_config["total_size_mib"] + 1, memhp_config["block_size_mib"] + ): + check_hotplug(uvm, size) + + check_memory_usable(uvm) + + for size in range( + memhp_config["total_size_mib"] - memhp_config["block_size_mib"], + -1, + -memhp_config["block_size_mib"], + ): + check_hotunplug(uvm, size) + + validate_metrics(uvm) + + +def test_snapshot_restore_persistence(uvm_plain_6_1, microvm_factory, snapshot_type): + """ + Check that hptplugged memory is persisted across snapshot/restore. + """ + if not uvm_plain_6_1.pci_enabled: + pytest.skip( + "Skip tests on MMIO transport to save time as we don't expect any difference." + ) + uvm = uvm_booted_memhp( + uvm_plain_6_1, + None, + microvm_factory, + False, + DEFAULT_CONFIG, + None, + None, + snapshot_type, + ) + + uvm.hotplug_memory(1024) + + # Increase /dev/shm size as it defaults to half of the boot memory + uvm.ssh.check_output("mount -o remount,size=1024M -t tmpfs tmpfs /dev/shm") + + uvm.ssh.check_output("dd if=/dev/urandom of=/dev/shm/mem_hp_test bs=1M count=1024") + + _, checksum_before, _ = uvm.ssh.check_output("sha256sum /dev/shm/mem_hp_test") + + snapshot = uvm.make_snapshot(snapshot_type) + restored_vm = microvm_factory.build_from_snapshot(snapshot) + + _, checksum_after, _ = restored_vm.ssh.check_output( + "sha256sum /dev/shm/mem_hp_test" + ) + + assert checksum_before == checksum_after, "Checksums didn't match" + + validate_metrics(restored_vm) + + +def test_snapshot_restore_incremental(uvm_plain_6_1, microvm_factory): + """ + Check that hptplugged memory is persisted across snapshot/restore. + """ + if not uvm_plain_6_1.pci_enabled: + pytest.skip( + "Skip tests on MMIO transport to save time as we don't expect any difference." + ) + + uvm = uvm_booted_memhp( + uvm_plain_6_1, None, microvm_factory, False, DEFAULT_CONFIG, None, None, None + ) + + snapshot = uvm.snapshot_full() + + hotplug_count = 16 + hp_mem_mib_per_cycle = 1024 // hotplug_count + checksums = [] + for i, uvm in enumerate( + microvm_factory.build_n_from_snapshot( + snapshot, + hotplug_count + 1, + incremental=True, + use_snapshot_editor=True, + ) + ): + # check checksums of previous cycles + for j in range(i): + _, checksum, _ = uvm.ssh.check_output(f"sha256sum /dev/shm/mem_hp_test_{j}") + assert checksum == checksums[j], f"Checksums didn't match for i={i} j={j}" + + # we run hotplug_count+1 uvms to check all the checksums at the end + if i >= hotplug_count: + continue + + total_hp_mem_mib = hp_mem_mib_per_cycle * (i + 1) + uvm.hotplug_memory(total_hp_mem_mib) + + # Increase /dev/shm size as it defaults to half of the boot memory + uvm.ssh.check_output( + f"mount -o remount,size={total_hp_mem_mib}M -t tmpfs tmpfs /dev/shm" + ) + + uvm.ssh.check_output( + f"dd if=/dev/urandom of=/dev/shm/mem_hp_test_{i} bs=1M count={hp_mem_mib_per_cycle}" + ) + + _, checksum, _ = uvm.ssh.check_output(f"sha256sum /dev/shm/mem_hp_test_{i}") + checksums.append(checksum) + + validate_metrics(uvm) + + +def timed_memory_hotplug(uvm, size, metrics, metric_prefix, fc_metric_name): + """Wait for all memory hotplug events to be processed""" + + uvm.flush_metrics() + + api_time, total_time = uvm.hotplug_memory(size) + + fc_metrics = uvm.flush_metrics() + + metrics.put_metric( + f"{metric_prefix}_api_time", + api_time, + unit="Seconds", + ) + metrics.put_metric( + f"{metric_prefix}_total_time", + total_time, + unit="Seconds", + ) + metrics.put_metric( + f"{metric_prefix}_fc_time", + fc_metrics["memory_hotplug"][fc_metric_name]["sum_us"], + unit="Microseconds", + ) + + +@pytest.mark.nonci +@pytest.mark.parametrize( + "hotplug_size", + [ + 1024, + 2048, + 4096, + 8192, + 16384, + ], +) +@pytest.mark.parametrize( + "huge_pages", + [HugePagesConfig.NONE, HugePagesConfig.HUGETLBFS_2MB], +) +def test_memory_hotplug_latency( + microvm_factory, guest_kernel_linux_6_1, rootfs, hotplug_size, huge_pages, metrics +): + """Test the latency of hotplugging memory""" + + for i in range(20): + config = { + "total_size_mib": hotplug_size, + "slot_size_mib": 128, + "block_size_mib": 2, + } + uvm_plain_6_1 = microvm_factory.build(guest_kernel_linux_6_1, rootfs, pci=True) + uvm = uvm_booted_memhp( + uvm_plain_6_1, None, None, False, config, None, None, None + ) + + if i == 0: + metrics.set_dimensions( + { + "instance": global_props.instance, + "cpu_model": global_props.cpu_model, + "host_kernel": f"linux-{global_props.host_linux_version}", + "performance_test": "test_memory_hotplug_latency", + "hotplug_size": str(hotplug_size), + "huge_pages": huge_pages, + **uvm.dimensions, + } + ) + + timed_memory_hotplug(uvm, hotplug_size, metrics, "hotplug", "plug_agg") + timed_memory_hotplug(uvm, 0, metrics, "hotunplug", "unplug_agg") + timed_memory_hotplug(uvm, hotplug_size, metrics, "hotplug_2nd", "plug_agg") diff --git a/tools/bindgen.sh b/tools/bindgen.sh index cdab5ef824a..91a5aaf8e71 100755 --- a/tools/bindgen.sh +++ b/tools/bindgen.sh @@ -105,6 +105,12 @@ fc-bindgen \ --allowlist-var "VIRTIO_ID.*" \ "$INCLUDE/linux/virtio_ids.h" >src/vmm/src/devices/virtio/generated/virtio_ids.rs +info "BINDGEN virtio_mem.h" +fc-bindgen \ + --allowlist-var "VIRTIO_MEM.*" \ + --allowlist-type "virtio_mem.*" \ + "$INCLUDE/linux/virtio_mem.h" >src/vmm/src/devices/virtio/generated/virtio_mem.rs + info "BINDGEN prctl.h" fc-bindgen \ --allowlist-var "PR_.*" \