Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ axdriver_display = { path = "axdriver_display", version = "0.1" }
axdriver_net = { path = "axdriver_net", version = "0.1" }
axdriver_pci = { path = "axdriver_pci", version = "0.1" }
axdriver_virtio = { path = "axdriver_virtio", version = "0.1" }
axdriver_vsock = { path = "axdriver_vsock", version = "0.1" }
log = "0.4"
2 changes: 2 additions & 0 deletions axdriver_base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ pub enum DeviceType {
Net,
/// Graphic display device (e.g., GPU)
Display,
/// Vsock device (e.g., virtio-vsock).
Vsock,
}

/// The error type for device operation failures.
Expand Down
2 changes: 2 additions & 0 deletions axdriver_virtio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ alloc = ["virtio-drivers/alloc"]
block = ["axdriver_block"]
gpu = ["alloc", "axdriver_display"]
net = ["alloc", "axdriver_net"]
socket = ["alloc", "axdriver_vsock"]

[dependencies]
axdriver_base = { workspace = true }
axdriver_block = { workspace = true, optional = true }
axdriver_display = { workspace = true, optional = true }
axdriver_net = { workspace = true, optional = true }
axdriver_vsock = { workspace = true, optional = true }
log = { workspace = true }
virtio-drivers = { version = "0.7.4", default-features = false }
19 changes: 17 additions & 2 deletions axdriver_virtio/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ pub use self::blk::VirtIoBlkDev;
pub use self::gpu::VirtIoGpuDev;
#[cfg(feature = "net")]
pub use self::net::VirtIoNetDev;
#[cfg(feature = "socket")]
mod socket;
use self::pci::{DeviceFunction, DeviceFunctionInfo, PciRoot};
#[cfg(feature = "socket")]
pub use self::socket::VirtIoSocketDev;

/// Try to probe a VirtIO MMIO device from the given memory region.
///
Expand Down Expand Up @@ -84,13 +88,14 @@ const fn as_dev_type(t: VirtIoDevType) -> Option<DeviceType> {
Block => Some(DeviceType::Block),
Network => Some(DeviceType::Net),
GPU => Some(DeviceType::Display),
Socket => Some(DeviceType::Vsock),
_ => None,
}
}

