diff --git a/src/hyperlight_guest/src/host_function_call.rs b/src/hyperlight_guest/src/host_function_call.rs index df31307b3..97c03e117 100644 --- a/src/hyperlight_guest/src/host_function_call.rs +++ b/src/hyperlight_guest/src/host_function_call.rs @@ -32,7 +32,9 @@ use crate::shared_input_data::try_pop_shared_input_data_into; use crate::shared_output_data::push_shared_output_data; /// Get a return value from a host function call. -/// This usually requires a host function to be called first using `call_host_function`. +/// This usually requires a host function to be called first using `call_host_function_internal`. +/// When calling `call_host_function`, this function is called internally to get the return +/// value. pub fn get_host_return_value>() -> Result { let return_value = try_pop_shared_input_data_into::() .expect("Unable to deserialize a return value from host"); @@ -47,10 +49,12 @@ pub fn get_host_return_value>() -> Result { }) } -// TODO: Make this generic, return a Result this should allow callers to call this function and get the result type they expect -// without having to do the conversion themselves - -pub fn call_host_function( +/// Internal function to call a host function without generic type parameters. +/// This is used by both the Rust and C APIs to reduce code duplication. +/// +/// This function doesn't return the host function result directly, instead it just +/// performs the call. The result must be obtained by calling `get_host_return_value`. +pub fn call_host_function_internal( function_name: &str, parameters: Option>, return_type: ReturnType, @@ -73,6 +77,20 @@ pub fn call_host_function( Ok(()) } +/// Call a host function with the given parameters and return type. +/// This function serializes the function call and its parameters, +/// sends it to the host, and then retrieves the return value. +/// +/// The return value is deserialized into the specified type `T`. +pub fn call_host_function>( + function_name: &str, + parameters: Option>, + return_type: ReturnType, +) -> Result { + call_host_function_internal(function_name, parameters, return_type)?; + get_host_return_value::() +} + pub fn outb(port: u16, data: &[u8]) { unsafe { let mut i = 0; @@ -109,13 +127,13 @@ pub fn debug_print(msg: &str) { /// existence of the input and output memory regions. pub fn print_output_with_host_print(function_call: &FunctionCall) -> Result> { if let ParameterValue::String(message) = function_call.parameters.clone().unwrap()[0].clone() { - call_host_function( + let res = call_host_function::( "HostPrint", Some(Vec::from(&[ParameterValue::String(message.to_string())])), ReturnType::Int, )?; - let res_i = get_host_return_value::()?; - Ok(get_flatbuffer_result(res_i)) + + Ok(get_flatbuffer_result(res)) } else { Err(HyperlightGuestError::new( ErrorCode::GuestError, diff --git a/src/hyperlight_guest/src/print.rs b/src/hyperlight_guest/src/print.rs index b7dd8404f..740e4561f 100644 --- a/src/hyperlight_guest/src/print.rs +++ b/src/hyperlight_guest/src/print.rs @@ -55,10 +55,11 @@ pub unsafe extern "C" fn _putchar(c: c_char) { .expect("Failed to convert buffer to string") }; - call_host_function( + // HostPrint returns an i32, but we don't care about the return value + let _ = call_host_function::( "HostPrint", Some(Vec::from(&[ParameterValue::String(str)])), - ReturnType::Void, + ReturnType::Int, ) .expect("Failed to call HostPrint"); diff --git a/src/hyperlight_guest_capi/src/dispatch.rs b/src/hyperlight_guest_capi/src/dispatch.rs index eacd7b374..feba47658 100644 --- a/src/hyperlight_guest_capi/src/dispatch.rs +++ b/src/hyperlight_guest_capi/src/dispatch.rs @@ -10,7 +10,7 @@ use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode; use hyperlight_guest::error::{HyperlightGuestError, Result}; use hyperlight_guest::guest_function_definition::GuestFunctionDefinition; use hyperlight_guest::guest_function_register::GuestFunctionRegister; -use hyperlight_guest::host_function_call::call_host_function; +use hyperlight_guest::host_function_call::call_host_function_internal; use crate::types::{FfiFunctionCall, FfiVec}; static mut REGISTERED_C_GUEST_FUNCTIONS: GuestFunctionRegister = GuestFunctionRegister::new(); @@ -89,5 +89,9 @@ pub extern "C" fn hl_call_host_function(function_call: &FfiFunctionCall) { let parameters = unsafe { function_call.copy_parameters() }; let func_name = unsafe { function_call.copy_function_name() }; let return_type = unsafe { function_call.copy_return_type() }; - let _ = call_host_function(&func_name, Some(parameters), return_type); + + // Use the non-generic internal implementation + // The C API will then call specific getter functions to fetch the properly typed return value + let _ = call_host_function_internal(&func_name, Some(parameters), return_type) + .expect("Failed to call host function"); } diff --git a/src/hyperlight_host/src/func/utils.rs b/src/hyperlight_host/src/func/utils.rs index b02666bf5..d34521e5d 100644 --- a/src/hyperlight_host/src/func/utils.rs +++ b/src/hyperlight_host/src/func/utils.rs @@ -4,6 +4,8 @@ /// /// Usage: /// ```rust +/// use hyperlight_host::func::for_each_tuple; +/// /// macro_rules! my_macro { /// ([$count:expr] ($($name:ident: $type:ident),*)) => { /// // $count is the arity of the tuple diff --git a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs index e6007742d..8769e1643 100644 --- a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs +++ b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs @@ -112,7 +112,6 @@ impl MultiUseSandbox { /// let u_sbox = UninitializedSandbox::new( /// GuestBinary::FilePath("some_guest_binary".to_string()), /// None, - /// None, /// ).unwrap(); /// let sbox: MultiUseSandbox = u_sbox.evolve(Noop::default()).unwrap(); /// // Next, create a new call context from the single-use sandbox. diff --git a/src/tests/rust_guests/callbackguest/src/main.rs b/src/tests/rust_guests/callbackguest/src/main.rs index a96b957f5..bb2c61b50 100644 --- a/src/tests/rust_guests/callbackguest/src/main.rs +++ b/src/tests/rust_guests/callbackguest/src/main.rs @@ -34,9 +34,7 @@ use hyperlight_common::flatbuffer_wrappers::util::get_flatbuffer_result; use hyperlight_guest::error::{HyperlightGuestError, Result}; use hyperlight_guest::guest_function_definition::GuestFunctionDefinition; use hyperlight_guest::guest_function_register::register_function; -use hyperlight_guest::host_function_call::{ - call_host_function, get_host_return_value, print_output_with_host_print, -}; +use hyperlight_guest::host_function_call::{call_host_function, print_output_with_host_print}; use hyperlight_guest::logging::log_message; fn send_message_to_host_method( @@ -45,15 +43,13 @@ fn send_message_to_host_method( message: &str, ) -> Result> { let message = format!("{}{}", guest_message, message); - call_host_function( + let res = call_host_function::( method_name, Some(Vec::from(&[ParameterValue::String(message.to_string())])), ReturnType::Int, )?; - let result = get_host_return_value::()?; - - Ok(get_flatbuffer_result(result)) + Ok(get_flatbuffer_result(res)) } fn guest_function(function_call: &FunctionCall) -> Result> { @@ -101,7 +97,7 @@ fn guest_function3(function_call: &FunctionCall) -> Result> { } fn guest_function4(_: &FunctionCall) -> Result> { - call_host_function( + call_host_function::<()>( "HostMethod4", Some(Vec::from(&[ParameterValue::String( "Hello from GuestFunction4".to_string(), @@ -157,7 +153,7 @@ fn call_error_method(function_call: &FunctionCall) -> Result> { } fn call_host_spin(_: &FunctionCall) -> Result> { - call_host_function("Spin", None, ReturnType::Void)?; + call_host_function::<()>("Spin", None, ReturnType::Void)?; Ok(get_flatbuffer_result(())) } diff --git a/src/tests/rust_guests/simpleguest/src/main.rs b/src/tests/rust_guests/simpleguest/src/main.rs index 1d73d7d7f..ce7bd6aa7 100644 --- a/src/tests/rust_guests/simpleguest/src/main.rs +++ b/src/tests/rust_guests/simpleguest/src/main.rs @@ -45,7 +45,7 @@ use hyperlight_guest::entrypoint::{abort_with_code, abort_with_code_and_message} use hyperlight_guest::error::{HyperlightGuestError, Result}; use hyperlight_guest::guest_function_definition::GuestFunctionDefinition; use hyperlight_guest::guest_function_register::register_function; -use hyperlight_guest::host_function_call::{call_host_function, get_host_return_value}; +use hyperlight_guest::host_function_call::{call_host_function, call_host_function_internal}; use hyperlight_guest::memory::malloc; use hyperlight_guest::{logging, MIN_STACK_ADDRESS}; use log::{error, LevelFilter}; @@ -86,13 +86,13 @@ fn echo_float(function_call: &FunctionCall) -> Result> { } fn print_output(message: &str) -> Result> { - call_host_function( + let res = call_host_function::( "HostPrint", Some(Vec::from(&[ParameterValue::String(message.to_string())])), ReturnType::Int, )?; - let result = get_host_return_value::()?; - Ok(get_flatbuffer_result(result)) + + Ok(get_flatbuffer_result(res)) } fn simple_print_output(function_call: &FunctionCall) -> Result> { @@ -679,9 +679,7 @@ fn add_to_static_and_fail(_: &FunctionCall) -> Result> { fn violate_seccomp_filters(function_call: &FunctionCall) -> Result> { if function_call.parameters.is_none() { - call_host_function("MakeGetpidSyscall", None, ReturnType::ULong)?; - - let res = get_host_return_value::()?; + let res = call_host_function::("MakeGetpidSyscall", None, ReturnType::ULong)?; Ok(get_flatbuffer_result(res)) } else { @@ -697,14 +695,11 @@ fn add(function_call: &FunctionCall) -> Result> { function_call.parameters.clone().unwrap()[0].clone(), function_call.parameters.clone().unwrap()[1].clone(), ) { - call_host_function( + let res = call_host_function::( "HostAdd", Some(Vec::from(&[ParameterValue::Int(a), ParameterValue::Int(b)])), ReturnType::Int, )?; - - let res = get_host_return_value::()?; - Ok(get_flatbuffer_result(res)) } else { Err(HyperlightGuestError::new( @@ -1156,12 +1151,11 @@ pub fn guest_dispatch_function(function_call: FunctionCall) -> Result> { 1, ); - call_host_function( + let result = call_host_function::( "HostPrint", Some(Vec::from(&[ParameterValue::String(message.to_string())])), ReturnType::Int, )?; - let result = get_host_return_value::()?; let function_name = function_call.function_name.clone(); let param_len = function_call.parameters.clone().unwrap_or_default().len(); let call_type = function_call.function_call_type().clone(); @@ -1195,7 +1189,12 @@ fn fuzz_host_function(func: FunctionCall) -> Result> { )) } }; - call_host_function(&host_func_name, Some(params), func.expected_return_type) + + // Because we do not know at compile time the actual return type of the host function to be called + // we cannot use the `call_host_function` generic function. + // We need to use the `call_host_function_internal` function that does not retrieve the return + // value + call_host_function_internal(&host_func_name, Some(params), func.expected_return_type) .expect("failed to call host function"); Ok(get_flatbuffer_result(())) }