Skip to content

Commit 2247591

Browse files
committed
Address reviewer comments
1 parent ae076b7 commit 2247591

File tree

1 file changed

+67
-97
lines changed
  • fortanix-vme/fortanix-vme-runner/src

1 file changed

+67
-97
lines changed

fortanix-vme/fortanix-vme-runner/src/lib.rs

Lines changed: 67 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -96,50 +96,57 @@ struct Connection {
9696
remote_name: String,
9797
}
9898

99+
#[derive(Clone, Debug)]
100+
struct ConnectionInfo {
101+
}
102+
99103
impl Connection {
100-
fn new(vsock_stream: VsockStream<Std>, tcp_stream: TcpStream, remote_name: String) -> Self {
104+
pub fn new(vsock_stream: VsockStream<Std>, tcp_stream: TcpStream, remote_name: String) -> Self {
101105
Connection {
102106
tcp_stream,
103107
vsock_stream,
104108
remote_name,
105109
}
106110
}
107111

108-
fn close(&self) {
109-
let _ = self.tcp_stream.shutdown(Shutdown::Both);
110-
let _ = self.vsock_stream.shutdown(Shutdown::Both);
112+
pub fn info(&self) -> ConnectionInfo {
113+
ConnectionInfo{}
111114
}
112115

113-
/// Exchanges messages between the remote server and enclave. Returns `true` when the
114-
/// connection should remain active, false otherwise
115-
fn proxy(&mut self) -> bool {
116-
fn exchange<S: StreamConnection, D: StreamConnection>(src: &mut S, src_name: &str, dst: &mut D, dst_name: &str) -> bool {
117-
// According to the `Read` threat documentation, reading 0 bytes
118-
// indicates that the connection has been shutdown correctly. So we
119-
// close the proxy service
120-
// https://doc.rust-lang.org/std/io/trait.Read.html#tymethod.read
121-
match Server::transfer_data(src, src_name, dst, dst_name) {
122-
Ok(n) if n == 0 => false,
123-
Ok(_) => true,
124-
Err(_) => false,
125-
}
126-
}
127-
let remote = self.tcp_stream.as_raw_fd();
128-
let proxy = self.vsock_stream.as_raw_fd();
116+
/// Exchanges messages between the remote server and enclave. Returns on error, or when one of
117+
/// the connections terminated
118+
pub fn proxy(&mut self) -> Result<(), IoError> {
119+
let remote = &mut self.tcp_stream;
120+
let enclave = &mut self.vsock_stream;
129121

130-
let mut read_set = FdSet::new();
131-
read_set.insert(remote);
132-
read_set.insert(proxy);
122+
let mut golden_set = FdSet::new();
123+
golden_set.insert(remote.as_raw_fd());
124+
golden_set.insert(enclave.as_raw_fd());
133125

134-
if let Ok(_num) = select(None, Some(&mut read_set), None, None, None) {
135-
if read_set.contains(remote) {
136-
return exchange(&mut self.tcp_stream, &self.remote_name, &mut self.vsock_stream, "proxy");
137-
}
138-
if read_set.contains(proxy) {
139-
return exchange(&mut self.vsock_stream, "proxy", &mut self.tcp_stream, &self.remote_name);
126+
while golden_set != FdSet::new() {
127+
let mut read_set = golden_set.clone();
128+
129+
if let Ok(_num) = select(None, Some(&mut read_set), None, None, None) {
130+
if read_set.contains(remote.as_raw_fd()) {
131+
// According to the `Read` trait documentation, reading 0 bytes
132+
// indicates that the connection has been shutdown (for writes) correctly. We
133+
// - reflect this change on the other connection
134+
// - avoid reading from the socket again
135+
// https://doc.rust-lang.org/std/io/trait.Read.html#tymethod.read
136+
if Server::transfer_data(remote, &self.remote_name, enclave, "enclave")? == 0 {
137+
enclave.shutdown(Shutdown::Write)?;
138+
golden_set.remove(remote.as_raw_fd());
139+
}
140+
}
141+
if read_set.contains(enclave.as_raw_fd()) {
142+
if Server::transfer_data(enclave, "enclave", remote, &self.remote_name)? == 0 {
143+
remote.shutdown(Shutdown::Write)?;
144+
golden_set.remove(enclave.as_raw_fd());
145+
}
146+
}
140147
}
141148
}
142-
true
149+
Ok(())
143150
}
144151
}
145152

@@ -150,15 +157,15 @@ struct ConnectionKey {
150157
}
151158

152159
impl ConnectionKey {
153-
fn from_vsock_stream(runner_enclave: &VsockStream<Std>) -> Self {
160+
pub fn from_vsock_stream(runner_enclave: &VsockStream<Std>) -> Self {
154161
let runner_cid = runner_enclave.local_addr().unwrap().cid();
155162
let runner_port = runner_enclave.local_addr().unwrap().port();
156163
let enclave_cid = runner_enclave.peer_addr().unwrap().cid();
157164
let enclave_port = runner_enclave.peer_addr().unwrap().port();
158165
Self::connection_key(enclave_cid, enclave_port, runner_cid, runner_port)
159166
}
160167

161-
fn from_addresses(enclave: VsockAddr, runner: VsockAddr) -> Self {
168+
pub fn from_addresses(enclave: VsockAddr, runner: VsockAddr) -> Self {
162169
ConnectionKey {
163170
enclave,
164171
runner,
@@ -180,7 +187,7 @@ pub struct Server {
180187
/// connection. It then locates the ListenerInfo and finds the information it needs to set up a
181188
/// new vsock connection to the enclave
182189
listeners: RwLock<FnvHashMap<VsockAddr, Arc<Mutex<Listener>>>>,
183-
connections: RwLock<FnvHashMap<ConnectionKey, Arc<Mutex<Connection>>>>,
190+
connections: RwLock<FnvHashMap<ConnectionKey, ConnectionInfo>>,
184191
}
185192

186193
impl Server {
@@ -263,7 +270,7 @@ impl Server {
263270
* [2] remote
264271
* [3] proxy
265272
*/
266-
fn handle_request_connect(server: Arc<Self>, remote_addr: &String, enclave: &mut VsockStream) -> Result<(), IoError> {
273+
fn handle_request_connect(self: Arc<Self>, remote_addr: &String, enclave: &mut VsockStream) -> Result<(), IoError> {
267274
// Connect to remote server
268275
let remote_socket = TcpStream::connect(remote_addr)?;
269276
let remote_name = remote_addr.split_terminator(":").next().unwrap_or(remote_addr);
@@ -293,7 +300,7 @@ impl Server {
293300
let (proxy, _proxy_addr) = proxy_server.accept()?;
294301

295302
// Store connection info
296-
server.add_connection(proxy, remote_socket, remote_name.to_string());
303+
self.add_connection(proxy, remote_socket, remote_name.to_string())?;
297304

298305
Ok(())
299306
}
@@ -308,7 +315,7 @@ impl Server {
308315

309316
// Preliminary work for PLAT-367
310317
#[allow(dead_code)]
311-
fn connection(&self, enclave: VsockAddr, runner: VsockAddr) -> Option<Arc<Mutex<Connection>>> {
318+
fn connection(&self, enclave: VsockAddr, runner: VsockAddr) -> Option<ConnectionInfo> {
312319
let k = ConnectionKey::from_addresses(enclave, runner);
313320
self.connections
314321
.read()
@@ -317,10 +324,17 @@ impl Server {
317324
.cloned()
318325
}
319326

320-
fn add_connection(&self, runner_enclave: VsockStream<Std>, runner_remote: TcpStream, remote_name: String) {
327+
fn add_connection(self: Arc<Self>, runner_enclave: VsockStream<Std>, runner_remote: TcpStream, remote_name: String) -> Result<JoinHandle<()>, IoError> {
321328
let k = ConnectionKey::from_vsock_stream(&runner_enclave);
322-
let info = Connection::new(runner_enclave, runner_remote, remote_name);
323-
self.connections.write().unwrap().insert(k.clone(), Arc::new(Mutex::new(info)));
329+
let mut connection = Connection::new(runner_enclave, runner_remote, remote_name);
330+
self.connections.write().unwrap().insert(k.clone(), connection.info());
331+
332+
thread::Builder::new().spawn(move || {
333+
if let Err(e) = connection.proxy() {
334+
eprintln!("Connection failed: {}", e);
335+
}
336+
self.connections.write().unwrap().remove(&k);
337+
})
324338
}
325339

326340
/*
@@ -346,11 +360,11 @@ impl Server {
346360
* runner
347361
* `enclave`: The runner-enclave vsock connection
348362
*/
349-
fn handle_request_bind(server: Arc<Self>, addr: &String, enclave_port: u32, enclave: &mut VsockStream) -> Result<(), IoError> {
363+
fn handle_request_bind(self: Arc<Self>, addr: &String, enclave_port: u32, enclave: &mut VsockStream) -> Result<(), IoError> {
350364
let cid: u32 = enclave.peer().unwrap().parse().unwrap_or(vsock::VMADDR_CID_HYPERVISOR);
351365
let listener = TcpListener::bind(addr)?;
352366
let local: Addr = listener.local_addr()?.into();
353-
server.add_listener(VsockAddr::new(cid, enclave_port), Listener::new(listener));
367+
self.add_listener(VsockAddr::new(cid, enclave_port), Listener::new(listener));
354368
let response = Response::Bound{ local };
355369
Self::log_communication(
356370
"runner",
@@ -364,10 +378,10 @@ impl Server {
364378
Ok(())
365379
}
366380

367-
fn handle_request_accept(server: Arc<Self>, vsock_listener_port: u32, enclave: &mut VsockStream) -> Result<(), IoError> {
381+
fn handle_request_accept(self: Arc<Self>, vsock_listener_port: u32, enclave: &mut VsockStream) -> Result<(), IoError> {
368382
let enclave_cid: u32 = enclave.peer().unwrap().parse().unwrap_or(vsock::VMADDR_CID_HYPERVISOR);
369383
let enclave_addr = VsockAddr::new(enclave_cid, vsock_listener_port);
370-
let listener = server.listener(&enclave_addr)
384+
let listener = self.listener(&enclave_addr)
371385
.ok_or(IoError::new(IoErrorKind::InvalidInput, "Information about provided file descriptor was not found"))?;
372386
let listener = listener.lock().unwrap();
373387

@@ -391,53 +405,19 @@ impl Server {
391405
enclave.write(&serde_cbor::ser::to_vec(&response).unwrap())?;
392406

393407
let proxy = vsock.connect_with_cid_port(enclave_addr.cid(), enclave_addr.port()).unwrap();
394-
server.add_connection(proxy, conn, "remote".to_string());
408+
self.add_connection(proxy, conn, "remote".to_string())?;
395409

396410
Ok(())
397411
},
398412
Err(e) => Err(e),
399413
}
400414
}
401415

402-
fn proxy_connections(server: Arc<Server>) {
403-
let mut closed_connections = Vec::new();
404-
405-
loop {
406-
// Exchange messages on every proxy connection
407-
// TODO: Store connections as a linked hash map so we don't need to keep a read lock
408-
// over the HashMap while every connection is serviced
409-
if let Ok(connections) = server.connections.read() {
410-
for (key, connection) in connections.iter() {
411-
match connection.try_lock() {
412-
Ok(mut connection) => if !connection.proxy() {
413-
connection.close();
414-
closed_connections.push(key.clone());
415-
}
416-
Err(_) => (),
417-
}
418-
}
419-
}
420-
421-
// Remove closed connections
422-
let mut num_connections = None;
423-
if let Ok(mut connections) = server.connections.try_write() {
424-
while let Some(k) = closed_connections.pop() {
425-
connections.remove(&k);
426-
}
427-
num_connections = Some(connections.len());
428-
}
429-
430-
if num_connections == Some(0) {
431-
thread::yield_now();
432-
}
433-
}
434-
}
435-
436-
fn handle_client(server: Arc<Self>, stream: &mut VsockStream) -> Result<(), IoError> {
416+
fn handle_client(self: Arc<Self>, stream: &mut VsockStream) -> Result<(), IoError> {
437417
match Self::read_request(stream) {
438-
Ok(Request::Connect{ addr }) => Self::handle_request_connect(server, &addr, stream)?,
439-
Ok(Request::Bind{ addr, enclave_port }) => Self::handle_request_bind(server, &addr, enclave_port, stream)?,
440-
Ok(Request::Accept{ enclave_port }) => Self::handle_request_accept(server, enclave_port, stream)?,
418+
Ok(Request::Connect{ addr }) => self.handle_request_connect(&addr, stream)?,
419+
Ok(Request::Bind{ addr, enclave_port }) => self.handle_request_bind(&addr, enclave_port, stream)?,
420+
Ok(Request::Accept{ enclave_port }) => self.handle_request_accept(enclave_port, stream)?,
441421
Err(_e) => return Err(IoError::new(IoErrorKind::InvalidData, "Failed to read request")),
442422
};
443423
Ok(())
@@ -452,22 +432,15 @@ impl Server {
452432
})
453433
}
454434

455-
fn start_proxy_server(server: Arc<Server>) -> Result<JoinHandle<()>, IoError> {
456-
thread::Builder::new()
457-
.spawn(move || {
458-
Self::proxy_connections(server);
459-
})
460-
}
461-
462-
fn start_command_server(server: Arc<Server>) -> Result<JoinHandle<()>, IoError> {
435+
fn start_command_server(self: Arc<Self>) -> Result<JoinHandle<()>, IoError> {
463436
thread::Builder::new().spawn(move || {
464-
let command_listener = server.command_listener.lock().unwrap();
437+
let command_listener = self.command_listener.lock().unwrap();
465438
for stream in command_listener.incoming() {
466-
let server = server.clone();
439+
let server = self.clone();
467440
let _ = thread::Builder::new()
468441
.spawn(move || {
469442
let mut stream = stream.unwrap();
470-
if let Err(e) = Self::handle_client(server, &mut stream) {
443+
if let Err(e) = server.handle_client(&mut stream) {
471444
eprintln!("Error handling connection: {}, shutting connection down", e);
472445
let _ = stream.shutdown(Shutdown::Both);
473446
}
@@ -482,10 +455,7 @@ impl Server {
482455
let port = server.command_listener.lock().unwrap().local_addr()?.port();
483456
println!("Listening on vsock port {}...", port);
484457

485-
Server::start_proxy_server(server.clone())?;
486-
let handle = Server::start_command_server(server.clone())?;
487-
488-
Ok(handle)
458+
server.start_command_server()
489459
}
490460
}
491461

0 commit comments

Comments
 (0)