#[allow(dead_code)]
const fn as_dev_err(e: virtio_drivers::Error) -> DevError {
use virtio_drivers::Error::*;
use virtio_drivers::{Error::*, device::socket::SocketError::*};
Copy link

Copilot AI Dec 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import of SocketError::* may cause namespace pollution since it brings all socket error variants into scope. This could lead to naming conflicts if other error types have similar variant names. Consider using qualified paths like device::socket::SocketError::ConnectionExists in the match arms instead, or limiting the wildcard import scope.

Copilot uses AI. Check for mistakes.
match e {
QueueFull => DevError::BadState,
NotReady => DevError::Again,
Expand All @@ -102,6 +107,16 @@ const fn as_dev_err(e: virtio_drivers::Error) -> DevError {
Unsupported => DevError::Unsupported,
ConfigSpaceTooSmall => DevError::BadState,
ConfigSpaceMissing => DevError::BadState,
_ => DevError::BadState,
SocketDeviceError(e) => match e {
ConnectionExists => DevError::AlreadyExists,
NotConnected => DevError::BadState,
InvalidOperation | InvalidNumber | UnknownOperation(_) => DevError::InvalidParam,
OutputBufferTooShort(_) | BufferTooShort | BufferTooLong(..) => DevError::InvalidParam,
UnexpectedDataInPacket | PeerSocketShutdown | NoResponseReceived | ConnectionFailed => {
DevError::Io
}
InsufficientBufferSpaceInPeer => DevError::Again,
RecycledWrongBuffer => DevError::BadState,
},
}
}
144 changes: 144 additions & 0 deletions axdriver_virtio/src/socket.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
use axdriver_base::{BaseDriverOps, DevResult, DeviceType};
use axdriver_vsock::{VsockConnId, VsockDriverEvent, VsockDriverOps};
use virtio_drivers::{
Hal,
device::socket::{
VirtIOSocket, VsockAddr, VsockConnectionManager as InnerDev, VsockEvent, VsockEventType,
},
transport::Transport,
};

use crate::as_dev_err;

/// The VirtIO socket device driver.
pub struct VirtIoSocketDev<H: Hal, T: Transport> {
inner: InnerDev<H, T>,
}

unsafe impl<H: Hal, T: Transport> Send for VirtIoSocketDev<H, T> {}
unsafe impl<H: Hal, T: Transport> Sync for VirtIoSocketDev<H, T> {}

impl<H: Hal, T: Transport> VirtIoSocketDev<H, T> {
/// Creates a new driver instance and initializes the device, or returns
/// an error if any step fails.
pub fn try_new(transport: T) -> DevResult<Self> {
let virtio_socket = VirtIOSocket::<H, _>::new(transport).map_err(as_dev_err)?;
Ok(Self {
inner: InnerDev::new(virtio_socket),
})
}
}

impl<H: Hal, T: Transport> BaseDriverOps for VirtIoSocketDev<H, T> {
fn device_name(&self) -> &str {
"virtio-socket"
}

fn device_type(&self) -> DeviceType {
DeviceType::Vsock
}
}

fn map_conn_id(cid: VsockConnId) -> (VsockAddr, u32) {
(
VsockAddr {
cid: cid.peer_addr.cid as _,
port: cid.peer_addr.port as _,
},
cid.local_port,
)
}

impl<H: Hal, T: Transport> VsockDriverOps for VirtIoSocketDev<H, T> {
fn guest_cid(&self) -> u64 {
self.inner.guest_cid()
}

fn listen(&mut self, src_port: u32) {
self.inner.listen(src_port)
}

fn connect(&mut self, cid: VsockConnId) -> DevResult<()> {
let (peer_addr, src_port) = map_conn_id(cid);
self.inner.connect(peer_addr, src_port).map_err(as_dev_err)
}

fn send(&mut self, cid: VsockConnId, buf: &[u8]) -> DevResult<usize> {
let (peer_addr, src_port) = map_conn_id(cid);
match self.inner.send(peer_addr, src_port, buf) {
Ok(()) => Ok(buf.len()),
Err(e) => Err(as_dev_err(e)),
}
}

fn recv(&mut self, cid: VsockConnId, buf: &mut [u8]) -> DevResult<usize> {
let (peer_addr, src_port) = map_conn_id(cid);
self.inner
.recv(peer_addr, src_port, buf)
.map_err(as_dev_err)
}

fn recv_avail(&mut self, cid: VsockConnId) -> DevResult<usize> {
let (peer_addr, src_port) = map_conn_id(cid);
self.inner
.recv_buffer_available_bytes(peer_addr, src_port)
.map_err(as_dev_err)
}

fn disconnect(&mut self, cid: VsockConnId) -> DevResult<()> {
let (peer_addr, src_port) = map_conn_id(cid);
self.inner.shutdown(peer_addr, src_port).map_err(as_dev_err)
}

fn abort(&mut self, cid: VsockConnId) -> DevResult<()> {
let (peer_addr, src_port) = map_conn_id(cid);
self.inner
.force_close(peer_addr, src_port)
.map_err(as_dev_err)
}

fn poll_event(&mut self, buf: &mut [u8]) -> DevResult<Option<VsockDriverEvent>> {
match self.inner.poll() {
Ok(None) => {
// no event
Ok(None)
}
Ok(Some(event)) => {
// translate event
let result = convert_vsock_event(event, &mut self.inner, buf)?;
Ok(Some(result))
}
Err(e) => {
// error
Err(as_dev_err(e))
}
}
}
}

fn convert_vsock_event<H: Hal, T: Transport>(
event: VsockEvent,
inner: &mut InnerDev<H, T>,
buf: &mut [u8],
) -> DevResult<VsockDriverEvent> {
let cid = VsockConnId {
peer_addr: axdriver_vsock::VsockAddr {
cid: event.source.cid as _,
port: event.source.port as _,
},
local_port: event.destination.port,
};

match event.event_type {
VsockEventType::ConnectionRequest => Ok(VsockDriverEvent::ConnectionRequest(cid)),
VsockEventType::Connected => Ok(VsockDriverEvent::Connected(cid)),
VsockEventType::Received { length } => {
let read = inner
.recv(event.source, event.destination.port, &mut buf[..length])
.map_err(as_dev_err)?;
Ok(VsockDriverEvent::Received(cid, read))
}
VsockEventType::Disconnected { reason: _ } => Ok(VsockDriverEvent::Disconnected(cid)),
_ => Ok(VsockDriverEvent::Unknown),
Copy link

Copilot AI Dec 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The wildcard pattern _ in the match statement catches all unknown event types and returns VsockDriverEvent::Unknown. While this provides a fallback, it may silently ignore newly added event types in future versions of the virtio-drivers crate. Consider logging a warning when an unknown event type is encountered to aid debugging.

Copilot uses AI. Check for mistakes.
}
}
19 changes: 19 additions & 0 deletions axdriver_vsock/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[package]
name = "axdriver_vsock"
edition.workspace = true
description = "Common traits and types for vsock drivers"
documentation = "https://arceos-org.github.io/axdriver_crates/axdriver_vsock"
keywords = ["arceos", "driver", "vsock"]
version.workspace = true
authors = ["Weikang Guo <guoweikang@kylinos.cn>"]
license.workspace = true
homepage.workspace = true
repository.workspace = true
categories.workspace = true

[features]
default = []

[dependencies]
axdriver_base = { workspace = true }
log = "0.4"
84 changes: 84 additions & 0 deletions axdriver_vsock/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
//! Common traits and types for socket communite device drivers (i.e. disk).

#![no_std]
#![cfg_attr(doc, feature(doc_cfg))]

#[doc(no_inline)]
pub use axdriver_base::{BaseDriverOps, DevError, DevResult, DeviceType};

/// Vsock address.
#[derive(Copy, Clone, Debug, Default, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct VsockAddr {
/// Context Identifier.
pub cid: u64,
/// Port number.
pub port: u32,
}

