Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 159 additions & 9 deletions src/vmm/src/devices/virtio/mem/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::sync::atomic::AtomicU32;
use log::info;
use serde::{Deserialize, Serialize};
use vm_memory::{
Address, GuestAddress, GuestMemory, GuestMemoryError, GuestMemoryRegion, GuestUsize,
Address, Bytes, GuestAddress, GuestMemory, GuestMemoryError, GuestMemoryRegion, GuestUsize,
};
use vmm_sys_util::eventfd::EventFd;

Expand All @@ -20,12 +20,15 @@ use crate::devices::virtio::device::{ActiveState, DeviceState, VirtioDevice};
use crate::devices::virtio::generated::virtio_config::VIRTIO_F_VERSION_1;
use crate::devices::virtio::generated::virtio_ids::VIRTIO_ID_MEM;
use crate::devices::virtio::generated::virtio_mem::{
VIRTIO_MEM_F_UNPLUGGED_INACCESSIBLE, virtio_mem_config,
self, VIRTIO_MEM_F_UNPLUGGED_INACCESSIBLE, virtio_mem_config,
};
use crate::devices::virtio::iov_deque::IovDequeError;
use crate::devices::virtio::mem::metrics::METRICS;
use crate::devices::virtio::mem::request::{BlockRangeState, Request, RequestedRange, Response};
use crate::devices::virtio::mem::{VIRTIO_MEM_DEV_ID, VIRTIO_MEM_GUEST_ADDRESS};
use crate::devices::virtio::queue::{FIRECRACKER_MAX_QUEUE_SIZE, InvalidAvailIdx, Queue};
use crate::devices::virtio::queue::{
DescriptorChain, FIRECRACKER_MAX_QUEUE_SIZE, InvalidAvailIdx, Queue, QueueError,
};
use crate::devices::virtio::transport::{VirtioInterrupt, VirtioInterruptType};
use crate::logger::{IncMetric, debug, error};
use crate::utils::{bytes_to_mib, mib_to_bytes, u64_to_usize, usize_to_u64};
Expand All @@ -47,6 +50,24 @@ pub enum VirtioMemError {
InvalidSize(u64),
/// Device is not active
DeviceNotActive,
/// Descriptor is write-only
UnexpectedWriteOnlyDescriptor,
/// Error reading virtio descriptor
DescriptorWriteFailed,
/// Error writing virtio descriptor
DescriptorReadFailed,
/// Unknown request type: {0:?}
UnknownRequestType(u32),
/// Descriptor chain is too short
DescriptorChainTooShort,
/// Descriptor is too small
DescriptorLengthTooSmall,
/// Descriptor is read-only
UnexpectedReadOnlyDescriptor,
/// Error popping from virtio queue: {0}
InvalidAvailIdx(#[from] InvalidAvailIdx),
/// Error adding used queue: {0}
QueueError(#[from] QueueError),
}

#[derive(Debug)]
Expand Down Expand Up @@ -170,8 +191,139 @@ impl VirtioMem {
.map_err(VirtioMemError::InterruptError)
}

fn guest_memory(&self) -> &GuestMemoryMmap {
&self.device_state.active_state().unwrap().mem
}

fn parse_request(
&self,
avail_desc: &DescriptorChain,
) -> Result<(Request, GuestAddress, u16), VirtioMemError> {
// The head contains the request type which MUST be readable.
if avail_desc.is_write_only() {
return Err(VirtioMemError::UnexpectedWriteOnlyDescriptor);
}

if (avail_desc.len as usize) < size_of::<virtio_mem::virtio_mem_req>() {
return Err(VirtioMemError::DescriptorLengthTooSmall);
}

let request: virtio_mem::virtio_mem_req = self
.guest_memory()
.read_obj(avail_desc.addr)
.map_err(|_| VirtioMemError::DescriptorReadFailed)?;

let resp_desc = avail_desc
.next_descriptor()
.ok_or(VirtioMemError::DescriptorChainTooShort)?;

// The response MUST always be writable.
if !resp_desc.is_write_only() {
return Err(VirtioMemError::UnexpectedReadOnlyDescriptor);
}

if (resp_desc.len as usize) < std::mem::size_of::<virtio_mem::virtio_mem_resp>() {
return Err(VirtioMemError::DescriptorLengthTooSmall);
}

Ok((request.into(), resp_desc.addr, avail_desc.index))
}

fn write_response(
&mut self,
resp: Response,
resp_addr: GuestAddress,
used_idx: u16,
) -> Result<(), VirtioMemError> {
debug!("virtio-mem: Response: {:?}", resp);
self.guest_memory()
.write_obj(virtio_mem::virtio_mem_resp::from(resp), resp_addr)
.map_err(|_| VirtioMemError::DescriptorWriteFailed)
.map(|_| size_of::<virtio_mem::virtio_mem_resp>())?;
self.queues[MEM_QUEUE]
.add_used(
used_idx,
u32::try_from(std::mem::size_of::<virtio_mem::virtio_mem_resp>()).unwrap(),
)
.map_err(VirtioMemError::QueueError)
}

fn handle_plug_request(
&mut self,
range: &RequestedRange,
resp_addr: GuestAddress,
used_idx: u16,
) -> Result<(), VirtioMemError> {
METRICS.plug_count.inc();
let _metric = METRICS.plug_agg.record_latency_metrics();

// TODO: implement PLUG request
let response = Response::ack();
self.write_response(response, resp_addr, used_idx)
}

fn handle_unplug_request(
&mut self,
range: &RequestedRange,
resp_addr: GuestAddress,
used_idx: u16,
) -> Result<(), VirtioMemError> {
METRICS.unplug_count.inc();
let _metric = METRICS.unplug_agg.record_latency_metrics();

// TODO: implement UNPLUG request
let response = Response::ack();
self.write_response(response, resp_addr, used_idx)
}

fn handle_unplug_all_request(
&mut self,
resp_addr: GuestAddress,
used_idx: u16,
) -> Result<(), VirtioMemError> {
METRICS.unplug_all_count.inc();
let _metric = METRICS.unplug_all_agg.record_latency_metrics();

// TODO: implement UNPLUG ALL request
let response = Response::ack();
self.write_response(response, resp_addr, used_idx)
}

fn handle_state_request(
&mut self,
range: &RequestedRange,
resp_addr: GuestAddress,
used_idx: u16,
) -> Result<(), VirtioMemError> {
METRICS.state_count.inc();
let _metric = METRICS.state_agg.record_latency_metrics();

// TODO: implement STATE request
let response = Response::ack_with_state(BlockRangeState::Mixed);
self.write_response(response, resp_addr, used_idx)
}

fn process_mem_queue(&mut self) -> Result<(), VirtioMemError> {
info!("TODO: Received mem queue event, but it's not implemented.");
while let Some(desc) = self.queues[MEM_QUEUE].pop()? {
let index = desc.index;

let (req, resp_addr, used_idx) = self.parse_request(&desc)?;
debug!("virtio-mem: Request: {:?}", req);
// Handle request and write response
match req {
Request::State(ref range) => self.handle_state_request(range, resp_addr, used_idx),
Request::Plug(ref range) => self.handle_plug_request(range, resp_addr, used_idx),
Request::Unplug(ref range) => {
self.handle_unplug_request(range, resp_addr, used_idx)
}
Request::UnplugAll => self.handle_unplug_all_request(resp_addr, used_idx),
Request::Unsupported(t) => Err(VirtioMemError::UnknownRequestType(t)),
}?;
}

self.queues[MEM_QUEUE].advance_used_ring_idx();
self.signal_used_queue()?;

Ok(())
}

Expand Down Expand Up @@ -237,11 +389,9 @@ impl VirtioMem {
"virtio-mem: Updated requested size to {} bytes",
requested_size
);
// TODO(virtio-mem): trigger interrupt once we add handling for the requests
// self.interrupt_trigger()
// .trigger(VirtioInterruptType::Config)
// .map_err(VirtioMemError::InterruptError)
Ok(())
self.interrupt_trigger()
.trigger(VirtioInterruptType::Config)
.map_err(VirtioMemError::InterruptError)
}
}

