Skip to content

Commit c6d9478

Browse files
committed
Refactor HostFuncsWrapper so that the helper functions are private methods
Signed-off-by: Jorge Prendes <[email protected]>
1 parent 2e1efff commit c6d9478

File tree

2 files changed

+84
-92
lines changed

2 files changed

+84
-92
lines changed

src/hyperlight_host/src/sandbox/host_funcs.rs

Lines changed: 81 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ impl HostFuncsWrapper {
4040
name: String,
4141
func: HyperlightFunction,
4242
) -> Result<()> {
43-
register_host_function_helper(self, name, func, None)
43+
self.register_host_function_helper(name, func, None)
4444
}
4545

4646
/// Register a host function with the sandbox, with a list of extra syscalls
@@ -53,7 +53,7 @@ impl HostFuncsWrapper {
5353
func: HyperlightFunction,
5454
extra_allowed_syscalls: Vec<ExtraAllowedSyscall>,
5555
) -> Result<()> {
56-
register_host_function_helper(self, name, func, Some(extra_allowed_syscalls))
56+
self.register_host_function_helper(name, func, Some(extra_allowed_syscalls))
5757
}
5858

5959
/// Assuming a host function called `"HostPrint"` exists, and takes a
@@ -63,11 +63,7 @@ impl HostFuncsWrapper {
6363
/// and `Err` otherwise.
6464
#[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")]
6565
pub(super) fn host_print(&mut self, msg: String) -> Result<i32> {
66-
let res = call_host_func_impl(
67-
&self.functions_map,
68-
"HostPrint",
69-
vec![ParameterValue::String(msg)],
70-
)?;
66+
let res = self.call_host_func_impl("HostPrint", vec![ParameterValue::String(msg)])?;
7167
res.try_into()
7268
.map_err(|_| HostFunctionNotFound("HostPrint".to_string()))
7369
}
@@ -84,97 +80,40 @@ impl HostFuncsWrapper {
8480
name: &str,
8581
args: Vec<ParameterValue>,
8682
) -> Result<ReturnValue> {
87-
call_host_func_impl(&self.functions_map, name, args)
83+
self.call_host_func_impl(name, args)
8884
}
89-
}
90-
91-
fn register_host_function_helper(
92-
self_: &mut HostFuncsWrapper,
93-
name: String,
94-
func: HyperlightFunction,
95-
extra_allowed_syscalls: Option<Vec<ExtraAllowedSyscall>>,
96-
) -> Result<()> {
97-
if let Some(_syscalls) = extra_allowed_syscalls {
98-
#[cfg(all(feature = "seccomp", target_os = "linux"))]
99-
self_.functions_map.insert(name, (func, Some(_syscalls)));
10085

86+
fn register_host_function_helper(
87+
&mut self,
88+
name: String,
89+
func: HyperlightFunction,
90+
extra_allowed_syscalls: Option<Vec<ExtraAllowedSyscall>>,
91+
) -> Result<()> {
10192
#[cfg(not(all(feature = "seccomp", target_os = "linux")))]
102-
return Err(new_error!(
103-
"Extra syscalls are only supported on Linux with seccomp"
104-
));
105-
} else {
106-
self_.functions_map.insert(name, (func, None));
93+
if extra_allowed_syscalls.is_some() {
94+
return Err(new_error!(
95+
"Extra syscalls are only supported on Linux with seccomp"
96+
));
97+
}
98+
self.functions_map
99+
.insert(name, (func, extra_allowed_syscalls));
100+
Ok(())
107101
}
108102

109-
Ok(())
110-
}
111-
112-
#[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")]
113-
fn call_host_func_impl(
114-
host_funcs: &HashMap<String, (HyperlightFunction, Option<Vec<ExtraAllowedSyscall>>)>,
115-
name: &str,
116-
args: Vec<ParameterValue>,
117-
) -> Result<ReturnValue> {
118-
// Inner function containing the common logic
119-
fn call_func(
120-
host_funcs: &HashMap<String, (HyperlightFunction, Option<Vec<ExtraAllowedSyscall>>)>,
121-
name: &str,
122-
args: Vec<ParameterValue>,
123-
) -> Result<ReturnValue> {
124-
let func_with_syscalls = host_funcs
103+
#[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")]
104+
fn call_host_func_impl(&self, name: &str, args: Vec<ParameterValue>) -> Result<ReturnValue> {
105+
let FunctionEntry {
106+
function,
107+
extra_allowed_syscalls,
108+
} = self
109+
.functions_map
125110
.get(name)
126111
.ok_or_else(|| HostFunctionNotFound(name.to_string()))?;
127112

128-
let func = func_with_syscalls.0.clone();
129-
130-
#[cfg(all(feature = "seccomp", target_os = "linux"))]
131-
{
132-
let syscalls = func_with_syscalls.1.clone();
133-
let seccomp_filter =
134-
crate::seccomp::guest::get_seccomp_filter_for_host_function_worker_thread(
135-
syscalls,
136-
)?;
137-
seccompiler::apply_filter(&seccomp_filter)?;
138-
}
139-
140-
crate::metrics::maybe_time_and_emit_host_call(name, || func.call(args))
141-
}
142-
143-
cfg_if::cfg_if! {
144-
if #[cfg(all(feature = "seccomp", target_os = "linux"))] {
145-
// Clone variables for the thread
146-
let host_funcs_cloned = host_funcs.clone();
147-
let name_cloned = name.to_string();
148-
let args_cloned = args.clone();
149-
150-
// Create a new thread when seccomp is enabled on Linux
151-
let join_handle = std::thread::Builder::new()
152-
.name(format!("Host Function Worker Thread for: {:?}", name_cloned))
153-
.spawn(move || {
154-
// We have a `catch_unwind` here because, if a disallowed syscall is issued,
155-
// we handle it by panicking. This is to avoid returning execution to the
156-
// offending host function—for two reasons: (1) if a host function is issuing
157-
// disallowed syscalls, it could be unsafe to return to, and (2) returning
158-
// execution after trapping the disallowed syscall can lead to UB (e.g., try
159-
// running a host function that attempts to sleep without `SYS_clock_nanosleep`,
160-
// you'll block the syscall but panic in the aftermath).
161-
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| call_func(&host_funcs_cloned, &name_cloned, args_cloned))) {
162-
Ok(val) => val,
163-
Err(err) => {
164-
if let Some(crate::HyperlightError::DisallowedSyscall) = err.downcast_ref::<crate::HyperlightError>() {
165-
return Err(crate::HyperlightError::DisallowedSyscall)
166-
}
167-
168-
crate::log_then_return!("Host function {} panicked", name_cloned);
169-
}
170-
}
171-
})?;
172-
173-
join_handle.join().map_err(|_| new_error!("Error joining thread executing host function"))?
174-
} else {
175-
// Directly call the function without creating a new thread
176-
call_func(host_funcs, name, args)
177-
}
113+
// Create a new thread when seccomp is enabled on Linux
114+
maybe_with_seccomp(name, extra_allowed_syscalls.as_deref(), || {
115+
crate::metrics::maybe_time_and_emit_host_call(name, || function.call(args))
116+
})
178117
}
179118
}
180119

