Skip to content

Commit f27fa71

Browse files
committed
Refactor HostFuncsWrapper so that the helper functions are private methods
Signed-off-by: Jorge Prendes <[email protected]>
1 parent 944472d commit f27fa71

File tree

1 file changed

+52
-75
lines changed

1 file changed

+52
-75
lines changed

src/hyperlight_host/src/sandbox/host_funcs.rs

Lines changed: 52 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ impl HostFuncsWrapper {
4040
name: String,
4141
func: HyperlightFunction,
4242
) -> Result<()> {
43-
register_host_function_helper(self, name, func, None)
43+
self.register_host_function_helper(name, func, None)
4444
}
4545

4646
/// Register a host function with the sandbox, with a list of extra syscalls
@@ -53,7 +53,7 @@ impl HostFuncsWrapper {
5353
func: HyperlightFunction,
5454
extra_allowed_syscalls: Vec<ExtraAllowedSyscall>,
5555
) -> Result<()> {
56-
register_host_function_helper(self, name, func, Some(extra_allowed_syscalls))
56+
self.register_host_function_helper(name, func, Some(extra_allowed_syscalls))
5757
}
5858

5959
/// Assuming a host function called `"HostPrint"` exists, and takes a
@@ -63,11 +63,7 @@ impl HostFuncsWrapper {
6363
/// and `Err` otherwise.
6464
#[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")]
6565
pub(super) fn host_print(&mut self, msg: String) -> Result<i32> {
66-
let res = call_host_func_impl(
67-
&self.functions_map,
68-
"HostPrint",
69-
vec![ParameterValue::String(msg)],
70-
)?;
66+
let res = self.call_host_func_impl("HostPrint", vec![ParameterValue::String(msg)])?;
7167
res.try_into()
7268
.map_err(|_| HostFunctionNotFound("HostPrint".to_string()))
7369
}
@@ -84,96 +80,77 @@ impl HostFuncsWrapper {
8480
name: &str,
8581
args: Vec<ParameterValue>,
8682
) -> Result<ReturnValue> {
87-
call_host_func_impl(&self.functions_map, name, args)
83+
self.call_host_func_impl(name, args)
8884
}
89-
}
90-
91-
fn register_host_function_helper(
92-
self_: &mut HostFuncsWrapper,
93-
name: String,
94-
func: HyperlightFunction,
95-
extra_allowed_syscalls: Option<Vec<ExtraAllowedSyscall>>,
96-
) -> Result<()> {
97-
if let Some(_syscalls) = extra_allowed_syscalls {
98-
#[cfg(all(feature = "seccomp", target_os = "linux"))]
99-
self_.functions_map.insert(name, (func, Some(_syscalls)));
10085

86+
fn register_host_function_helper(
87+
&mut self,
88+
name: String,
89+
func: HyperlightFunction,
90+
extra_allowed_syscalls: Option<Vec<ExtraAllowedSyscall>>,
91+
) -> Result<()> {
10192
#[cfg(not(all(feature = "seccomp", target_os = "linux")))]
102-
return Err(new_error!(
103-
"Extra syscalls are only supported on Linux with seccomp"
104-
));
105-
} else {
106-
self_.functions_map.insert(name, (func, None));
107-
}
108-
109-
Ok(())
110-
}
111-
112-
#[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")]
113-
fn call_host_func_impl(
114-
host_funcs: &HashMap<String, (HyperlightFunction, Option<Vec<ExtraAllowedSyscall>>)>,
115-
name: &str,
116-
args: Vec<ParameterValue>,
117-
) -> Result<ReturnValue> {
118-
// Inner function containing the common logic
119-
fn call_func(
120-
host_funcs: &HashMap<String, (HyperlightFunction, Option<Vec<ExtraAllowedSyscall>>)>,
121-
name: &str,
122-
args: Vec<ParameterValue>,
123-
) -> Result<ReturnValue> {
124-
let func_with_syscalls = host_funcs
125-
.get(name)
126-
.ok_or_else(|| HostFunctionNotFound(name.to_string()))?;
127-
128-
let func = func_with_syscalls.0.clone();
129-
130-
#[cfg(all(feature = "seccomp", target_os = "linux"))]
131-
{
132-
let syscalls = func_with_syscalls.1.clone();
133-
let seccomp_filter =
134-
crate::seccomp::guest::get_seccomp_filter_for_host_function_worker_thread(
135-
syscalls,
136-
)?;
137-
seccompiler::apply_filter(&seccomp_filter)?;
93+
if extra_allowed_syscalls.is_some() {
94+
return Err(new_error!(
95+
"Extra syscalls are only supported on Linux with seccomp"
96+
));
13897
}
139-
140-
crate::metrics::maybe_time_and_emit_host_call(name, || func.call(args))
98+
self.functions_map
99+
.insert(name, (func, extra_allowed_syscalls));
100+
Ok(())
141101
}
142102

143-
cfg_if::cfg_if! {
144-
if #[cfg(all(feature = "seccomp", target_os = "linux"))] {
145-
// Clone variables for the thread
146-
let host_funcs_cloned = host_funcs.clone();
147-
let name_cloned = name.to_string();
148-
let args_cloned = args.clone();
149-
103+
#[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")]
104+
fn call_host_func_impl(&self, name: &str, args: Vec<ParameterValue>) -> Result<ReturnValue> {
105+
// Inner function containing the common logic
106+
let do_call = || {
107+
let (func, syscalls) = self
108+
.functions_map
109+
.get(name)
110+
.ok_or_else(|| HostFunctionNotFound(name.to_string()))?;
111+
112+
#[cfg(all(feature = "seccomp", target_os = "linux"))]
113+
{
114+
let seccomp_filter =
115+
crate::seccomp::guest::get_seccomp_filter_for_host_function_worker_thread(
116+
syscalls.clone(),
117+
)?;
118+
seccompiler::apply_filter(&seccomp_filter)?;
119+
}
120+
121+
crate::metrics::maybe_time_and_emit_host_call(name, || func.call(args))
122+
};
123+
124+
if cfg!(all(feature = "seccomp", target_os = "linux")) {
150125
// Create a new thread when seccomp is enabled on Linux
151-
let join_handle = std::thread::Builder::new()
152-
.name(format!("Host Function Worker Thread for: {:?}", name_cloned))
153-
.spawn(move || {
126+
std::thread::scope(|s| {
127+
s.spawn(move || {
154128
// We have a `catch_unwind` here because, if a disallowed syscall is issued,
155129
// we handle it by panicking. This is to avoid returning execution to the
156130
// offending host function—for two reasons: (1) if a host function is issuing
157131
// disallowed syscalls, it could be unsafe to return to, and (2) returning
158132
// execution after trapping the disallowed syscall can lead to UB (e.g., try
159133
// running a host function that attempts to sleep without `SYS_clock_nanosleep`,
160134
// you'll block the syscall but panic in the aftermath).
161-
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| call_func(&host_funcs_cloned, &name_cloned, args_cloned))) {
135+
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(do_call)) {
162136
Ok(val) => val,
163137
Err(err) => {
164-
if let Some(crate::HyperlightError::DisallowedSyscall) = err.downcast_ref::<crate::HyperlightError>() {
165-
return Err(crate::HyperlightError::DisallowedSyscall)
138+
if let Some(crate::HyperlightError::DisallowedSyscall) =
139+
err.downcast_ref::<crate::HyperlightError>()
140+
{
141+
return Err(crate::HyperlightError::DisallowedSyscall);
166142
}
167143

168-
crate::log_then_return!("Host function {} panicked", name_cloned);
144+
crate::log_then_return!("Host function {} panicked", name);
169145
}
170146
}
171-
})?;
172-
173-
join_handle.join().map_err(|_| new_error!("Error joining thread executing host function"))?
147+
})
148+
.join()
149+
.map_err(|_| new_error!("Error joining thread executing host function"))?
150+
})
174151
} else {
175152
// Directly call the function without creating a new thread
176-
call_func(host_funcs, name, args)
153+
do_call()
177154
}
178155
}
179156
}

0 commit comments

Comments
 (0)