Skip to content

Commit ae076b7

Browse files
committed
Refactoring proxy threads
1 parent 61a6ab1 commit ae076b7

File tree

2 files changed

+129
-105
lines changed

2 files changed

+129
-105
lines changed

fortanix-vme/fortanix-vme-runner/src/lib.rs

Lines changed: 128 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use std::thread::{self, JoinHandle};
88
use std::io::{self, Error as IoError, ErrorKind as IoErrorKind, Read, Write};
99
use std::net::{Shutdown, TcpListener, TcpStream};
1010
use std::os::unix::io::AsRawFd;
11-
use std::sync::{Arc, Mutex};
11+
use std::sync::{Arc, Mutex, RwLock};
1212
use fortanix_vme_abi::{self, Addr, Response, Request};
1313
use vsock::{self, SockAddr as VsockAddr, Std, Vsock, VsockListener, VsockStream};
1414

@@ -91,23 +91,56 @@ impl Listener {
9191

9292
#[derive(Debug)]
9393
struct Connection {
94-
// Preliminary work for PLAT-367
95-
#[allow(dead_code)]
96-
remote: Addr,
97-
// Preliminary work for PLAT-367
98-
#[allow(dead_code)]
99-
runner: Addr,
94+
tcp_stream: TcpStream,
95+
vsock_stream: VsockStream<Std>,
96+
remote_name: String,
10097
}
10198

10299
impl Connection {
103-
fn from_tcp_stream(stream: &TcpStream) -> Self {
104-
let tcp_remote = stream.peer_addr().unwrap().into();
105-
let tcp_runner = stream.local_addr().unwrap().into();
100+
fn new(vsock_stream: VsockStream<Std>, tcp_stream: TcpStream, remote_name: String) -> Self {
106101
Connection {
107-
remote: tcp_remote,
108-
runner: tcp_runner,
102+
tcp_stream,
103+
vsock_stream,
104+
remote_name,
109105
}
110106
}
107+
108+
fn close(&self) {
109+
let _ = self.tcp_stream.shutdown(Shutdown::Both);
110+
let _ = self.vsock_stream.shutdown(Shutdown::Both);
111+
}
112+
113+
/// Exchanges messages between the remote server and enclave. Returns `true` when the
114+
/// connection should remain active, false otherwise
115+
fn proxy(&mut self) -> bool {
116+
fn exchange<S: StreamConnection, D: StreamConnection>(src: &mut S, src_name: &str, dst: &mut D, dst_name: &str) -> bool {
117+
// According to the `Read` threat documentation, reading 0 bytes
118+
// indicates that the connection has been shutdown correctly. So we
119+
// close the proxy service
120+
// https://doc.rust-lang.org/std/io/trait.Read.html#tymethod.read
121+
match Server::transfer_data(src, src_name, dst, dst_name) {
122+
Ok(n) if n == 0 => false,
123+
Ok(_) => true,
124+
Err(_) => false,
125+
}
126+
}
127+
let remote = self.tcp_stream.as_raw_fd();
128+
let proxy = self.vsock_stream.as_raw_fd();
129+
130+
let mut read_set = FdSet::new();
131+
read_set.insert(remote);
132+
read_set.insert(proxy);
133+
134+
if let Ok(_num) = select(None, Some(&mut read_set), None, None, None) {
135+
if read_set.contains(remote) {
136+
return exchange(&mut self.tcp_stream, &self.remote_name, &mut self.vsock_stream, "proxy");
137+
}
138+
if read_set.contains(proxy) {
139+
return exchange(&mut self.vsock_stream, "proxy", &mut self.tcp_stream, &self.remote_name);
140+
}
141+
}
142+
true
143+
}
111144
}
112145

113146
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
@@ -146,8 +179,8 @@ pub struct Server {
146179
/// When the enclave instructs to accept a new connection, the runner accepts a new TCP
147180
/// connection. It then locates the ListenerInfo and finds the information it needs to set up a
148181
/// new vsock connection to the enclave
149-
listeners: Mutex<FnvHashMap<VsockAddr, Arc<Mutex<Listener>>>>,
150-
connections: Mutex<FnvHashMap<ConnectionKey, Arc<Mutex<Connection>>>>,
182+
listeners: RwLock<FnvHashMap<VsockAddr, Arc<Mutex<Listener>>>>,
183+
connections: RwLock<FnvHashMap<ConnectionKey, Arc<Mutex<Connection>>>>,
151184
}
152185

153186
impl Server {
@@ -230,9 +263,9 @@ impl Server {
230263
* [2] remote
231264
* [3] proxy
232265
*/
233-
fn handle_request_connect(&self, remote_addr: &String, enclave: &mut VsockStream) -> Result<(), IoError> {
266+
fn handle_request_connect(server: Arc<Self>, remote_addr: &String, enclave: &mut VsockStream) -> Result<(), IoError> {
234267
// Connect to remote server
235-
let mut remote_socket = TcpStream::connect(remote_addr)?;
268+
let remote_socket = TcpStream::connect(remote_addr)?;
236269
let remote_name = remote_addr.split_terminator(":").next().unwrap_or(remote_addr);
237270

238271
// Create listening socket that the enclave can connect to
@@ -257,50 +290,37 @@ impl Server {
257290
Self::send(enclave, &response)?;
258291

259292
// Wait for incoming connection from enclave
260-
let (mut proxy, _proxy_addr) = proxy_server.accept()?;
293+
let (proxy, _proxy_addr) = proxy_server.accept()?;
261294

262295
// Store connection info
263-
let k = self.add_connection(&proxy, &remote_socket);
296+
server.add_connection(proxy, remote_socket, remote_name.to_string());
264297

265-
// Pass messages between remote server <-> enclave
266-
Self::proxy_connection((&mut remote_socket, remote_name), (&mut proxy, "proxy"));
267-
268-
// Remove connection info
269-
self.remove_connection(&k);
270298
Ok(())
271299
}
272300

273301
fn add_listener(&self, addr: VsockAddr, info: Listener) {
274-
self.listeners.lock().unwrap().insert(addr, Arc::new(Mutex::new(info)));
302+
self.listeners.write().unwrap().insert(addr, Arc::new(Mutex::new(info)));
275303
}
276304

277305
fn listener(&self, addr: &VsockAddr) -> Option<Arc<Mutex<Listener>>> {
278-
self.listeners.lock().unwrap().get(&addr).cloned()
306+
self.listeners.read().unwrap().get(&addr).cloned()
279307
}
280308

281309
// Preliminary work for PLAT-367
282310
#[allow(dead_code)]
283311
fn connection(&self, enclave: VsockAddr, runner: VsockAddr) -> Option<Arc<Mutex<Connection>>> {
284312
let k = ConnectionKey::from_addresses(enclave, runner);
285313
self.connections
286-
.lock()
314+
.read()
287315
.unwrap()
288316
.get(&k)
289317
.cloned()
290318
}
291319

292-
fn add_connection(&self, runner_enclave: &VsockStream<Std>, runner_remote: &TcpStream) -> ConnectionKey {
293-
let k = ConnectionKey::from_vsock_stream(runner_enclave);
294-
let info = Connection::from_tcp_stream(runner_remote);
295-
self.connections.lock().unwrap().insert(k.clone(), Arc::new(Mutex::new(info)));
296-
k
297-
}
298-
299-
fn remove_connection(&self, k: &ConnectionKey) {
300-
self.connections
301-
.lock()
302-
.unwrap()
303-
.remove(&k);
320+
fn add_connection(&self, runner_enclave: VsockStream<Std>, runner_remote: TcpStream, remote_name: String) {
321+
let k = ConnectionKey::from_vsock_stream(&runner_enclave);
322+
let info = Connection::new(runner_enclave, runner_remote, remote_name);
323+
self.connections.write().unwrap().insert(k.clone(), Arc::new(Mutex::new(info)));
304324
}
305325

306326
/*
@@ -326,11 +346,11 @@ impl Server {
326346
* runner
327347
* `enclave`: The runner-enclave vsock connection
328348
*/
329-
fn handle_request_bind(&self, addr: &String, enclave_port: u32, enclave: &mut VsockStream) -> Result<(), IoError> {
349+
fn handle_request_bind(server: Arc<Self>, addr: &String, enclave_port: u32, enclave: &mut VsockStream) -> Result<(), IoError> {
330350
let cid: u32 = enclave.peer().unwrap().parse().unwrap_or(vsock::VMADDR_CID_HYPERVISOR);
331351
let listener = TcpListener::bind(addr)?;
332352
let local: Addr = listener.local_addr()?.into();
333-
self.add_listener(VsockAddr::new(cid, enclave_port), Listener::new(listener));
353+
server.add_listener(VsockAddr::new(cid, enclave_port), Listener::new(listener));
334354
let response = Response::Bound{ local };
335355
Self::log_communication(
336356
"runner",
@@ -344,15 +364,15 @@ impl Server {
344364
Ok(())
345365
}
346366

347-
fn handle_request_accept(&self, vsock_listener_port: u32, enclave: &mut VsockStream) -> Result<(), IoError> {
367+
fn handle_request_accept(server: Arc<Self>, vsock_listener_port: u32, enclave: &mut VsockStream) -> Result<(), IoError> {
348368
let enclave_cid: u32 = enclave.peer().unwrap().parse().unwrap_or(vsock::VMADDR_CID_HYPERVISOR);
349369
let enclave_addr = VsockAddr::new(enclave_cid, vsock_listener_port);
350-
let listener = self.listener(&enclave_addr)
370+
let listener = server.listener(&enclave_addr)
351371
.ok_or(IoError::new(IoErrorKind::InvalidInput, "Information about provided file descriptor was not found"))?;
352372
let listener = listener.lock().unwrap();
353373

354374
match listener.listener.accept() {
355-
Ok((mut conn, peer)) => {
375+
Ok((conn, peer)) => {
356376
let vsock = Vsock::new::<Std>()?;
357377
let runner_addr = vsock.addr::<Std>()?;
358378
let response = Response::IncomingConnection{
@@ -369,62 +389,55 @@ impl Server {
369389
Direction::Right,
370390
"vsock");
371391
enclave.write(&serde_cbor::ser::to_vec(&response).unwrap())?;
372-
let _ = thread::Builder::new().spawn(move || {
373-
let mut proxy = vsock.connect_with_cid_port(enclave_addr.cid(), enclave_addr.port()).unwrap();
374-
//let k = self.add_connection(&proxy, &conn);
375-
Self::proxy_connection((&mut conn, "remote"), (&mut proxy, "proxy"));
376-
//self.remove_connection(&k);
377-
});
392+
393+
let proxy = vsock.connect_with_cid_port(enclave_addr.cid(), enclave_addr.port()).unwrap();
394+
server.add_connection(proxy, conn, "remote".to_string());
395+
378396
Ok(())
379397
},
380398
Err(e) => Err(e),
381399
}
382400
}
383401

384-
fn proxy_connection(remote: (&mut TcpStream, &str), proxy: (&mut VsockStream, &str)) {
402+
fn proxy_connections(server: Arc<Server>) {
403+
let mut closed_connections = Vec::new();
404+
385405
loop {
386-
let mut read_set = FdSet::new();
387-
read_set.insert(remote.0.as_raw_fd());
388-
read_set.insert(proxy.0.as_raw_fd());
389-
390-
if let Ok(_num) = select(None, Some(&mut read_set), None, None, None) {
391-
if read_set.contains(remote.0.as_raw_fd()) {
392-
match Self::transfer_data(remote.0, remote.1, proxy.0, proxy.1) {
393-
Ok(0) => {
394-
// According to the `Read` threat documentation, reading 0 bytes
395-
// indicates that the connection has been shutdown correctly. So we
396-
// close the proxy service
397-
// https://doc.rust-lang.org/std/io/trait.Read.html#tymethod.read
398-
break
399-
},
400-
Ok(_) => (),
401-
Err(e) => {
402-
eprintln!("transfer from remote failed: {:?}", e);
403-
break;
404-
}
406+
// Exchange messages on every proxy connection
407+
// TODO: Store connections as a linked hash map so we don't need to keep a read lock
408+
// over the HashMap while every connection is serviced
409+
if let Ok(connections) = server.connections.read() {
410+
for (key, connection) in connections.iter() {
411+
match connection.try_lock() {
412+
Ok(mut connection) => if !connection.proxy() {
413+
connection.close();
414+
closed_connections.push(key.clone());
415+
}
416+
Err(_) => (),
405417
}
406418
}
407-
if read_set.contains(proxy.0.as_raw_fd()) {
408-
match Self::transfer_data(proxy.0, proxy.1, remote.0, remote.1) {
409-
Ok(0) => break,
410-
Ok(_) => (),
411-
Err(e) => {
412-
eprintln!("transfer from proxy failed: {:?}", e);
413-
break;
414-
}
415-
}
419+
}
420+
421+
// Remove closed connections
422+
let mut num_connections = None;
423+
if let Ok(mut connections) = server.connections.try_write() {
424+
while let Some(k) = closed_connections.pop() {
425+
connections.remove(&k);
416426
}
427+
num_connections = Some(connections.len());
428+
}
429+
430+
if num_connections == Some(0) {
431+
thread::yield_now();
417432
}
418433
}
419-
let _ = proxy.0.shutdown(Shutdown::Both);
420-
let _ = remote.0.shutdown(Shutdown::Both);
421434
}
422435

423-
fn handle_client(&self, stream: &mut VsockStream) -> Result<(), IoError> {
436+
fn handle_client(server: Arc<Self>, stream: &mut VsockStream) -> Result<(), IoError> {
424437
match Self::read_request(stream) {
425-
Ok(Request::Connect{ addr }) => self.handle_request_connect(&addr, stream)?,
426-
Ok(Request::Bind{ addr, enclave_port }) => self.handle_request_bind(&addr, enclave_port, stream)?,
427-
Ok(Request::Accept{ enclave_port }) => self.handle_request_accept(enclave_port, stream)?,
438+
Ok(Request::Connect{ addr }) => Self::handle_request_connect(server, &addr, stream)?,
439+
Ok(Request::Bind{ addr, enclave_port }) => Self::handle_request_bind(server, &addr, enclave_port, stream)?,
440+
Ok(Request::Accept{ enclave_port }) => Self::handle_request_accept(server, enclave_port, stream)?,
428441
Err(_e) => return Err(IoError::new(IoErrorKind::InvalidData, "Failed to read request")),
429442
};
430443
Ok(())
@@ -434,34 +447,45 @@ impl Server {
434447
let command_listener = VsockListener::<Std>::bind_with_cid_port(vsock::VMADDR_CID_ANY, port)?;
435448
Ok(Server {
436449
command_listener: Mutex::new(command_listener),
437-
listeners: Mutex::new(FnvHashMap::default()),
438-
connections: Mutex::new(FnvHashMap::default()),
450+
listeners: RwLock::new(FnvHashMap::default()),
451+
connections: RwLock::new(FnvHashMap::default()),
452+
})
453+
}
454+
455+
fn start_proxy_server(server: Arc<Server>) -> Result<JoinHandle<()>, IoError> {
456+
thread::Builder::new()
457+
.spawn(move || {
458+
Self::proxy_connections(server);
459+
})
460+
}
461+
462+
fn start_command_server(server: Arc<Server>) -> Result<JoinHandle<()>, IoError> {
463+
thread::Builder::new().spawn(move || {
464+
let command_listener = server.command_listener.lock().unwrap();
465+
for stream in command_listener.incoming() {
466+
let server = server.clone();
467+
let _ = thread::Builder::new()
468+
.spawn(move || {
469+
let mut stream = stream.unwrap();
470+
if let Err(e) = Self::handle_client(server, &mut stream) {
471+
eprintln!("Error handling connection: {}, shutting connection down", e);
472+
let _ = stream.shutdown(Shutdown::Both);
473+
}
474+
});
475+
}
439476
})
440477
}
441478

442-
pub fn run(port: u32) -> std::io::Result<(JoinHandle<()>, u32)> {
479+
pub fn run(port: u32) -> std::io::Result<JoinHandle<()>> {
443480
println!("Starting enclave runner.");
444481
let server = Arc::new(Self::bind(port)?);
445482
let port = server.command_listener.lock().unwrap().local_addr()?.port();
446483
println!("Listening on vsock port {}...", port);
447484

448-
let handle = thread::Builder::new().spawn(move || {
449-
let server = server;
450-
let server = server.clone();
451-
let command_listener = server.command_listener.lock().unwrap();
452-
for stream in command_listener.incoming() {
453-
let server = server.clone();
454-
let _ = thread::Builder::new()
455-
.spawn(move || {
456-
let mut stream = stream.unwrap();
457-
if let Err(e) = server.handle_client(&mut stream) {
458-
eprintln!("Error handling connection: {}, shutting connection down", e);
459-
let _ = stream.shutdown(Shutdown::Both);
460-
}
461-
});
462-
}
463-
})?;
464-
Ok((handle, port))
485+
Server::start_proxy_server(server.clone())?;
486+
let handle = Server::start_command_server(server.clone())?;
487+
488+
Ok(handle)
465489
}
466490
}
467491

fortanix-vme/fortanix-vme-runner/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::io::ErrorKind;
44

55
fn main() {
66
match Server::run(SERVER_PORT) {
7-
Ok((server_thread, _port)) => server_thread.join().expect("Server panicked"),
7+
Ok(handle) => { handle.join().unwrap(); },
88
Err(e) if e.kind() == ErrorKind::AddrInUse => println!("Server failed. Do you already have a runner running on vsock port {}? (Error: {:?})", SERVER_PORT, e),
99
Err(e) => println!("Server failed. Error: {:?}", e),
1010
}

0 commit comments

Comments
 (0)