Skip to content

Commit 9cb7f64

Browse files
committed
use an enum for the backend instead of a trait with a generic
to simplify the lifetime of file descriptors, leaving their closing to the unix stream and the owned fd for seqpacket Co-authored-by: aerosouund <aerosound161@gmail.com> Signed-off-by: aerosouund <aerosound161@gmail.com>
1 parent 35a8af4 commit 9cb7f64

File tree

8 files changed

+104
-158
lines changed

8 files changed

+104
-158
lines changed

CHANGELOG.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ and this project adheres to
1212

1313
### Changed
1414

15-
- [#5595](https://github.com/firecracker-microvm/firecracker/pull/5595): Added `vsock_type`
16-
field to the vsock device API to denote the type of the underlying socket. Can be `stream`
17-
or `seqpacket`
15+
- [#5595](https://github.com/firecracker-microvm/firecracker/pull/5595): Added
16+
`vsock_type` field to the vsock device API to denote the type of the
17+
underlying socket. Can be `stream` or `seqpacket`
1818

1919
### Deprecated
2020

src/vmm/src/devices/virtio/vsock/csm/connection.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ pub struct VsockConnection<S: VsockConnectionBackend> {
140140
/// Instant when this connection should be scheduled for immediate termination, due to some
141141
/// timeout condition having been fulfilled.
142142
expiry: Option<Instant>,
143-
/// Vsock type (stream or seqpacket)
143+
/// The type of the underlying socket connection
144144
vsock_type: VsockType,
145145
}
146146

src/vmm/src/devices/virtio/vsock/unix/mod.rs

Lines changed: 33 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ mod muxer;
1111
mod muxer_killq;
1212
mod muxer_rxq;
1313
mod seqpacket;
14-
use std::os::fd::AsRawFd as _;
14+
use std::io::{self, Read, Write};
15+
use std::os::fd::AsRawFd;
1516
use std::os::unix::net::UnixStream;
1617
use std::time::Instant;
1718

@@ -57,85 +58,58 @@ pub enum VsockUnixBackendError {
5758
TooManyConnections,
5859
}
5960

60-
type MuxerStreamConnection = super::csm::VsockConnection<UnixStream>;
61-
type MuxerSeqpacketConnetion = super::csm::VsockConnection<SeqpacketConn>;
62-
6361
#[derive(Debug)]
64-
enum MuxerConn {
65-
Stream(MuxerStreamConnection),
66-
Seqpacket(MuxerSeqpacketConnetion),
62+
pub enum ConnBackend {
63+
Stream(UnixStream),
64+
Seqpacket(SeqpacketConn),
6765
}
68-
66+
// can we make vsockconnection instead of being generic, hold an enum ?
6967
macro_rules! forward_to_inner {
7068
($self:ident, $method:ident $(, $args:expr )* ) => {
7169
match $self {
72-
MuxerConn::Stream(inner) => inner.$method($($args),*),
73-
MuxerConn::Seqpacket(inner) => inner.$method($($args),*),
70+
ConnBackend::Stream(inner) => inner.$method($($args),*),
71+
ConnBackend::Seqpacket(inner) => inner.$method($($args),*),
7472
}
7573
};
7674
}
7775

78-
impl MuxerConn {
79-
fn has_pending_rx(&self) -> bool {
80-
forward_to_inner!(self, has_pending_rx)
76+
impl Read for ConnBackend {
77+
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
78+
forward_to_inner!(self, read, buf)
8179
}
80+
}
8281

82+
impl AsRawFd for ConnBackend {
8383
fn as_raw_fd(&self) -> i32 {
8484
forward_to_inner!(self, as_raw_fd)
8585
}
86+
}
8687

87-
fn kill(&mut self) {
88-
forward_to_inner!(self, kill)
89-
}
90-
91-
fn get_polled_evset(&self) -> EventSet {
92-
forward_to_inner!(self, get_polled_evset)
93-
}
94-
95-
fn will_expire(&self) -> bool {
96-
forward_to_inner!(self, will_expire)
97-
}
98-
99-
fn has_expired(&self) -> bool {
100-
forward_to_inner!(self, has_expired)
101-
}
102-
103-
fn send_bytes_raw(&mut self, buf: &[u8]) -> Result<usize, VsockCsmError> {
104-
forward_to_inner!(self, send_bytes_raw, buf)
105-
}
106-
107-
fn state(&self) -> ConnState {
108-
forward_to_inner!(self, state)
109-
}
110-
111-
fn expiry(&self) -> Option<Instant> {
112-
forward_to_inner!(self, expiry)
113-
}
114-
115-
fn recv_pkt(&mut self, pkt: &mut VsockPacketRx) -> Result<(), VsockError> {
116-
forward_to_inner!(self, recv_pkt, pkt)
117-
}
118-
119-
fn send_pkt(&mut self, pkt: &VsockPacketTx) -> Result<(), VsockError> {
120-
forward_to_inner!(self, send_pkt, pkt)
88+
impl ReadVolatile for ConnBackend {
89+
fn read_volatile<B: vm_memory::bitmap::BitmapSlice>(
90+
&mut self,
91+
buf: &mut vm_memory::VolatileSlice<B>,
92+
) -> Result<usize, vm_memory::VolatileMemoryError> {
93+
forward_to_inner!(self, read_volatile, buf)
12194
}
95+
}
12296

123-
fn notify(&mut self, evset: EventSet) {
124-
forward_to_inner!(self, notify, evset)
97+
impl WriteVolatile for ConnBackend {
98+
fn write_volatile<B: vm_memory::bitmap::BitmapSlice>(
99+
&mut self,
100+
buf: &vm_memory::VolatileSlice<B>,
101+
) -> Result<usize, vm_memory::VolatileMemoryError> {
102+
forward_to_inner!(self, write_volatile, buf)
125103
}
126104
}
127105

128-
#[cfg(test)]
129-
impl MuxerConn {
130-
pub(crate) fn fwd_cnt(&self) -> std::num::Wrapping<u32> {
131-
forward_to_inner!(self, fwd_cnt)
106+
impl Write for ConnBackend {
107+
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
108+
forward_to_inner!(self, write, buf)
132109
}
133-
134-
pub(crate) fn insert_credit_update(&mut self) {
135-
forward_to_inner!(self, insert_credit_update)
110+
fn flush(&mut self) -> io::Result<()> {
111+
Ok(())
136112
}
137113
}
138114

139-
impl VsockConnectionBackend for UnixStream {}
140-
141-
impl VsockConnectionBackend for SeqpacketConn {}
115+
impl VsockConnectionBackend for ConnBackend {}

src/vmm/src/devices/virtio/vsock/unix/muxer.rs

Lines changed: 40 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,12 @@ use super::super::defs::uapi;
4646
use super::super::{VsockBackend, VsockChannel, VsockEpollListener, VsockError};
4747
use super::muxer_killq::MuxerKillQ;
4848
use super::muxer_rxq::MuxerRxQ;
49-
use super::{MuxerStreamConnection, VsockUnixBackendError, defs};
50-
use crate::devices::virtio::vsock::csm::{VsockConnection, VsockConnectionBackend};
49+
use super::{VsockUnixBackendError, defs};
50+
use crate::devices::virtio::vsock::csm::VsockConnection;
5151
use crate::devices::virtio::vsock::defs::uapi::{VSOCK_TYPE_SEQPACKET, VSOCK_TYPE_STREAM};
5252
use crate::devices::virtio::vsock::metrics::METRICS;
5353
use crate::devices::virtio::vsock::packet::{VsockPacketRx, VsockPacketTx};
54-
use crate::devices::virtio::vsock::unix::MuxerConn;
54+
use crate::devices::virtio::vsock::unix::ConnBackend;
5555
use crate::devices::virtio::vsock::unix::seqpacket::{SeqpacketConn, SeqpacketListener, Socket};
5656
use crate::logger::IncMetric;
5757
use crate::vmm_config::vsock::VsockType;
@@ -84,7 +84,7 @@ enum EpollListener {
8484
HostSock,
8585
/// A listener interested in reading host `connect <port>` commands from a freshly
8686
/// connected host socket.
87-
LocalStream(RawFd),
87+
LocalStream(ConnBackend),
8888
}
8989

9090
/// The vsock connection multiplexer.
@@ -93,7 +93,7 @@ pub struct VsockMuxer {
9393
/// Guest CID.
9494
cid: u64,
9595
/// A hash map used to store the active connections.
96-
conn_map: HashMap<ConnMapKey, MuxerConn>,
96+
conn_map: HashMap<ConnMapKey, VsockConnection<ConnBackend>>,
9797
/// the underlying host socket file descriptor type wrapper
9898
host_sock: Box<dyn Socket>,
9999
/// A hash map used to store epoll event listeners / handlers.
@@ -411,10 +411,7 @@ impl VsockMuxer {
411411
// the guest side, we need to know the destination port. We'll read
412412
// that port from a "connect" command received on this socket, so the
413413
// next step is to ask to be notified the moment we can read from it.
414-
self.add_listener(
415-
stream.as_raw_fd(),
416-
EpollListener::LocalStream(stream.as_raw_fd()),
417-
)
414+
self.add_listener(stream.as_raw_fd(), EpollListener::LocalStream(stream))
418415
})
419416
.unwrap_or_else(|err| {
420417
warn!("vsock: unable to accept local connection: {:?}", err);
@@ -424,62 +421,28 @@ impl VsockMuxer {
424421
// Data is ready to be read from a host-initiated connection. That would be the
425422
// "connect" command that we're expecting.
426423
Some(EpollListener::LocalStream(_)) => {
427-
if let Some(EpollListener::LocalStream(fd)) = self.remove_listener(fd) {
428-
match self.vsock_type {
429-
VsockType::Stream => {
430-
// SAFETY: Safe because the fd is valid and we own it (removed from listener_map).
431-
let mut stream = unsafe { UnixStream::from_raw_fd(fd) };
432-
Self::read_local_stream_port(&mut stream)
433-
.map(|peer_port| (self.allocate_local_port(), peer_port))
434-
.and_then(|(local_port, peer_port)| {
435-
self.add_connection(
436-
ConnMapKey {
437-
local_port,
438-
peer_port,
439-
},
440-
MuxerConn::Stream(
441-
VsockConnection::<UnixStream>::new_local_init(
442-
stream,
443-
uapi::VSOCK_HOST_CID,
444-
self.cid,
445-
local_port,
446-
peer_port,
447-
VsockType::Stream,
448-
),
449-
),
450-
)
451-
})
452-
.unwrap_or_else(|err| {
453-
info!("vsock: error adding local-init connection: {:?}", err);
454-
})
455-
}
456-
VsockType::Seqpacket => {
457-
let mut stream = SeqpacketConn::new(fd);
458-
Self::read_local_stream_port(&mut stream)
459-
.map(|peer_port| (self.allocate_local_port(), peer_port))
460-
.and_then(|(local_port, peer_port)| {
461-
self.add_connection(
462-
ConnMapKey {
463-
local_port,
464-
peer_port,
465-
},
466-
MuxerConn::Seqpacket(
467-
VsockConnection::<SeqpacketConn>::new_local_init(
468-
stream,
469-
uapi::VSOCK_HOST_CID,
470-
self.cid,
471-
local_port,
472-
peer_port,
473-
VsockType::Seqpacket,
474-
),
475-
),
476-
)
477-
})
478-
.unwrap_or_else(|err| {
479-
info!("vsock: error adding local-init connection: {:?}", err);
480-
})
481-
}
482-
};
424+
if let Some(EpollListener::LocalStream(mut stream)) = self.remove_listener(fd) {
425+
Self::read_local_stream_port(&mut stream)
426+
.map(|peer_port| (self.allocate_local_port(), peer_port))
427+
.and_then(|(local_port, peer_port)| {
428+
self.add_connection(
429+
ConnMapKey {
430+
local_port,
431+
peer_port,
432+
},
433+
VsockConnection::new_local_init(
434+
stream,
435+
uapi::VSOCK_HOST_CID,
436+
self.cid,
437+
local_port,
438+
peer_port,
439+
self.vsock_type.clone(),
440+
),
441+
)
442+
})
443+
.unwrap_or_else(|err| {
444+
info!("vsock: error adding local-init connection: {:?}", err);
445+
});
483446
}
484447
}
485448

@@ -547,7 +510,7 @@ impl VsockMuxer {
547510
fn add_connection(
548511
&mut self,
549512
key: ConnMapKey,
550-
conn: MuxerConn,
513+
conn: VsockConnection<ConnBackend>,
551514
) -> Result<(), VsockUnixBackendError> {
552515
// We might need to make room for this new connection, so let's sweep the kill queue
553516
// first. It's fine to do this here because:
@@ -695,15 +658,15 @@ impl VsockMuxer {
695658
local_port: pkt.hdr.dst_port(),
696659
peer_port: pkt.hdr.src_port(),
697660
},
698-
MuxerConn::Stream(VsockConnection::<UnixStream>::new_peer_init(
699-
stream,
661+
VsockConnection::<ConnBackend>::new_peer_init(
662+
ConnBackend::Stream(stream),
700663
uapi::VSOCK_HOST_CID,
701664
self.cid,
702665
pkt.hdr.dst_port(),
703666
pkt.hdr.src_port(),
704667
pkt.hdr.buf_alloc(),
705668
VsockType::Stream,
706-
)),
669+
),
707670
)
708671
})
709672
.unwrap_or_else(|_| self.enq_rst(pkt.hdr.dst_port(), pkt.hdr.src_port()));
@@ -718,15 +681,19 @@ impl VsockMuxer {
718681
local_port: pkt.hdr.dst_port(),
719682
peer_port: pkt.hdr.src_port(),
720683
},
721-
MuxerConn::Seqpacket(VsockConnection::<SeqpacketConn>::new_peer_init(
722-
SeqpacketConn::new(stream.into_raw_fd()),
684+
VsockConnection::<ConnBackend>::new_peer_init(
685+
// SAFETY: There's no way this file descriptor is invalid or closed
686+
// because we only created it in the above line
687+
ConnBackend::Seqpacket(SeqpacketConn::new(unsafe {
688+
OwnedFd::from_raw_fd(stream.into_raw_fd())
689+
})),
723690
uapi::VSOCK_HOST_CID,
724691
self.cid,
725692
pkt.hdr.dst_port(),
726693
pkt.hdr.src_port(),
727694
pkt.hdr.buf_alloc(),
728695
VsockType::Seqpacket,
729-
)),
696+
),
730697
)
731698
})
732699
.unwrap_or_else(|_| self.enq_rst(pkt.hdr.dst_port(), pkt.hdr.src_port()));
@@ -743,7 +710,7 @@ impl VsockMuxer {
743710
/// - kill the connection if an unrecoverable error occurs.
744711
fn apply_conn_mutation<F>(&mut self, key: ConnMapKey, mut_fn: F)
745712
where
746-
F: FnOnce(&mut MuxerConn),
713+
F: FnOnce(&mut VsockConnection<ConnBackend>),
747714
{
748715
if let Some(conn) = self.conn_map.get_mut(&key) {
749716
let had_rx = conn.has_pending_rx();

src/vmm/src/devices/virtio/vsock/unix/muxer_killq.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@
2727
use std::collections::{HashMap, VecDeque};
2828
use std::time::Instant;
2929

30+
use super::defs;
3031
use super::muxer::ConnMapKey;
31-
use super::{MuxerStreamConnection, defs};
32-
use crate::devices::virtio::vsock::csm::{VsockConnection, VsockConnectionBackend};
33-
use crate::devices::virtio::vsock::unix::MuxerConn;
32+
use crate::devices::virtio::vsock::csm::VsockConnection;
33+
use crate::devices::virtio::vsock::unix::ConnBackend;
3434

3535
/// A kill queue item, holding the connection key and the scheduled time for termination.
3636
#[derive(Debug, Clone, Copy)]
@@ -68,7 +68,7 @@ impl MuxerKillQ {
6868
/// set to expire at some point in the future.
6969
/// Note: if more than `Self::SIZE` connections are found, the queue will be created in an
7070
/// out-of-sync state, and will be discarded after it is emptied.
71-
pub fn from_conn_map(conn_map: &HashMap<ConnMapKey, MuxerConn>) -> Self {
71+
pub fn from_conn_map(conn_map: &HashMap<ConnMapKey, VsockConnection<ConnBackend>>) -> Self {
7272
let mut q_buf: Vec<MuxerKillQItem> = Vec::with_capacity(Self::SIZE);
7373
let mut synced = true;
7474
for (key, conn) in conn_map.iter() {

src/vmm/src/devices/virtio/vsock/unix/muxer_rxq.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ use std::collections::{HashMap, VecDeque};
2020
use super::super::VsockChannel;
2121
use super::defs;
2222
use super::muxer::{ConnMapKey, MuxerRx};
23-
use crate::devices::virtio::vsock::csm::{VsockConnection, VsockConnectionBackend};
24-
use crate::devices::virtio::vsock::unix::MuxerConn;
23+
use crate::devices::virtio::vsock::csm::VsockConnection;
24+
use crate::devices::virtio::vsock::unix::ConnBackend;
2525

2626
/// The muxer RX queue.
2727
#[derive(Debug)]
@@ -47,7 +47,7 @@ impl MuxerRxQ {
4747
/// Note: the resulting queue may still be desynchronized, if there are too many connections
4848
/// that have pending RX data. In that case, the muxer will first drain this queue, and
4949
/// then try again to build a synchronized one.
50-
pub fn from_conn_map(conn_map: &HashMap<ConnMapKey, MuxerConn>) -> Self {
50+
pub fn from_conn_map(conn_map: &HashMap<ConnMapKey, VsockConnection<ConnBackend>>) -> Self {
5151
let mut q = VecDeque::new();
5252
let mut synced = true;
5353

0 commit comments

Comments
 (0)