@@ -16,6 +16,13 @@ mod platforms;
1616pub use platforms:: { Platform , NitroEnclaves , EnclaveSimulator , EnclaveSimulatorArgs } ;
1717pub use platforms:: amdsevsnp:: { AmdSevVm , RunningVm , VmRunArgs , VmSimulator } ;
1818
19+ use crate :: usercall_ext:: Listener ;
20+ use crate :: usercall_ext:: SocketStream ;
21+ use crate :: usercall_ext:: UsercallExtension ;
22+ use crate :: usercall_ext:: UsercallExtensionDefault ;
23+
24+ mod usercall_ext;
25+
1926const MAX_LOG_MESSAGE_LEN : usize = 80 ;
2027const PROXY_BUFF_SIZE : usize = 4192 ;
2128
@@ -36,9 +43,9 @@ pub trait StreamConnection: Read + Write {
3643 fn peer_port ( & self ) -> io:: Result < u32 > ;
3744}
3845
39- impl StreamConnection for TcpStream {
46+ impl < T : SocketStream > StreamConnection for T {
4047 fn protocol ( ) -> & ' static str {
41- "tcp "
48+ "socket stream "
4249 }
4350
4451 fn local ( & self ) -> io:: Result < String > {
@@ -84,20 +91,20 @@ impl StreamConnection for VsockStream {
8491 }
8592}
8693
87- #[ derive( Debug ) ]
88- struct Listener {
89- listener : TcpListener ,
90- }
94+ // #[derive(Debug)]
95+ // struct Listener {
96+ // listener: TcpListener,
97+ // }
9198
92- impl Listener {
93- fn new ( listener : TcpListener ) -> Self {
94- Listener { listener }
95- }
96- }
99+ // impl Listener {
100+ // fn new(listener: TcpListener) -> Self {
101+ // Listener{ listener }
102+ // }
103+ // }
97104
98- #[ derive( Debug ) ]
105+ // #[derive(Debug)]
99106struct Connection {
100- tcp_stream : TcpStream ,
107+ tcp_stream : Box < dyn SocketStream > ,
101108 vsock_stream : VsockStream < Std > ,
102109 remote_name : String ,
103110}
@@ -111,7 +118,7 @@ struct ConnectionInfo {
111118}
112119
113120impl Connection {
114- pub fn new ( vsock_stream : VsockStream < Std > , tcp_stream : TcpStream , remote_name : String ) -> Self {
121+ pub fn new ( vsock_stream : VsockStream < Std > , tcp_stream : Box < dyn SocketStream > , remote_name : String ) -> Self {
115122 Connection {
116123 tcp_stream,
117124 vsock_stream,
@@ -342,8 +349,9 @@ pub struct Server<P: Platform> {
342349 /// When the enclave instructs to accept a new connection, the runner accepts a new TCP
343350 /// connection. It then locates the ListenerInfo and finds the information it needs to set up a
344351 /// new vsock connection to the enclave
345- listeners : RwLock < FnvHashMap < VsockAddr , Arc < Mutex < Listener > > > > ,
352+ listeners : RwLock < FnvHashMap < VsockAddr , Arc < Mutex < Box < dyn Listener > > > > > ,
346353 connections : RwLock < FnvHashMap < ConnectionKey , ConnectionInfo > > ,
354+ usercall_ext : Box < dyn UsercallExtension > ,
347355}
348356
349357impl < P : Platform + ' static > Server < P > {
@@ -379,8 +387,13 @@ impl<P: Platform + 'static> Server<P> {
379387 * [3] proxy
380388 */
381389 fn handle_request_connect ( self : Arc < Self > , remote_addr : & String , conn : & mut ClientConnection ) -> Result < ( ) , VmeError > {
390+ let remote_stream = if let Some ( stream) = self . usercall_ext . connect_stream ( remote_addr) ? {
391+ stream
392+ } else {
393+ let remote_socket = TcpStream :: connect ( remote_addr) . map_err ( |e| VmeError :: Command ( e. kind ( ) . into ( ) ) ) ?;
394+ Box :: new ( remote_socket)
395+ } ;
382396 // Connect to remote server
383- let remote_socket = TcpStream :: connect ( remote_addr) . map_err ( |e| VmeError :: Command ( e. kind ( ) . into ( ) ) ) ?;
384397 let remote_name = remote_addr. split_terminator ( ":" ) . next ( ) . unwrap_or ( remote_addr) ;
385398
386399 // Create listening socket that the enclave can connect to
@@ -390,8 +403,8 @@ impl<P: Platform + 'static> Server<P> {
390403 // Notify the enclave on which port her proxy is listening on
391404 let response = Response :: Connected {
392405 proxy_port : proxy_server_port,
393- local : remote_socket . local_addr ( ) ?. into ( ) ,
394- peer : remote_socket . peer_addr ( ) ?. into ( ) ,
406+ local : remote_stream . local_addr ( ) ?. into ( ) ,
407+ peer : remote_stream . peer_addr ( ) ?. into ( ) ,
395408 } ;
396409
397410 conn. send ( & response) ?;
@@ -402,7 +415,7 @@ impl<P: Platform + 'static> Server<P> {
402415 let accept_connection = move || -> Result < ( ) , VmeError > {
403416 let ( proxy, _proxy_addr) = proxy_server. accept ( ) ?;
404417 // Store connection info
405- self . add_connection ( proxy, remote_socket , remote_name. to_string ( ) ) ?;
418+ self . add_connection ( proxy, remote_stream , remote_name. to_string ( ) ) ?;
406419 Ok ( ( ) )
407420 } ;
408421 if let Err ( e) = accept_connection ( ) {
@@ -411,15 +424,15 @@ impl<P: Platform + 'static> Server<P> {
411424 Ok ( ( ) )
412425 }
413426
414- fn add_listener ( & self , addr : VsockAddr , info : Listener ) {
427+ fn add_listener ( & self , addr : VsockAddr , info : Box < dyn Listener > ) {
415428 self . listeners . write ( ) . unwrap ( ) . insert ( addr, Arc :: new ( Mutex :: new ( info) ) ) ;
416429 }
417430
418- fn listener ( & self , addr : & VsockAddr ) -> Option < Arc < Mutex < Listener > > > {
431+ fn listener ( & self , addr : & VsockAddr ) -> Option < Arc < Mutex < Box < dyn Listener > > > > {
419432 self . listeners . read ( ) . unwrap ( ) . get ( & addr) . cloned ( )
420433 }
421434
422- fn remove_listener ( & self , addr : & VsockAddr ) -> Option < Arc < Mutex < Listener > > > {
435+ fn remove_listener ( & self , addr : & VsockAddr ) -> Option < Arc < Mutex < Box < dyn Listener > > > > {
423436 self . listeners . write ( ) . unwrap ( ) . remove ( & addr)
424437 }
425438
@@ -445,7 +458,7 @@ impl<P: Platform + 'static> Server<P> {
445458 self . connections . write ( ) . unwrap ( ) . remove ( & k)
446459 }
447460
448- fn add_connection ( self : Arc < Self > , runner_enclave : VsockStream , runner_remote : TcpStream , remote_name : String ) -> Result < JoinHandle < ( ) > , IoError > {
461+ fn add_connection ( self : Arc < Self > , runner_enclave : VsockStream , runner_remote : Box < dyn SocketStream > , remote_name : String ) -> Result < JoinHandle < ( ) > , IoError > {
449462 let k = ConnectionKey :: from_vsock_stream ( & runner_enclave) ?;
450463 let mut connection = Connection :: new ( runner_enclave, runner_remote, remote_name) ;
451464 self . connections . write ( ) . unwrap ( ) . insert ( k. clone ( ) , connection. info ( ) ?) ;
@@ -484,9 +497,15 @@ impl<P: Platform + 'static> Server<P> {
484497 */
485498 fn handle_request_bind ( self : Arc < Self > , addr : & String , enclave_port : u32 , conn : & mut ClientConnection ) -> Result < ( ) , VmeError > {
486499 let cid: u32 = conn. stream . peer_addr ( ) ?. cid ( ) ;
487- let listener = TcpListener :: bind ( addr) . map_err ( |e| VmeError :: Command ( e. kind ( ) . into ( ) ) ) ?;
488- let local: Addr = listener. local_addr ( ) ?. into ( ) ;
489- self . add_listener ( VsockAddr :: new ( cid, enclave_port) , Listener :: new ( listener) ) ;
500+ let ( listener, local_addr) = if let Some ( ( lis, addr) ) = self . usercall_ext . bind_stream ( addr) ? {
501+ ( lis, addr)
502+ } else {
503+ let lis = TcpListener :: bind ( addr) . map_err ( |e| VmeError :: Command ( e. kind ( ) . into ( ) ) ) ?;
504+ let addr = lis. local_addr ( ) ?;
505+ ( Box :: new ( lis) as Box < dyn Listener > , addr)
506+ } ;
507+ let local: Addr = local_addr. into ( ) ;
508+ self . add_listener ( VsockAddr :: new ( cid, enclave_port) , listener) ;
490509 conn. send ( & Response :: Bound { local } ) ?;
491510 Ok ( ( ) )
492511 }
@@ -499,8 +518,8 @@ impl<P: Platform + 'static> Server<P> {
499518 . ok_or ( IoError :: new ( IoErrorKind :: InvalidInput , "Information about provided file descriptor was not found" ) ) ?;
500519
501520 // Accept connection for TCP Listener
502- let listener = listener. lock ( ) . unwrap ( ) ;
503- let ( conn, peer) = listener. listener . accept ( ) . map_err ( |e| VmeError :: Command ( e. kind ( ) . into ( ) ) ) ?;
521+ let mut listener = listener. lock ( ) . unwrap ( ) ;
522+ let ( conn, peer) = listener. accept ( ) . map_err ( |e| VmeError :: Command ( e. kind ( ) . into ( ) ) ) ?;
504523 drop ( listener) ;
505524
506525 // Send enclave info where it should accept new incoming connection
@@ -561,7 +580,7 @@ impl<P: Platform + 'static> Server<P> {
561580 if let Some ( listener) = self . listener ( & enclave_addr) {
562581 let listener = listener. lock ( ) . unwrap ( ) ;
563582 conn. send ( & Response :: Info {
564- local : listener. listener . local_addr ( ) ?. into ( ) ,
583+ local : listener. local_addr ( ) ?. into ( ) ,
565584 peer : None ,
566585 } ) ?;
567586 Ok ( ( ) )
@@ -605,9 +624,15 @@ impl<P: Platform + 'static> Server<P> {
605624 command_listener : Mutex :: new ( command_listener) ,
606625 listeners : RwLock :: new ( FnvHashMap :: default ( ) ) ,
607626 connections : RwLock :: new ( FnvHashMap :: default ( ) ) ,
627+ usercall_ext : Box :: new ( UsercallExtensionDefault ) ,
628+
608629 } )
609630 }
610631
632+ pub fn set_usercall_ext ( & mut self , usercall_ext : Box < dyn UsercallExtension > ) {
633+ self . usercall_ext = usercall_ext
634+ }
635+
611636 fn start_command_server ( self : Arc < Self > ) -> Result < JoinHandle < ( ) > , IoError > {
612637 thread:: Builder :: new ( ) . spawn ( move || {
613638 let command_listener = self . command_listener . lock ( ) . unwrap ( ) ;
0 commit comments