@@ -40,7 +40,7 @@ impl HostFuncsWrapper {
4040 name : String ,
4141 func : HyperlightFunction ,
4242 ) -> Result < ( ) > {
43- register_host_function_helper ( self , name, func, None )
43+ self . register_host_function_helper ( name, func, None )
4444 }
4545
4646 /// Register a host function with the sandbox, with a list of extra syscalls
@@ -53,7 +53,7 @@ impl HostFuncsWrapper {
5353 func : HyperlightFunction ,
5454 extra_allowed_syscalls : Vec < ExtraAllowedSyscall > ,
5555 ) -> Result < ( ) > {
56- register_host_function_helper ( self , name, func, Some ( extra_allowed_syscalls) )
56+ self . register_host_function_helper ( name, func, Some ( extra_allowed_syscalls) )
5757 }
5858
5959 /// Assuming a host function called `"HostPrint"` exists, and takes a
@@ -63,11 +63,7 @@ impl HostFuncsWrapper {
6363 /// and `Err` otherwise.
6464 #[ instrument( err( Debug ) , skip_all, parent = Span :: current( ) , level = "Trace" ) ]
6565 pub ( super ) fn host_print ( & mut self , msg : String ) -> Result < i32 > {
66- let res = call_host_func_impl (
67- & self . functions_map ,
68- "HostPrint" ,
69- vec ! [ ParameterValue :: String ( msg) ] ,
70- ) ?;
66+ let res = self . call_host_func_impl ( "HostPrint" , vec ! [ ParameterValue :: String ( msg) ] ) ?;
7167 res. try_into ( )
7268 . map_err ( |_| HostFunctionNotFound ( "HostPrint" . to_string ( ) ) )
7369 }
@@ -84,96 +80,77 @@ impl HostFuncsWrapper {
8480 name : & str ,
8581 args : Vec < ParameterValue > ,
8682 ) -> Result < ReturnValue > {
87- call_host_func_impl ( & self . functions_map , name, args)
83+ self . call_host_func_impl ( name, args)
8884 }
89- }
90-
91- fn register_host_function_helper (
92- self_ : & mut HostFuncsWrapper ,
93- name : String ,
94- func : HyperlightFunction ,
95- extra_allowed_syscalls : Option < Vec < ExtraAllowedSyscall > > ,
96- ) -> Result < ( ) > {
97- if let Some ( _syscalls) = extra_allowed_syscalls {
98- #[ cfg( all( feature = "seccomp" , target_os = "linux" ) ) ]
99- self_. functions_map . insert ( name, ( func, Some ( _syscalls) ) ) ;
10085
86+ fn register_host_function_helper (
87+ & mut self ,
88+ name : String ,
89+ func : HyperlightFunction ,
90+ extra_allowed_syscalls : Option < Vec < ExtraAllowedSyscall > > ,
91+ ) -> Result < ( ) > {
10192 #[ cfg( not( all( feature = "seccomp" , target_os = "linux" ) ) ) ]
102- return Err ( new_error ! (
103- "Extra syscalls are only supported on Linux with seccomp"
104- ) ) ;
105- } else {
106- self_. functions_map . insert ( name, ( func, None ) ) ;
107- }
108-
109- Ok ( ( ) )
110- }
111-
112- #[ instrument( err( Debug ) , skip_all, parent = Span :: current( ) , level = "Trace" ) ]
113- fn call_host_func_impl (
114- host_funcs : & HashMap < String , ( HyperlightFunction , Option < Vec < ExtraAllowedSyscall > > ) > ,
115- name : & str ,
116- args : Vec < ParameterValue > ,
117- ) -> Result < ReturnValue > {
118- // Inner function containing the common logic
119- fn call_func (
120- host_funcs : & HashMap < String , ( HyperlightFunction , Option < Vec < ExtraAllowedSyscall > > ) > ,
121- name : & str ,
122- args : Vec < ParameterValue > ,
123- ) -> Result < ReturnValue > {
124- let func_with_syscalls = host_funcs
125- . get ( name)
126- . ok_or_else ( || HostFunctionNotFound ( name. to_string ( ) ) ) ?;
127-
128- let func = func_with_syscalls. 0 . clone ( ) ;
129-
130- #[ cfg( all( feature = "seccomp" , target_os = "linux" ) ) ]
131- {
132- let syscalls = func_with_syscalls. 1 . clone ( ) ;
133- let seccomp_filter =
134- crate :: seccomp:: guest:: get_seccomp_filter_for_host_function_worker_thread (
135- syscalls,
136- ) ?;
137- seccompiler:: apply_filter ( & seccomp_filter) ?;
93+ if extra_allowed_syscalls. is_some ( ) {
94+ return Err ( new_error ! (
95+ "Extra syscalls are only supported on Linux with seccomp"
96+ ) ) ;
13897 }
139-
140- crate :: metrics:: maybe_time_and_emit_host_call ( name, || func. call ( args) )
98+ self . functions_map
99+ . insert ( name, ( func, extra_allowed_syscalls) ) ;
100+ Ok ( ( ) )
141101 }
142102
143- cfg_if:: cfg_if! {
144- if #[ cfg( all( feature = "seccomp" , target_os = "linux" ) ) ] {
145- // Clone variables for the thread
146- let host_funcs_cloned = host_funcs. clone( ) ;
147- let name_cloned = name. to_string( ) ;
148- let args_cloned = args. clone( ) ;
149-
103+ #[ instrument( err( Debug ) , skip_all, parent = Span :: current( ) , level = "Trace" ) ]
104+ fn call_host_func_impl ( & self , name : & str , args : Vec < ParameterValue > ) -> Result < ReturnValue > {
105+ // Inner function containing the common logic
106+ let do_call = || {
107+ let ( func, syscalls) = self
108+ . functions_map
109+ . get ( name)
110+ . ok_or_else ( || HostFunctionNotFound ( name. to_string ( ) ) ) ?;
111+
112+ #[ cfg( all( feature = "seccomp" , target_os = "linux" ) ) ]
113+ {
114+ let seccomp_filter =
115+ crate :: seccomp:: guest:: get_seccomp_filter_for_host_function_worker_thread (
116+ syscalls. clone ( ) ,
117+ ) ?;
118+ seccompiler:: apply_filter ( & seccomp_filter) ?;
119+ }
120+
121+ crate :: metrics:: maybe_time_and_emit_host_call ( name, || func. call ( args) )
122+ } ;
123+
124+ if cfg ! ( all( feature = "seccomp" , target_os = "linux" ) ) {
150125 // Create a new thread when seccomp is enabled on Linux
151- let join_handle = std:: thread:: Builder :: new( )
152- . name( format!( "Host Function Worker Thread for: {:?}" , name_cloned) )
153- . spawn( move || {
126+ std:: thread:: scope ( |s| {
127+ s. spawn ( move || {
154128 // We have a `catch_unwind` here because, if a disallowed syscall is issued,
155129 // we handle it by panicking. This is to avoid returning execution to the
156130 // offending host function—for two reasons: (1) if a host function is issuing
157131 // disallowed syscalls, it could be unsafe to return to, and (2) returning
158132 // execution after trapping the disallowed syscall can lead to UB (e.g., try
159133 // running a host function that attempts to sleep without `SYS_clock_nanosleep`,
160134 // you'll block the syscall but panic in the aftermath).
161- match std:: panic:: catch_unwind( std:: panic:: AssertUnwindSafe ( || call_func ( & host_funcs_cloned , & name_cloned , args_cloned ) ) ) {
135+ match std:: panic:: catch_unwind ( std:: panic:: AssertUnwindSafe ( do_call ) ) {
162136 Ok ( val) => val,
163137 Err ( err) => {
164- if let Some ( crate :: HyperlightError :: DisallowedSyscall ) = err. downcast_ref:: <crate :: HyperlightError >( ) {
165- return Err ( crate :: HyperlightError :: DisallowedSyscall )
138+ if let Some ( crate :: HyperlightError :: DisallowedSyscall ) =
139+ err. downcast_ref :: < crate :: HyperlightError > ( )
140+ {
141+ return Err ( crate :: HyperlightError :: DisallowedSyscall ) ;
166142 }
167143
168- crate :: log_then_return!( "Host function {} panicked" , name_cloned ) ;
144+ crate :: log_then_return!( "Host function {} panicked" , name ) ;
169145 }
170146 }
171- } ) ?;
172-
173- join_handle. join( ) . map_err( |_| new_error!( "Error joining thread executing host function" ) ) ?
147+ } )
148+ . join ( )
149+ . map_err ( |_| new_error ! ( "Error joining thread executing host function" ) ) ?
150+ } )
174151 } else {
175152 // Directly call the function without creating a new thread
176- call_func ( host_funcs , name , args )
153+ do_call ( )
177154 }
178155 }
179156}
0 commit comments