@@ -46,12 +46,12 @@ use super::super::defs::uapi;
4646use super :: super :: { VsockBackend , VsockChannel , VsockEpollListener , VsockError } ;
4747use super :: muxer_killq:: MuxerKillQ ;
4848use super :: muxer_rxq:: MuxerRxQ ;
49- use super :: { MuxerStreamConnection , VsockUnixBackendError , defs} ;
50- use crate :: devices:: virtio:: vsock:: csm:: { VsockConnection , VsockConnectionBackend } ;
49+ use super :: { VsockUnixBackendError , defs} ;
50+ use crate :: devices:: virtio:: vsock:: csm:: VsockConnection ;
5151use crate :: devices:: virtio:: vsock:: defs:: uapi:: { VSOCK_TYPE_SEQPACKET , VSOCK_TYPE_STREAM } ;
5252use crate :: devices:: virtio:: vsock:: metrics:: METRICS ;
5353use crate :: devices:: virtio:: vsock:: packet:: { VsockPacketRx , VsockPacketTx } ;
54- use crate :: devices:: virtio:: vsock:: unix:: MuxerConn ;
54+ use crate :: devices:: virtio:: vsock:: unix:: ConnBackend ;
5555use crate :: devices:: virtio:: vsock:: unix:: seqpacket:: { SeqpacketConn , SeqpacketListener , Socket } ;
5656use crate :: logger:: IncMetric ;
5757use crate :: vmm_config:: vsock:: VsockType ;
@@ -84,7 +84,7 @@ enum EpollListener {
8484 HostSock ,
8585 /// A listener interested in reading host `connect <port>` commands from a freshly
8686 /// connected host socket.
87- LocalStream ( RawFd ) ,
87+ LocalStream ( ConnBackend ) ,
8888}
8989
9090/// The vsock connection multiplexer.
@@ -93,7 +93,7 @@ pub struct VsockMuxer {
9393 /// Guest CID.
9494 cid : u64 ,
9595 /// A hash map used to store the active connections.
96- conn_map : HashMap < ConnMapKey , MuxerConn > ,
96+ conn_map : HashMap < ConnMapKey , VsockConnection < ConnBackend > > ,
9797 /// the underlying host socket file descriptor type wrapper
9898 host_sock : Box < dyn Socket > ,
9999 /// A hash map used to store epoll event listeners / handlers.
@@ -411,10 +411,7 @@ impl VsockMuxer {
411411 // the guest side, we need to know the destination port. We'll read
412412 // that port from a "connect" command received on this socket, so the
413413 // next step is to ask to be notified the moment we can read from it.
414- self . add_listener (
415- stream. as_raw_fd ( ) ,
416- EpollListener :: LocalStream ( stream. as_raw_fd ( ) ) ,
417- )
414+ self . add_listener ( stream. as_raw_fd ( ) , EpollListener :: LocalStream ( stream) )
418415 } )
419416 . unwrap_or_else ( |err| {
420417 warn ! ( "vsock: unable to accept local connection: {:?}" , err) ;
@@ -424,62 +421,28 @@ impl VsockMuxer {
424421 // Data is ready to be read from a host-initiated connection. That would be the
425422 // "connect" command that we're expecting.
426423 Some ( EpollListener :: LocalStream ( _) ) => {
427- if let Some ( EpollListener :: LocalStream ( fd) ) = self . remove_listener ( fd) {
428- match self . vsock_type {
429- VsockType :: Stream => {
430- // SAFETY: Safe because the fd is valid and we own it (removed from listener_map).
431- let mut stream = unsafe { UnixStream :: from_raw_fd ( fd) } ;
432- Self :: read_local_stream_port ( & mut stream)
433- . map ( |peer_port| ( self . allocate_local_port ( ) , peer_port) )
434- . and_then ( |( local_port, peer_port) | {
435- self . add_connection (
436- ConnMapKey {
437- local_port,
438- peer_port,
439- } ,
440- MuxerConn :: Stream (
441- VsockConnection :: < UnixStream > :: new_local_init (
442- stream,
443- uapi:: VSOCK_HOST_CID ,
444- self . cid ,
445- local_port,
446- peer_port,
447- VsockType :: Stream ,
448- ) ,
449- ) ,
450- )
451- } )
452- . unwrap_or_else ( |err| {
453- info ! ( "vsock: error adding local-init connection: {:?}" , err) ;
454- } )
455- }
456- VsockType :: Seqpacket => {
457- let mut stream = SeqpacketConn :: new ( fd) ;
458- Self :: read_local_stream_port ( & mut stream)
459- . map ( |peer_port| ( self . allocate_local_port ( ) , peer_port) )
460- . and_then ( |( local_port, peer_port) | {
461- self . add_connection (
462- ConnMapKey {
463- local_port,
464- peer_port,
465- } ,
466- MuxerConn :: Seqpacket (
467- VsockConnection :: < SeqpacketConn > :: new_local_init (
468- stream,
469- uapi:: VSOCK_HOST_CID ,
470- self . cid ,
471- local_port,
472- peer_port,
473- VsockType :: Seqpacket ,
474- ) ,
475- ) ,
476- )
477- } )
478- . unwrap_or_else ( |err| {
479- info ! ( "vsock: error adding local-init connection: {:?}" , err) ;
480- } )
481- }
482- } ;
424+ if let Some ( EpollListener :: LocalStream ( mut stream) ) = self . remove_listener ( fd) {
425+ Self :: read_local_stream_port ( & mut stream)
426+ . map ( |peer_port| ( self . allocate_local_port ( ) , peer_port) )
427+ . and_then ( |( local_port, peer_port) | {
428+ self . add_connection (
429+ ConnMapKey {
430+ local_port,
431+ peer_port,
432+ } ,
433+ VsockConnection :: new_local_init (
434+ stream,
435+ uapi:: VSOCK_HOST_CID ,
436+ self . cid ,
437+ local_port,
438+ peer_port,
439+ self . vsock_type . clone ( ) ,
440+ ) ,
441+ )
442+ } )
443+ . unwrap_or_else ( |err| {
444+ info ! ( "vsock: error adding local-init connection: {:?}" , err) ;
445+ } ) ;
483446 }
484447 }
485448
@@ -547,7 +510,7 @@ impl VsockMuxer {
547510 fn add_connection (
548511 & mut self ,
549512 key : ConnMapKey ,
550- conn : MuxerConn ,
513+ conn : VsockConnection < ConnBackend > ,
551514 ) -> Result < ( ) , VsockUnixBackendError > {
552515 // We might need to make room for this new connection, so let's sweep the kill queue
553516 // first. It's fine to do this here because:
@@ -695,15 +658,15 @@ impl VsockMuxer {
695658 local_port : pkt. hdr . dst_port ( ) ,
696659 peer_port : pkt. hdr . src_port ( ) ,
697660 } ,
698- MuxerConn :: Stream ( VsockConnection :: < UnixStream > :: new_peer_init (
699- stream,
661+ VsockConnection :: < ConnBackend > :: new_peer_init (
662+ ConnBackend :: Stream ( stream) ,
700663 uapi:: VSOCK_HOST_CID ,
701664 self . cid ,
702665 pkt. hdr . dst_port ( ) ,
703666 pkt. hdr . src_port ( ) ,
704667 pkt. hdr . buf_alloc ( ) ,
705668 VsockType :: Stream ,
706- ) ) ,
669+ ) ,
707670 )
708671 } )
709672 . unwrap_or_else ( |_| self . enq_rst ( pkt. hdr . dst_port ( ) , pkt. hdr . src_port ( ) ) ) ;
@@ -718,15 +681,19 @@ impl VsockMuxer {
718681 local_port : pkt. hdr . dst_port ( ) ,
719682 peer_port : pkt. hdr . src_port ( ) ,
720683 } ,
721- MuxerConn :: Seqpacket ( VsockConnection :: < SeqpacketConn > :: new_peer_init (
722- SeqpacketConn :: new ( stream. into_raw_fd ( ) ) ,
684+ VsockConnection :: < ConnBackend > :: new_peer_init (
685+ // SAFETY: There's no way this file descriptor is invalid or closed
686+ // because we only created it in the above line
687+ ConnBackend :: Seqpacket ( SeqpacketConn :: new ( unsafe {
688+ OwnedFd :: from_raw_fd ( stream. into_raw_fd ( ) )
689+ } ) ) ,
723690 uapi:: VSOCK_HOST_CID ,
724691 self . cid ,
725692 pkt. hdr . dst_port ( ) ,
726693 pkt. hdr . src_port ( ) ,
727694 pkt. hdr . buf_alloc ( ) ,
728695 VsockType :: Seqpacket ,
729- ) ) ,
696+ ) ,
730697 )
731698 } )
732699 . unwrap_or_else ( |_| self . enq_rst ( pkt. hdr . dst_port ( ) , pkt. hdr . src_port ( ) ) ) ;
@@ -743,7 +710,7 @@ impl VsockMuxer {
743710 /// - kill the connection if an unrecoverable error occurs.
744711 fn apply_conn_mutation < F > ( & mut self , key : ConnMapKey , mut_fn : F )
745712 where
746- F : FnOnce ( & mut MuxerConn ) ,
713+ F : FnOnce ( & mut VsockConnection < ConnBackend > ) ,
747714 {
748715 if let Some ( conn) = self . conn_map . get_mut ( & key) {
749716 let had_rx = conn. has_pending_rx ( ) ;
0 commit comments