/// Vsock connection id.
#[derive(Copy, Clone, Debug, Default, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct VsockConnId {
/// Peer address.
pub peer_addr: VsockAddr,
/// Local port.
pub local_port: u32,
}

impl VsockConnId {
/// Create a new [`VsockConnId`] for listening socket
pub fn listening(local_port: u32) -> Self {
Self {
peer_addr: VsockAddr { cid: 0, port: 0 },
local_port,
}
}
}

/// VsockDriverEvent
#[derive(Debug)]
pub enum VsockDriverEvent {
/// ConnectionRequest
ConnectionRequest(VsockConnId),
/// Connected
Connected(VsockConnId),
/// Received
Received(VsockConnId, usize),
/// Disconnected
Disconnected(VsockConnId),
/// unknown event
Unknown,
}

/// Operations that require a block storage device driver to implement.
pub trait VsockDriverOps: BaseDriverOps {
/// guest cid
fn guest_cid(&self) -> u64;

/// Listen on a specific port.
fn listen(&mut self, src_port: u32);

/// Connect to a peer socket.
fn connect(&mut self, cid: VsockConnId) -> DevResult<()>;

/// Send data to the connected peer socket. need addr for DGRAM mode
fn send(&mut self, cid: VsockConnId, buf: &[u8]) -> DevResult<usize>;

/// Receive data from the connected peer socket.
fn recv(&mut self, cid: VsockConnId, buf: &mut [u8]) -> DevResult<usize>;

/// Returns the number of bytes in the receive buffer available to be read
/// by recv.
fn recv_avail(&mut self, cid: VsockConnId) -> DevResult<usize>;

/// Disconnect from the connected peer socket.
///
/// Requests to shut down the connection cleanly, telling the peer that we
/// won't send or receive any more data.
fn disconnect(&mut self, cid: VsockConnId) -> DevResult<()>;

/// Forcibly closes the connection without waiting for the peer.
fn abort(&mut self, cid: VsockConnId) -> DevResult<()>;

/// poll event from driver
Copy link

Copilot AI Dec 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The buf parameter is passed to poll_event but is unused in the Ok(None) branch and the error path. The buffer is only used when converting events in the Ok(Some(event)) case. Consider documenting the purpose of this buffer parameter in the trait definition, as it's not clear why callers need to provide it.

Suggested change
/// poll event from driver
/// Poll a single event from the driver.
///
/// The caller provides `buf` as a scratch buffer that implementations may use
/// to decode or temporarily store driver-specific data associated with the
/// returned [`VsockDriverEvent`]. The buffer is typically only used when an
/// event is actually available (`Ok(Some(event))`) and may be ignored when
/// there is no event (`Ok(None)`) or when an error is returned.
///
/// Implementations must not rely on the contents of `buf` on entry, and
/// callers should ensure it is large enough for any driver-specific payload
/// they expect to handle.

Copilot uses AI. Check for mistakes.
fn poll_event(&mut self, buf: &mut [u8]) -> DevResult<Option<VsockDriverEvent>>;
}