Skip to content

Commit 1ca2fbc

Browse files
committed
feat(fortanix-vme-runner): add implement UsercallExtension trait and associated types
1 parent 087721d commit 1ca2fbc

File tree

2 files changed

+147
-29
lines changed

2 files changed

+147
-29
lines changed

fortanix-vme/fortanix-vme-runner/src/lib.rs

Lines changed: 54 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@ mod platforms;
1616
pub use platforms::{Platform, NitroEnclaves, EnclaveSimulator, EnclaveSimulatorArgs};
1717
pub use platforms::amdsevsnp::{AmdSevVm, RunningVm, VmRunArgs, VmSimulator};
1818

19+
use crate::usercall_ext::Listener;
20+
use crate::usercall_ext::SocketStream;
21+
use crate::usercall_ext::UsercallExtension;
22+
use crate::usercall_ext::UsercallExtensionDefault;
23+
24+
mod usercall_ext;
25+
1926
const MAX_LOG_MESSAGE_LEN: usize = 80;
2027
const PROXY_BUFF_SIZE: usize = 4192;
2128

@@ -36,9 +43,9 @@ pub trait StreamConnection: Read + Write {
3643
fn peer_port(&self) -> io::Result<u32>;
3744
}
3845

39-
impl StreamConnection for TcpStream {
46+
impl<T: SocketStream> StreamConnection for T {
4047
fn protocol() -> &'static str {
41-
"tcp"
48+
"socket stream"
4249
}
4350

4451
fn local(&self) -> io::Result<String> {
@@ -84,20 +91,20 @@ impl StreamConnection for VsockStream {
8491
}
8592
}
8693

87-
#[derive(Debug)]
88-
struct Listener {
89-
listener: TcpListener,
90-
}
94+
// #[derive(Debug)]
95+
// struct Listener {
96+
// listener: TcpListener,
97+
// }
9198

92-
impl Listener {
93-
fn new(listener: TcpListener) -> Self {
94-
Listener{ listener }
95-
}
96-
}
99+
// impl Listener {
100+
// fn new(listener: TcpListener) -> Self {
101+
// Listener{ listener }
102+
// }
103+
// }
97104

98-
#[derive(Debug)]
105+
// #[derive(Debug)]
99106
struct Connection {
100-
tcp_stream: TcpStream,
107+
tcp_stream: Box<dyn SocketStream>,
101108
vsock_stream: VsockStream<Std>,
102109
remote_name: String,
103110
}
@@ -111,7 +118,7 @@ struct ConnectionInfo {
111118
}
112119

113120
impl Connection {
114-
pub fn new(vsock_stream: VsockStream<Std>, tcp_stream: TcpStream, remote_name: String) -> Self {
121+
pub fn new(vsock_stream: VsockStream<Std>, tcp_stream: Box<dyn SocketStream>, remote_name: String) -> Self {
115122
Connection {
116123
tcp_stream,
117124
vsock_stream,
@@ -342,8 +349,9 @@ pub struct Server<P: Platform> {
342349
/// When the enclave instructs to accept a new connection, the runner accepts a new TCP
343350
/// connection. It then locates the ListenerInfo and finds the information it needs to set up a
344351
/// new vsock connection to the enclave
345-
listeners: RwLock<FnvHashMap<VsockAddr, Arc<Mutex<Listener>>>>,
352+
listeners: RwLock<FnvHashMap<VsockAddr, Arc<Mutex<Box<dyn Listener>>>>>,
346353
connections: RwLock<FnvHashMap<ConnectionKey, ConnectionInfo>>,
354+
usercall_ext: Box<dyn UsercallExtension>,
347355
}
348356

349357
impl<P: Platform + 'static> Server<P> {
@@ -379,8 +387,13 @@ impl<P: Platform + 'static> Server<P> {
379387
* [3] proxy
380388
*/
381389
fn handle_request_connect(self: Arc<Self>, remote_addr: &String, conn: &mut ClientConnection) -> Result<(), VmeError> {
390+
let remote_stream = if let Some(stream) = self.usercall_ext.connect_stream(remote_addr)? {
391+
stream
392+
} else {
393+
let remote_socket = TcpStream::connect(remote_addr).map_err(|e| VmeError::Command(e.kind().into()))?;
394+
Box::new(remote_socket)
395+
};
382396
// Connect to remote server
383-
let remote_socket = TcpStream::connect(remote_addr).map_err(|e| VmeError::Command(e.kind().into()))?;
384397
let remote_name = remote_addr.split_terminator(":").next().unwrap_or(remote_addr);
385398

386399
// Create listening socket that the enclave can connect to
@@ -390,8 +403,8 @@ impl<P: Platform + 'static> Server<P> {
390403
// Notify the enclave on which port her proxy is listening on
391404
let response = Response::Connected {
392405
proxy_port: proxy_server_port,
393-
local: remote_socket.local_addr()?.into(),
394-
peer: remote_socket.peer_addr()?.into(),
406+
local: remote_stream.local_addr()?.into(),
407+
peer: remote_stream.peer_addr()?.into(),
395408
};
396409

397410
conn.send(&response)?;
@@ -402,7 +415,7 @@ impl<P: Platform + 'static> Server<P> {
402415
let accept_connection = move || -> Result<(), VmeError> {
403416
let (proxy, _proxy_addr) = proxy_server.accept()?;
404417
// Store connection info
405-
self.add_connection(proxy, remote_socket, remote_name.to_string())?;
418+
self.add_connection(proxy, remote_stream, remote_name.to_string())?;
406419
Ok(())
407420
};
408421
if let Err(e) = accept_connection() {
@@ -411,15 +424,15 @@ impl<P: Platform + 'static> Server<P> {
411424
Ok(())
412425
}
413426

414-
fn add_listener(&self, addr: VsockAddr, info: Listener) {
427+
fn add_listener(&self, addr: VsockAddr, info: Box<dyn Listener>) {
415428
self.listeners.write().unwrap().insert(addr, Arc::new(Mutex::new(info)));
416429
}
417430

418-
fn listener(&self, addr: &VsockAddr) -> Option<Arc<Mutex<Listener>>> {
431+
fn listener(&self, addr: &VsockAddr) -> Option<Arc<Mutex<Box<dyn Listener>>>> {
419432
self.listeners.read().unwrap().get(&addr).cloned()
420433
}
421434

422-
fn remove_listener(&self, addr: &VsockAddr) -> Option<Arc<Mutex<Listener>>> {
435+
fn remove_listener(&self, addr: &VsockAddr) -> Option<Arc<Mutex<Box<dyn Listener>>>> {
423436
self.listeners.write().unwrap().remove(&addr)
424437
}
425438

@@ -445,7 +458,7 @@ impl<P: Platform + 'static> Server<P> {
445458
self.connections.write().unwrap().remove(&k)
446459
}
447460

448-
fn add_connection(self: Arc<Self>, runner_enclave: VsockStream, runner_remote: TcpStream, remote_name: String) -> Result<JoinHandle<()>, IoError> {
461+
fn add_connection(self: Arc<Self>, runner_enclave: VsockStream, runner_remote: Box<dyn SocketStream>, remote_name: String) -> Result<JoinHandle<()>, IoError> {
449462
let k = ConnectionKey::from_vsock_stream(&runner_enclave)?;
450463
let mut connection = Connection::new(runner_enclave, runner_remote, remote_name);
451464
self.connections.write().unwrap().insert(k.clone(), connection.info()?);
@@ -484,9 +497,15 @@ impl<P: Platform + 'static> Server<P> {
484497
*/
485498
fn handle_request_bind(self: Arc<Self>, addr: &String, enclave_port: u32, conn: &mut ClientConnection) -> Result<(), VmeError> {
486499
let cid: u32 = conn.stream.peer_addr()?.cid();
487-
let listener = TcpListener::bind(addr).map_err(|e| VmeError::Command(e.kind().into()))?;
488-
let local: Addr = listener.local_addr()?.into();
489-
self.add_listener(VsockAddr::new(cid, enclave_port), Listener::new(listener));
500+
let (listener, local_addr) = if let Some((lis, addr)) = self.usercall_ext.bind_stream(addr)? {
501+
(lis, addr)
502+
} else {
503+
let lis = TcpListener::bind(addr).map_err(|e| VmeError::Command(e.kind().into()))?;
504+
let addr = lis.local_addr()?;
505+
(Box::new(lis) as Box<dyn Listener>, addr)
506+
};
507+
let local: Addr = local_addr.into();
508+
self.add_listener(VsockAddr::new(cid, enclave_port), listener);
490509
conn.send(&Response::Bound{ local })?;
491510
Ok(())
492511
}
@@ -499,8 +518,8 @@ impl<P: Platform + 'static> Server<P> {
499518
.ok_or(IoError::new(IoErrorKind::InvalidInput, "Information about provided file descriptor was not found"))?;
500519

501520
// Accept connection for TCP Listener
502-
let listener = listener.lock().unwrap();
503-
let (conn, peer) = listener.listener.accept().map_err(|e| VmeError::Command(e.kind().into()))?;
521+
let mut listener = listener.lock().unwrap();
522+
let (conn, peer) = listener.accept().map_err(|e| VmeError::Command(e.kind().into()))?;
504523
drop(listener);
505524

506525
// Send enclave info where it should accept new incoming connection
@@ -561,7 +580,7 @@ impl<P: Platform + 'static> Server<P> {
561580
if let Some(listener) = self.listener(&enclave_addr) {
562581
let listener = listener.lock().unwrap();
563582
conn.send(&Response::Info {
564-
local: listener.listener.local_addr()?.into(),
583+
local: listener.local_addr()?.into(),
565584
peer: None,
566585
})?;
567586
Ok(())
@@ -605,9 +624,15 @@ impl<P: Platform + 'static> Server<P> {
605624
command_listener: Mutex::new(command_listener),
606625
listeners: RwLock::new(FnvHashMap::default()),
607626
connections: RwLock::new(FnvHashMap::default()),
627+
usercall_ext: Box::new(UsercallExtensionDefault),
628+
608629
})
609630
}
610631

632+
pub fn set_usercall_ext(&mut self, usercall_ext: Box<dyn UsercallExtension>) {
633+
self.usercall_ext = usercall_ext
634+
}
635+
611636
fn start_command_server(self: Arc<Self>) -> Result<JoinHandle<()>, IoError> {
612637
thread::Builder::new().spawn(move || {
613638
let command_listener = self.command_listener.lock().unwrap();
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
use core::net::SocketAddr;
2+
use std::{
3+
io::{self, Read, Result as IoResult, Write},
4+
net::{TcpListener, TcpStream},
5+
os::fd::RawFd,
6+
};
7+
8+
pub trait UsercallExtension: 'static + Send + Sync + std::fmt::Debug {
9+
fn connect_stream(&self, addr: &str) -> IoResult<Option<Box<dyn SocketStream>>> {
10+
let _ = addr;
11+
Ok(None)
12+
}
13+
fn bind_stream(&self, addr: &str) -> IoResult<Option<(Box<dyn Listener>, SocketAddr)>> {
14+
let _ = addr;
15+
Ok(None)
16+
}
17+
}
18+
19+
impl<T: UsercallExtension> From<T> for Box<dyn UsercallExtension> {
20+
fn from(value: T) -> Box<dyn UsercallExtension> {
21+
Box::new(value)
22+
}
23+
}
24+
25+
#[derive(Debug)]
26+
pub struct UsercallExtensionDefault;
27+
impl UsercallExtension for UsercallExtensionDefault {}
28+
29+
pub trait SocketStream: Read + Write + 'static + Send + Sync {
30+
fn local_addr(&self) -> IoResult<SocketAddr>;
31+
fn peer_addr(&self) -> IoResult<SocketAddr>;
32+
fn as_raw_fd(&self) -> RawFd;
33+
fn shutdown(&self, how: std::net::Shutdown) -> IoResult<()>;
34+
}
35+
36+
impl SocketStream for TcpStream {
37+
fn local_addr(&self) -> IoResult<SocketAddr> {
38+
self.local_addr()
39+
}
40+
41+
fn peer_addr(&self) -> IoResult<SocketAddr> {
42+
self.peer_addr()
43+
}
44+
45+
fn as_raw_fd(&self) -> RawFd {
46+
std::os::fd::AsRawFd::as_raw_fd(self)
47+
}
48+
49+
fn shutdown(&self, how: std::net::Shutdown) -> IoResult<()> {
50+
self.shutdown(how)
51+
}
52+
}
53+
54+
impl<T: SocketStream + ?Sized> SocketStream for Box<T> {
55+
fn local_addr(&self) -> IoResult<SocketAddr> {
56+
(**self).local_addr()
57+
}
58+
59+
fn peer_addr(&self) -> IoResult<SocketAddr> {
60+
(**self).peer_addr()
61+
}
62+
63+
fn as_raw_fd(&self) -> RawFd {
64+
(**self).as_raw_fd()
65+
}
66+
67+
fn shutdown(&self, how: std::net::Shutdown) -> IoResult<()> {
68+
(**self).shutdown(how)
69+
}
70+
}
71+
72+
/// Listener lets an implementation implement a slightly modified form of `std::net::TcpListener::accept`.
73+
pub trait Listener: 'static + Send {
74+
/// The enclave may optionally request the local or peer addresses
75+
/// be returned in `local_addr` or `peer_addr`, respectively.
76+
/// If `local_addr` and/or `peer_addr` are not `None`, they will point to an empty `String`.
77+
/// On success, user-space can fill in the strings as appropriate.
78+
///
79+
/// The enclave must not make any security decisions based on the local address received.
80+
fn accept(&mut self) -> io::Result<(Box<dyn SocketStream>, SocketAddr)>;
81+
fn local_addr(&self) -> IoResult<SocketAddr>;
82+
}
83+
84+
impl Listener for TcpListener {
85+
fn accept(&mut self) -> io::Result<(Box<dyn SocketStream>, SocketAddr)> {
86+
TcpListener::accept(&self)
87+
.map(|(stream, addr)| (Box::new(stream) as Box<dyn SocketStream>, addr))
88+
}
89+
90+
fn local_addr(&self) -> IoResult<SocketAddr> {
91+
self.local_addr()
92+
}
93+
}

0 commit comments

Comments
 (0)