Expand Down
24 changes: 24 additions & 0 deletions src/vmm/src/devices/virtio/mem/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,22 @@ pub(super) struct VirtioMemDeviceMetrics {
pub queue_event_fails: SharedIncMetric,
/// Number of queue events handled
pub queue_event_count: SharedIncMetric,
/// Latency of Plug operations
pub plug_agg: LatencyAggregateMetrics,
/// Number of Plug operations
pub plug_count: SharedIncMetric,
/// Latency of Unplug operations
pub unplug_agg: LatencyAggregateMetrics,
/// Number of Unplug operations
pub unplug_count: SharedIncMetric,
/// Latency of UnplugAll operations
pub unplug_all_agg: LatencyAggregateMetrics,
/// Number of UnplugAll operations
pub unplug_all_count: SharedIncMetric,
/// Latency of State operations
pub state_agg: LatencyAggregateMetrics,
/// Number of State operations
pub state_count: SharedIncMetric,
}

impl VirtioMemDeviceMetrics {
Expand All @@ -54,6 +70,14 @@ impl VirtioMemDeviceMetrics {
activate_fails: SharedIncMetric::new(),
queue_event_fails: SharedIncMetric::new(),
queue_event_count: SharedIncMetric::new(),
plug_agg: LatencyAggregateMetrics::new(),
plug_count: SharedIncMetric::new(),
unplug_agg: LatencyAggregateMetrics::new(),
unplug_count: SharedIncMetric::new(),
unplug_all_agg: LatencyAggregateMetrics::new(),
unplug_all_count: SharedIncMetric::new(),
state_agg: LatencyAggregateMetrics::new(),
state_count: SharedIncMetric::new(),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/vmm/src/devices/virtio/mem/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod device;
mod event_handler;
pub mod metrics;
pub mod persist;
mod request;

use vm_memory::GuestAddress;

Expand Down
142 changes: 142 additions & 0 deletions src/vmm/src/devices/virtio/mem/request.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use vm_memory::{ByteValued, GuestAddress};

use crate::devices::virtio::generated::virtio_mem;

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct RequestedRange {
pub(crate) addr: GuestAddress,
pub(crate) nb_blocks: usize,
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) enum Request {
Plug(RequestedRange),
Unplug(RequestedRange),
UnplugAll,
State(RequestedRange),
Unsupported(u32),
}

// SAFETY: this is safe, trust me bro
unsafe impl ByteValued for virtio_mem::virtio_mem_req {}

impl From<virtio_mem::virtio_mem_req> for Request {
fn from(req: virtio_mem::virtio_mem_req) -> Self {
match req.type_.into() {
// SAFETY: union type is checked in the match
virtio_mem::VIRTIO_MEM_REQ_PLUG => unsafe {
Request::Plug(RequestedRange {
addr: GuestAddress(req.u.plug.addr),
nb_blocks: req.u.plug.nb_blocks.into(),
})
},
// SAFETY: union type is checked in the match
virtio_mem::VIRTIO_MEM_REQ_UNPLUG => unsafe {
Request::Unplug(RequestedRange {
addr: GuestAddress(req.u.unplug.addr),
nb_blocks: req.u.unplug.nb_blocks.into(),
})
},
virtio_mem::VIRTIO_MEM_REQ_UNPLUG_ALL => Request::UnplugAll,
// SAFETY: union type is checked in the match
virtio_mem::VIRTIO_MEM_REQ_STATE => unsafe {
Request::State(RequestedRange {
addr: GuestAddress(req.u.state.addr),
nb_blocks: req.u.state.nb_blocks.into(),
})
},
t => Request::Unsupported(t),
}
}
}

#[derive(Debug, Clone, Copy)]
pub enum ResponseType {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really needed? This enum just repeats what VIRTIO_MEM_RESP_ consts are. If you use constants directly, there is no need for this enum or any conversion code.
Or at least you can use repr(u16) to remove to u16 conversion and totally skip VIRTIO_MEM_RESP_ values since thes enum will have same values.

Ack,
Nack,
Busy,
Error,
}

impl From<ResponseType> for u16 {
fn from(code: ResponseType) -> Self {
match code {
ResponseType::Ack => virtio_mem::VIRTIO_MEM_RESP_ACK,
ResponseType::Nack => virtio_mem::VIRTIO_MEM_RESP_NACK,
ResponseType::Busy => virtio_mem::VIRTIO_MEM_RESP_BUSY,
ResponseType::Error => virtio_mem::VIRTIO_MEM_RESP_ERROR,
}
.try_into()
.unwrap()
}
}

#[derive(Debug, Clone, Copy)]
pub enum BlockRangeState {
Plugged,
Unplugged,
Mixed,
}

impl From<BlockRangeState> for virtio_mem::virtio_mem_resp_state {
fn from(code: BlockRangeState) -> Self {
virtio_mem::virtio_mem_resp_state {
state: match code {
BlockRangeState::Plugged => virtio_mem::VIRTIO_MEM_STATE_PLUGGED,
BlockRangeState::Unplugged => virtio_mem::VIRTIO_MEM_STATE_UNPLUGGED,
BlockRangeState::Mixed => virtio_mem::VIRTIO_MEM_STATE_MIXED,
}
.try_into()
.unwrap(),
}
}
}

#[derive(Debug, Clone)]
pub struct Response {
pub resp_type: ResponseType,
// Only for State requests
pub state: Option<BlockRangeState>,
}

impl Response {
pub(crate) fn error() -> Self {
Response {
resp_type: ResponseType::Error,
state: None,
}
}

pub(crate) fn ack() -> Self {
Response {
resp_type: ResponseType::Ack,
state: None,
}
}

pub(crate) fn ack_with_state(state: BlockRangeState) -> Self {
Response {
resp_type: ResponseType::Ack,
state: Some(state),
}
}
}

// SAFETY: Plain data structures
unsafe impl ByteValued for virtio_mem::virtio_mem_resp {}

impl From<Response> for virtio_mem::virtio_mem_resp {
fn from(resp: Response) -> Self {
let mut out = virtio_mem::virtio_mem_resp {
type_: resp.resp_type.into(),
..Default::default()
};
if let Some(state) = resp.state {
out.u.state = state.into();
}
out
}
}