@@ -8,7 +8,7 @@ use std::thread::{self, JoinHandle};
88use std:: io:: { self , Error as IoError , ErrorKind as IoErrorKind , Read , Write } ;
99use std:: net:: { Shutdown , TcpListener , TcpStream } ;
1010use std:: os:: unix:: io:: AsRawFd ;
11- use std:: sync:: { Arc , Mutex } ;
11+ use std:: sync:: { Arc , Mutex , RwLock } ;
1212use fortanix_vme_abi:: { self , Addr , Response , Request } ;
1313use vsock:: { self , SockAddr as VsockAddr , Std , Vsock , VsockListener , VsockStream } ;
1414
@@ -91,23 +91,56 @@ impl Listener {
9191
9292#[ derive( Debug ) ]
9393struct Connection {
94- // Preliminary work for PLAT-367
95- #[ allow( dead_code) ]
96- remote : Addr ,
97- // Preliminary work for PLAT-367
98- #[ allow( dead_code) ]
99- runner : Addr ,
94+ tcp_stream : TcpStream ,
95+ vsock_stream : VsockStream < Std > ,
96+ remote_name : String ,
10097}
10198
10299impl Connection {
103- fn from_tcp_stream ( stream : & TcpStream ) -> Self {
104- let tcp_remote = stream. peer_addr ( ) . unwrap ( ) . into ( ) ;
105- let tcp_runner = stream. local_addr ( ) . unwrap ( ) . into ( ) ;
100+ fn new ( vsock_stream : VsockStream < Std > , tcp_stream : TcpStream , remote_name : String ) -> Self {
106101 Connection {
107- remote : tcp_remote,
108- runner : tcp_runner,
102+ tcp_stream,
103+ vsock_stream,
104+ remote_name,
109105 }
110106 }
107+
108+ fn close ( & self ) {
109+ let _ = self . tcp_stream . shutdown ( Shutdown :: Both ) ;
110+ let _ = self . vsock_stream . shutdown ( Shutdown :: Both ) ;
111+ }
112+
113+ /// Exchanges messages between the remote server and enclave. Returns `true` when the
114+ /// connection should remain active, false otherwise
115+ fn proxy ( & mut self ) -> bool {
116+ fn exchange < S : StreamConnection , D : StreamConnection > ( src : & mut S , src_name : & str , dst : & mut D , dst_name : & str ) -> bool {
117+ // According to the `Read` threat documentation, reading 0 bytes
118+ // indicates that the connection has been shutdown correctly. So we
119+ // close the proxy service
120+ // https://doc.rust-lang.org/std/io/trait.Read.html#tymethod.read
121+ match Server :: transfer_data ( src, src_name, dst, dst_name) {
122+ Ok ( n) if n == 0 => false ,
123+ Ok ( _) => true ,
124+ Err ( _) => false ,
125+ }
126+ }
127+ let remote = self . tcp_stream . as_raw_fd ( ) ;
128+ let proxy = self . vsock_stream . as_raw_fd ( ) ;
129+
130+ let mut read_set = FdSet :: new ( ) ;
131+ read_set. insert ( remote) ;
132+ read_set. insert ( proxy) ;
133+
134+ if let Ok ( _num) = select ( None , Some ( & mut read_set) , None , None , None ) {
135+ if read_set. contains ( remote) {
136+ return exchange ( & mut self . tcp_stream , & self . remote_name , & mut self . vsock_stream , "proxy" ) ;
137+ }
138+ if read_set. contains ( proxy) {
139+ return exchange ( & mut self . vsock_stream , "proxy" , & mut self . tcp_stream , & self . remote_name ) ;
140+ }
141+ }
142+ true
143+ }
111144}
112145
113146#[ derive( Clone , Debug , Eq , Hash , PartialEq ) ]
@@ -146,8 +179,8 @@ pub struct Server {
146179 /// When the enclave instructs to accept a new connection, the runner accepts a new TCP
147180 /// connection. It then locates the ListenerInfo and finds the information it needs to set up a
148181 /// new vsock connection to the enclave
149- listeners : Mutex < FnvHashMap < VsockAddr , Arc < Mutex < Listener > > > > ,
150- connections : Mutex < FnvHashMap < ConnectionKey , Arc < Mutex < Connection > > > > ,
182+ listeners : RwLock < FnvHashMap < VsockAddr , Arc < Mutex < Listener > > > > ,
183+ connections : RwLock < FnvHashMap < ConnectionKey , Arc < Mutex < Connection > > > > ,
151184}
152185
153186impl Server {
@@ -230,9 +263,9 @@ impl Server {
230263 * [2] remote
231264 * [3] proxy
232265 */
233- fn handle_request_connect ( & self , remote_addr : & String , enclave : & mut VsockStream ) -> Result < ( ) , IoError > {
266+ fn handle_request_connect ( server : Arc < Self > , remote_addr : & String , enclave : & mut VsockStream ) -> Result < ( ) , IoError > {
234267 // Connect to remote server
235- let mut remote_socket = TcpStream :: connect ( remote_addr) ?;
268+ let remote_socket = TcpStream :: connect ( remote_addr) ?;
236269 let remote_name = remote_addr. split_terminator ( ":" ) . next ( ) . unwrap_or ( remote_addr) ;
237270
238271 // Create listening socket that the enclave can connect to
@@ -257,50 +290,37 @@ impl Server {
257290 Self :: send ( enclave, & response) ?;
258291
259292 // Wait for incoming connection from enclave
260- let ( mut proxy, _proxy_addr) = proxy_server. accept ( ) ?;
293+ let ( proxy, _proxy_addr) = proxy_server. accept ( ) ?;
261294
262295 // Store connection info
263- let k = self . add_connection ( & proxy, & remote_socket) ;
296+ server . add_connection ( proxy, remote_socket, remote_name . to_string ( ) ) ;
264297
265- // Pass messages between remote server <-> enclave
266- Self :: proxy_connection ( ( & mut remote_socket, remote_name) , ( & mut proxy, "proxy" ) ) ;
267-
268- // Remove connection info
269- self . remove_connection ( & k) ;
270298 Ok ( ( ) )
271299 }
272300
273301 fn add_listener ( & self , addr : VsockAddr , info : Listener ) {
274- self . listeners . lock ( ) . unwrap ( ) . insert ( addr, Arc :: new ( Mutex :: new ( info) ) ) ;
302+ self . listeners . write ( ) . unwrap ( ) . insert ( addr, Arc :: new ( Mutex :: new ( info) ) ) ;
275303 }
276304
277305 fn listener ( & self , addr : & VsockAddr ) -> Option < Arc < Mutex < Listener > > > {
278- self . listeners . lock ( ) . unwrap ( ) . get ( & addr) . cloned ( )
306+ self . listeners . read ( ) . unwrap ( ) . get ( & addr) . cloned ( )
279307 }
280308
281309 // Preliminary work for PLAT-367
282310 #[ allow( dead_code) ]
283311 fn connection ( & self , enclave : VsockAddr , runner : VsockAddr ) -> Option < Arc < Mutex < Connection > > > {
284312 let k = ConnectionKey :: from_addresses ( enclave, runner) ;
285313 self . connections
286- . lock ( )
314+ . read ( )
287315 . unwrap ( )
288316 . get ( & k)
289317 . cloned ( )
290318 }
291319
292- fn add_connection ( & self , runner_enclave : & VsockStream < Std > , runner_remote : & TcpStream ) -> ConnectionKey {
293- let k = ConnectionKey :: from_vsock_stream ( runner_enclave) ;
294- let info = Connection :: from_tcp_stream ( runner_remote) ;
295- self . connections . lock ( ) . unwrap ( ) . insert ( k. clone ( ) , Arc :: new ( Mutex :: new ( info) ) ) ;
296- k
297- }
298-
299- fn remove_connection ( & self , k : & ConnectionKey ) {
300- self . connections
301- . lock ( )
302- . unwrap ( )
303- . remove ( & k) ;
320+ fn add_connection ( & self , runner_enclave : VsockStream < Std > , runner_remote : TcpStream , remote_name : String ) {
321+ let k = ConnectionKey :: from_vsock_stream ( & runner_enclave) ;
322+ let info = Connection :: new ( runner_enclave, runner_remote, remote_name) ;
323+ self . connections . write ( ) . unwrap ( ) . insert ( k. clone ( ) , Arc :: new ( Mutex :: new ( info) ) ) ;
304324 }
305325
306326 /*
@@ -326,11 +346,11 @@ impl Server {
326346 * runner
327347 * `enclave`: The runner-enclave vsock connection
328348 */
329- fn handle_request_bind ( & self , addr : & String , enclave_port : u32 , enclave : & mut VsockStream ) -> Result < ( ) , IoError > {
349+ fn handle_request_bind ( server : Arc < Self > , addr : & String , enclave_port : u32 , enclave : & mut VsockStream ) -> Result < ( ) , IoError > {
330350 let cid: u32 = enclave. peer ( ) . unwrap ( ) . parse ( ) . unwrap_or ( vsock:: VMADDR_CID_HYPERVISOR ) ;
331351 let listener = TcpListener :: bind ( addr) ?;
332352 let local: Addr = listener. local_addr ( ) ?. into ( ) ;
333- self . add_listener ( VsockAddr :: new ( cid, enclave_port) , Listener :: new ( listener) ) ;
353+ server . add_listener ( VsockAddr :: new ( cid, enclave_port) , Listener :: new ( listener) ) ;
334354 let response = Response :: Bound { local } ;
335355 Self :: log_communication (
336356 "runner" ,
@@ -344,15 +364,15 @@ impl Server {
344364 Ok ( ( ) )
345365 }
346366
347- fn handle_request_accept ( & self , vsock_listener_port : u32 , enclave : & mut VsockStream ) -> Result < ( ) , IoError > {
367+ fn handle_request_accept ( server : Arc < Self > , vsock_listener_port : u32 , enclave : & mut VsockStream ) -> Result < ( ) , IoError > {
348368 let enclave_cid: u32 = enclave. peer ( ) . unwrap ( ) . parse ( ) . unwrap_or ( vsock:: VMADDR_CID_HYPERVISOR ) ;
349369 let enclave_addr = VsockAddr :: new ( enclave_cid, vsock_listener_port) ;
350- let listener = self . listener ( & enclave_addr)
370+ let listener = server . listener ( & enclave_addr)
351371 . ok_or ( IoError :: new ( IoErrorKind :: InvalidInput , "Information about provided file descriptor was not found" ) ) ?;
352372 let listener = listener. lock ( ) . unwrap ( ) ;
353373
354374 match listener. listener . accept ( ) {
355- Ok ( ( mut conn, peer) ) => {
375+ Ok ( ( conn, peer) ) => {
356376 let vsock = Vsock :: new :: < Std > ( ) ?;
357377 let runner_addr = vsock. addr :: < Std > ( ) ?;
358378 let response = Response :: IncomingConnection {
@@ -369,62 +389,55 @@ impl Server {
369389 Direction :: Right ,
370390 "vsock" ) ;
371391 enclave. write ( & serde_cbor:: ser:: to_vec ( & response) . unwrap ( ) ) ?;
372- let _ = thread:: Builder :: new ( ) . spawn ( move || {
373- let mut proxy = vsock. connect_with_cid_port ( enclave_addr. cid ( ) , enclave_addr. port ( ) ) . unwrap ( ) ;
374- //let k = self.add_connection(&proxy, &conn);
375- Self :: proxy_connection ( ( & mut conn, "remote" ) , ( & mut proxy, "proxy" ) ) ;
376- //self.remove_connection(&k);
377- } ) ;
392+
393+ let proxy = vsock. connect_with_cid_port ( enclave_addr. cid ( ) , enclave_addr. port ( ) ) . unwrap ( ) ;
394+ server. add_connection ( proxy, conn, "remote" . to_string ( ) ) ;
395+
378396 Ok ( ( ) )
379397 } ,
380398 Err ( e) => Err ( e) ,
381399 }
382400 }
383401
384- fn proxy_connection ( remote : ( & mut TcpStream , & str ) , proxy : ( & mut VsockStream , & str ) ) {
402+ fn proxy_connections ( server : Arc < Server > ) {
403+ let mut closed_connections = Vec :: new ( ) ;
404+
385405 loop {
386- let mut read_set = FdSet :: new ( ) ;
387- read_set. insert ( remote. 0 . as_raw_fd ( ) ) ;
388- read_set. insert ( proxy. 0 . as_raw_fd ( ) ) ;
389-
390- if let Ok ( _num) = select ( None , Some ( & mut read_set) , None , None , None ) {
391- if read_set. contains ( remote. 0 . as_raw_fd ( ) ) {
392- match Self :: transfer_data ( remote. 0 , remote. 1 , proxy. 0 , proxy. 1 ) {
393- Ok ( 0 ) => {
394- // According to the `Read` threat documentation, reading 0 bytes
395- // indicates that the connection has been shutdown correctly. So we
396- // close the proxy service
397- // https://doc.rust-lang.org/std/io/trait.Read.html#tymethod.read
398- break
399- } ,
400- Ok ( _) => ( ) ,
401- Err ( e) => {
402- eprintln ! ( "transfer from remote failed: {:?}" , e) ;
403- break ;
404- }
406+ // Exchange messages on every proxy connection
407+ // TODO: Store connections as a linked hash map so we don't need to keep a read lock
408+ // over the HashMap while every connection is serviced
409+ if let Ok ( connections) = server. connections . read ( ) {
410+ for ( key, connection) in connections. iter ( ) {
411+ match connection. try_lock ( ) {
412+ Ok ( mut connection) => if !connection. proxy ( ) {
413+ connection. close ( ) ;
414+ closed_connections. push ( key. clone ( ) ) ;
415+ }
416+ Err ( _) => ( ) ,
405417 }
406418 }
407- if read_set. contains ( proxy. 0 . as_raw_fd ( ) ) {
408- match Self :: transfer_data ( proxy. 0 , proxy. 1 , remote. 0 , remote. 1 ) {
409- Ok ( 0 ) => break ,
410- Ok ( _) => ( ) ,
411- Err ( e) => {
412- eprintln ! ( "transfer from proxy failed: {:?}" , e) ;
413- break ;
414- }
415- }
419+ }
420+
421+ // Remove closed connections
422+ let mut num_connections = None ;
423+ if let Ok ( mut connections) = server. connections . try_write ( ) {
424+ while let Some ( k) = closed_connections. pop ( ) {
425+ connections. remove ( & k) ;
416426 }
427+ num_connections = Some ( connections. len ( ) ) ;
428+ }
429+
430+ if num_connections == Some ( 0 ) {
431+ thread:: yield_now ( ) ;
417432 }
418433 }
419- let _ = proxy. 0 . shutdown ( Shutdown :: Both ) ;
420- let _ = remote. 0 . shutdown ( Shutdown :: Both ) ;
421434 }
422435
423- fn handle_client ( & self , stream : & mut VsockStream ) -> Result < ( ) , IoError > {
436+ fn handle_client ( server : Arc < Self > , stream : & mut VsockStream ) -> Result < ( ) , IoError > {
424437 match Self :: read_request ( stream) {
425- Ok ( Request :: Connect { addr } ) => self . handle_request_connect ( & addr, stream) ?,
426- Ok ( Request :: Bind { addr, enclave_port } ) => self . handle_request_bind ( & addr, enclave_port, stream) ?,
427- Ok ( Request :: Accept { enclave_port } ) => self . handle_request_accept ( enclave_port, stream) ?,
438+ Ok ( Request :: Connect { addr } ) => Self :: handle_request_connect ( server , & addr, stream) ?,
439+ Ok ( Request :: Bind { addr, enclave_port } ) => Self :: handle_request_bind ( server , & addr, enclave_port, stream) ?,
440+ Ok ( Request :: Accept { enclave_port } ) => Self :: handle_request_accept ( server , enclave_port, stream) ?,
428441 Err ( _e) => return Err ( IoError :: new ( IoErrorKind :: InvalidData , "Failed to read request" ) ) ,
429442 } ;
430443 Ok ( ( ) )
@@ -434,34 +447,45 @@ impl Server {
434447 let command_listener = VsockListener :: < Std > :: bind_with_cid_port ( vsock:: VMADDR_CID_ANY , port) ?;
435448 Ok ( Server {
436449 command_listener : Mutex :: new ( command_listener) ,
437- listeners : Mutex :: new ( FnvHashMap :: default ( ) ) ,
438- connections : Mutex :: new ( FnvHashMap :: default ( ) ) ,
450+ listeners : RwLock :: new ( FnvHashMap :: default ( ) ) ,
451+ connections : RwLock :: new ( FnvHashMap :: default ( ) ) ,
452+ } )
453+ }
454+
455+ fn start_proxy_server ( server : Arc < Server > ) -> Result < JoinHandle < ( ) > , IoError > {
456+ thread:: Builder :: new ( )
457+ . spawn ( move || {
458+ Self :: proxy_connections ( server) ;
459+ } )
460+ }
461+
462+ fn start_command_server ( server : Arc < Server > ) -> Result < JoinHandle < ( ) > , IoError > {
463+ thread:: Builder :: new ( ) . spawn ( move || {
464+ let command_listener = server. command_listener . lock ( ) . unwrap ( ) ;
465+ for stream in command_listener. incoming ( ) {
466+ let server = server. clone ( ) ;
467+ let _ = thread:: Builder :: new ( )
468+ . spawn ( move || {
469+ let mut stream = stream. unwrap ( ) ;
470+ if let Err ( e) = Self :: handle_client ( server, & mut stream) {
471+ eprintln ! ( "Error handling connection: {}, shutting connection down" , e) ;
472+ let _ = stream. shutdown ( Shutdown :: Both ) ;
473+ }
474+ } ) ;
475+ }
439476 } )
440477 }
441478
442- pub fn run ( port : u32 ) -> std:: io:: Result < ( JoinHandle < ( ) > , u32 ) > {
479+ pub fn run ( port : u32 ) -> std:: io:: Result < JoinHandle < ( ) > > {
443480 println ! ( "Starting enclave runner." ) ;
444481 let server = Arc :: new ( Self :: bind ( port) ?) ;
445482 let port = server. command_listener . lock ( ) . unwrap ( ) . local_addr ( ) ?. port ( ) ;
446483 println ! ( "Listening on vsock port {}..." , port) ;
447484
448- let handle = thread:: Builder :: new ( ) . spawn ( move || {
449- let server = server;
450- let server = server. clone ( ) ;
451- let command_listener = server. command_listener . lock ( ) . unwrap ( ) ;
452- for stream in command_listener. incoming ( ) {
453- let server = server. clone ( ) ;
454- let _ = thread:: Builder :: new ( )
455- . spawn ( move || {
456- let mut stream = stream. unwrap ( ) ;
457- if let Err ( e) = server. handle_client ( & mut stream) {
458- eprintln ! ( "Error handling connection: {}, shutting connection down" , e) ;
459- let _ = stream. shutdown ( Shutdown :: Both ) ;
460- }
461- } ) ;
462- }
463- } ) ?;
464- Ok ( ( handle, port) )
485+ Server :: start_proxy_server ( server. clone ( ) ) ?;
486+ let handle = Server :: start_command_server ( server. clone ( ) ) ?;
487+
488+ Ok ( handle)
465489 }
466490}
467491
0 commit comments