Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 54 additions & 29 deletions fortanix-vme/fortanix-vme-runner/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ use vsock::{self, SockAddr as VsockAddr, Std, Vsock, VsockListener, VsockStream}
mod platforms;
pub use platforms::{Platform, NitroEnclaves, Simulator, SimulatorArgs};

use crate::usercall_ext::Listener;
use crate::usercall_ext::SocketStream;
use crate::usercall_ext::UsercallExtension;
use crate::usercall_ext::UsercallExtensionDefault;

mod usercall_ext;

const MAX_LOG_MESSAGE_LEN: usize = 80;
const PROXY_BUFF_SIZE: usize = 4192;

Expand All @@ -36,9 +43,9 @@ pub trait StreamConnection: Read + Write {
fn peer_port(&self) -> io::Result<u32>;
}

impl StreamConnection for TcpStream {
impl<T: SocketStream> StreamConnection for T {
fn protocol() -> &'static str {
"tcp"
"socket stream"
}

fn local(&self) -> io::Result<String> {
Expand Down Expand Up @@ -84,20 +91,20 @@ impl StreamConnection for VsockStream {
}
}

#[derive(Debug)]
struct Listener {
listener: TcpListener,
}
// #[derive(Debug)]
// struct Listener {
// listener: TcpListener,
// }

impl Listener {
fn new(listener: TcpListener) -> Self {
Listener{ listener }
}
}
// impl Listener {
// fn new(listener: TcpListener) -> Self {
// Listener{ listener }
// }
// }

#[derive(Debug)]
// #[derive(Debug)]
struct Connection {
tcp_stream: TcpStream,
tcp_stream: Box<dyn SocketStream>,
vsock_stream: VsockStream<Std>,
remote_name: String,
}
Expand All @@ -111,7 +118,7 @@ struct ConnectionInfo {
}

impl Connection {
pub fn new(vsock_stream: VsockStream<Std>, tcp_stream: TcpStream, remote_name: String) -> Self {
pub fn new(vsock_stream: VsockStream<Std>, tcp_stream: Box<dyn SocketStream>, remote_name: String) -> Self {
Connection {
tcp_stream,
vsock_stream,
Expand Down Expand Up @@ -344,8 +351,9 @@ pub struct Server<P: Platform> {
/// When the enclave instructs to accept a new connection, the runner accepts a new TCP
/// connection. It then locates the ListenerInfo and finds the information it needs to set up a
/// new vsock connection to the enclave
listeners: RwLock<FnvHashMap<VsockAddr, Arc<Mutex<Listener>>>>,
listeners: RwLock<FnvHashMap<VsockAddr, Arc<Mutex<Box<dyn Listener>>>>>,
connections: RwLock<FnvHashMap<ConnectionKey, ConnectionInfo>>,
usercall_ext: Box<dyn UsercallExtension>,
}

impl<P: Platform + 'static> Server<P> {
Expand Down Expand Up @@ -381,8 +389,13 @@ impl<P: Platform + 'static> Server<P> {
* [3] proxy
*/
fn handle_request_connect(self: Arc<Self>, remote_addr: &String, conn: &mut ClientConnection) -> Result<(), VmeError> {
let remote_stream = if let Some(stream) = self.usercall_ext.connect_stream(remote_addr)? {
stream
} else {
let remote_socket = TcpStream::connect(remote_addr).map_err(|e| VmeError::Command(e.kind().into()))?;
Box::new(remote_socket)
};
// Connect to remote server
let remote_socket = TcpStream::connect(remote_addr).map_err(|e| VmeError::Command(e.kind().into()))?;
let remote_name = remote_addr.split_terminator(":").next().unwrap_or(remote_addr);

// Create listening socket that the enclave can connect to
Expand All @@ -392,8 +405,8 @@ impl<P: Platform + 'static> Server<P> {
// Notify the enclave on which port her proxy is listening on
let response = Response::Connected {
proxy_port: proxy_server_port,
local: remote_socket.local_addr()?.into(),
peer: remote_socket.peer_addr()?.into(),
local: remote_stream.local_addr()?.into(),
peer: remote_stream.peer_addr()?.into(),
};

conn.send(&response)?;
Expand All @@ -404,7 +417,7 @@ impl<P: Platform + 'static> Server<P> {
let accept_connection = move || -> Result<(), VmeError> {
let (proxy, _proxy_addr) = proxy_server.accept()?;
// Store connection info
self.add_connection(proxy, remote_socket, remote_name.to_string())?;
self.add_connection(proxy, remote_stream, remote_name.to_string())?;
Ok(())
};
if let Err(e) = accept_connection() {
Expand All @@ -413,15 +426,15 @@ impl<P: Platform + 'static> Server<P> {
Ok(())
}

fn add_listener(&self, addr: VsockAddr, info: Listener) {
fn add_listener(&self, addr: VsockAddr, info: Box<dyn Listener>) {
self.listeners.write().unwrap().insert(addr, Arc::new(Mutex::new(info)));
}

fn listener(&self, addr: &VsockAddr) -> Option<Arc<Mutex<Listener>>> {
fn listener(&self, addr: &VsockAddr) -> Option<Arc<Mutex<Box<dyn Listener>>>> {
self.listeners.read().unwrap().get(&addr).cloned()
}

fn remove_listener(&self, addr: &VsockAddr) -> Option<Arc<Mutex<Listener>>> {
fn remove_listener(&self, addr: &VsockAddr) -> Option<Arc<Mutex<Box<dyn Listener>>>> {
self.listeners.write().unwrap().remove(&addr)
}

Expand All @@ -447,7 +460,7 @@ impl<P: Platform + 'static> Server<P> {
self.connections.write().unwrap().remove(&k)
}

fn add_connection(self: Arc<Self>, runner_enclave: VsockStream, runner_remote: TcpStream, remote_name: String) -> Result<JoinHandle<()>, IoError> {
fn add_connection(self: Arc<Self>, runner_enclave: VsockStream, runner_remote: Box<dyn SocketStream>, remote_name: String) -> Result<JoinHandle<()>, IoError> {
let k = ConnectionKey::from_vsock_stream(&runner_enclave)?;
let mut connection = Connection::new(runner_enclave, runner_remote, remote_name);
self.connections.write().unwrap().insert(k.clone(), connection.info()?);
Expand Down Expand Up @@ -486,9 +499,15 @@ impl<P: Platform + 'static> Server<P> {
*/
fn handle_request_bind(self: Arc<Self>, addr: &String, enclave_port: u32, conn: &mut ClientConnection) -> Result<(), VmeError> {
let cid: u32 = conn.stream.peer_addr()?.cid();
let listener = TcpListener::bind(addr).map_err(|e| VmeError::Command(e.kind().into()))?;
let local: Addr = listener.local_addr()?.into();
self.add_listener(VsockAddr::new(cid, enclave_port), Listener::new(listener));
let (listener, local_addr) = if let Some((lis, addr)) = self.usercall_ext.bind_stream(addr)? {
(lis, addr)
} else {
let lis = TcpListener::bind(addr).map_err(|e| VmeError::Command(e.kind().into()))?;
let addr = lis.local_addr()?;
(Box::new(lis) as Box<dyn Listener>, addr)
};
let local: Addr = local_addr.into();
self.add_listener(VsockAddr::new(cid, enclave_port), listener);
conn.send(&Response::Bound{ local })?;
Ok(())
}
Expand All @@ -501,8 +520,8 @@ impl<P: Platform + 'static> Server<P> {
.ok_or(IoError::new(IoErrorKind::InvalidInput, "Information about provided file descriptor was not found"))?;

// Accept connection for TCP Listener
let listener = listener.lock().unwrap();
let (conn, peer) = listener.listener.accept().map_err(|e| VmeError::Command(e.kind().into()))?;
let mut listener = listener.lock().unwrap();
let (conn, peer) = listener.accept().map_err(|e| VmeError::Command(e.kind().into()))?;
drop(listener);

// Send enclave info where it should accept new incoming connection
Expand Down Expand Up @@ -563,7 +582,7 @@ impl<P: Platform + 'static> Server<P> {
if let Some(listener) = self.listener(&enclave_addr) {
let listener = listener.lock().unwrap();
conn.send(&Response::Info {
local: listener.listener.local_addr()?.into(),
local: listener.local_addr()?.into(),
peer: None,
})?;
Ok(())
Expand Down Expand Up @@ -607,9 +626,15 @@ impl<P: Platform + 'static> Server<P> {
command_listener: Mutex::new(command_listener),
listeners: RwLock::new(FnvHashMap::default()),
connections: RwLock::new(FnvHashMap::default()),
usercall_ext: Box::new(UsercallExtensionDefault),

})
}

pub fn set_usercall_ext(&mut self, usercall_ext: Box<dyn UsercallExtension>) {
self.usercall_ext = usercall_ext
}

fn start_command_server(self: Arc<Self>) -> Result<JoinHandle<()>, IoError> {
thread::Builder::new().spawn(move || {
let command_listener = self.command_listener.lock().unwrap();
Expand Down
93 changes: 93 additions & 0 deletions fortanix-vme/fortanix-vme-runner/src/usercall_ext.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use core::net::SocketAddr;
use std::{
io::{self, Read, Result as IoResult, Write},
net::{TcpListener, TcpStream},
os::fd::RawFd,
};

pub trait UsercallExtension: 'static + Send + Sync + std::fmt::Debug {
fn connect_stream(&self, addr: &str) -> IoResult<Option<Box<dyn SocketStream>>> {
let _ = addr;
Ok(None)
}
fn bind_stream(&self, addr: &str) -> IoResult<Option<(Box<dyn Listener>, SocketAddr)>> {
let _ = addr;
Ok(None)
}
}

impl<T: UsercallExtension> From<T> for Box<dyn UsercallExtension> {
fn from(value: T) -> Box<dyn UsercallExtension> {
Box::new(value)
}
}

#[derive(Debug)]
pub struct UsercallExtensionDefault;
impl UsercallExtension for UsercallExtensionDefault {}

pub trait SocketStream: Read + Write + 'static + Send + Sync {
fn local_addr(&self) -> IoResult<SocketAddr>;
fn peer_addr(&self) -> IoResult<SocketAddr>;
fn as_raw_fd(&self) -> RawFd;
fn shutdown(&self, how: std::net::Shutdown) -> IoResult<()>;
}

impl SocketStream for TcpStream {
fn local_addr(&self) -> IoResult<SocketAddr> {
self.local_addr()
}

fn peer_addr(&self) -> IoResult<SocketAddr> {
self.peer_addr()
}

fn as_raw_fd(&self) -> RawFd {
std::os::fd::AsRawFd::as_raw_fd(self)
}

fn shutdown(&self, how: std::net::Shutdown) -> IoResult<()> {
self.shutdown(how)
}
}

impl<T: SocketStream + ?Sized> SocketStream for Box<T> {
fn local_addr(&self) -> IoResult<SocketAddr> {
(**self).local_addr()
}

fn peer_addr(&self) -> IoResult<SocketAddr> {
(**self).peer_addr()
}

fn as_raw_fd(&self) -> RawFd {
(**self).as_raw_fd()
}

fn shutdown(&self, how: std::net::Shutdown) -> IoResult<()> {
(**self).shutdown(how)
}
}

/// Listener lets an implementation implement a slightly modified form of `std::net::TcpListener::accept`.
pub trait Listener: 'static + Send {
/// The enclave may optionally request the local or peer addresses
/// be returned in `local_addr` or `peer_addr`, respectively.
/// If `local_addr` and/or `peer_addr` are not `None`, they will point to an empty `String`.
/// On success, user-space can fill in the strings as appropriate.
///
/// The enclave must not make any security decisions based on the local address received.
fn accept(&mut self) -> io::Result<(Box<dyn SocketStream>, SocketAddr)>;
fn local_addr(&self) -> IoResult<SocketAddr>;
}

impl Listener for TcpListener {
fn accept(&mut self) -> io::Result<(Box<dyn SocketStream>, SocketAddr)> {
TcpListener::accept(&self)
.map(|(stream, addr)| (Box::new(stream) as Box<dyn SocketStream>, addr))
}

fn local_addr(&self) -> IoResult<SocketAddr> {
self.local_addr()
}
}