@@ -197,3 +136,55 @@ pub(super) fn default_writer_func(s: String) -> Result<i32> {
197136
}
198137
}
199138
}
139+
140+
#[cfg(all(feature = "seccomp", target_os = "linux"))]
141+
fn maybe_with_seccomp<T: Send>(
142+
name: &str,
143+
syscalls: Option<&[ExtraAllowedSyscall]>,
144+
f: impl FnOnce() -> Result<T> + Send,
145+
) -> Result<T> {
146+
use crate::seccomp::guest::get_seccomp_filter_for_host_function_worker_thread;
147+
148+
// Use a scoped thread so that we can pass around references without having to clone them.
149+
crossbeam::thread::scope(|s| {
150+
s.builder()
151+
.name(format!("Host Function Worker Thread for: {name:?}",))
152+
.spawn(move |_| {
153+
let seccomp_filter = get_seccomp_filter_for_host_function_worker_thread(syscalls)?;
154+
seccompiler::apply_filter(&seccomp_filter)?;
155+
156+
// We have a `catch_unwind` here because, if a disallowed syscall is issued,
157+
// we handle it by panicking. This is to avoid returning execution to the
158+
// offending host function—for two reasons: (1) if a host function is issuing
159+
// disallowed syscalls, it could be unsafe to return to, and (2) returning
160+
// execution after trapping the disallowed syscall can lead to UB (e.g., try
161+
// running a host function that attempts to sleep without `SYS_clock_nanosleep`,
162+
// you'll block the syscall but panic in the aftermath).
163+
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) {
164+
Ok(val) => val,
165+
Err(err) => {
166+
if let Some(crate::HyperlightError::DisallowedSyscall) =
167+
err.downcast_ref::<crate::HyperlightError>()
168+
{
169+
return Err(crate::HyperlightError::DisallowedSyscall);
170+
}
171+
172+
crate::log_then_return!("Host function {} panicked", name);
173+
}
174+
}
175+
})?
176+
.join()
177+
.map_err(|_| new_error!("Error joining thread executing host function"))?
178+
})
179+
.unwrap() // we've already joined the spawned thread, so no error can happen here
180+
}
181+
182+
#[cfg(not(all(feature = "seccomp", target_os = "linux")))]
183+
fn maybe_with_seccomp<T: Send>(
184+
_name: &str,
185+
_syscalls: Option<&[ExtraAllowedSyscall]>,
186+
f: impl FnOnce() -> Result<T> + Send,
187+
) -> Result<T> {
188+
// No seccomp, just call the function
189+
f()
190+
}

src/hyperlight_host/src/seccomp/guest.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,15 @@ fn syscalls_allowlist() -> Result<Vec<(i64, Vec<SeccompRule>)>> {
6969
/// (e.g., `KVM_SET_USER_MEMORY_REGION`, `KVM_GET_API_VERSION`, `KVM_CREATE_VM`,
7070
/// or `KVM_CREATE_VCPU`).
7171
pub(crate) fn get_seccomp_filter_for_host_function_worker_thread(
72-
extra_allowed_syscalls: Option<Vec<ExtraAllowedSyscall>>,
72+
extra_allowed_syscalls: Option<&[ExtraAllowedSyscall]>,
7373
) -> Result<BpfProgram> {
7474
let mut allowed_syscalls = syscalls_allowlist()?;
7575

7676
if let Some(extra_allowed_syscalls) = extra_allowed_syscalls {
7777
allowed_syscalls.extend(
7878
extra_allowed_syscalls
79-
.into_iter()
79+
.iter()
80+
.copied()
8081
.map(|syscall| (syscall, vec![])),
8182
);
8283

0 commit comments

Comments
 (0)