diff --git a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs index b31cd37d3..43f3c7211 100644 --- a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs +++ b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs @@ -163,9 +163,9 @@ impl MultiUseSandbox { func_ret_type: ReturnType, args: Option>, ) -> Result { - let res = call_function_on_guest(self, func_name, func_ret_type, args)?; + let res = call_function_on_guest(self, func_name, func_ret_type, args); self.restore_state()?; - Ok(res) + res } /// Restore the Sandbox's state diff --git a/src/hyperlight_host/tests/integration_test.rs b/src/hyperlight_host/tests/integration_test.rs index be8bb261e..73ad0b21d 100644 --- a/src/hyperlight_host/tests/integration_test.rs +++ b/src/hyperlight_host/tests/integration_test.rs @@ -429,6 +429,22 @@ fn execute_on_heap() { } } +#[test] +fn memory_resets_after_failed_guestcall() { + let mut sbox1 = new_uninit_rust().unwrap().evolve(Noop::default()).unwrap(); + sbox1 + .call_guest_function_by_name("AddToStaticAndFail", ReturnType::String, None) + .unwrap_err(); + let res = sbox1 + .call_guest_function_by_name("GetStatic", ReturnType::Int, None) + .unwrap(); + assert!( + matches!(res, ReturnValue::Int(0)), + "Expected 0, got {:?}", + res + ); +} + // checks that a recursive function with stack allocation eventually fails with stackoverflow #[test] fn recursive_stack_allocate_overflow() { diff --git a/src/tests/rust_guests/simpleguest/src/main.rs b/src/tests/rust_guests/simpleguest/src/main.rs index dffa2bd34..a713e7266 100644 --- a/src/tests/rust_guests/simpleguest/src/main.rs +++ b/src/tests/rust_guests/simpleguest/src/main.rs @@ -665,6 +665,16 @@ fn get_static(function_call: &FunctionCall) -> Result> { } } +fn add_to_static_and_fail(_: &FunctionCall) -> Result> { + unsafe { + COUNTER += 10; + }; + Err(HyperlightGuestError::new( + ErrorCode::GuestError, + "Crash on purpose".to_string(), + )) +} + fn violate_seccomp_filters(function_call: &FunctionCall) -> Result> { if function_call.parameters.is_none() { call_host_function("MakeGetpidSyscall", None, ReturnType::ULong)?; @@ -1036,6 +1046,7 @@ pub extern "C" fn hyperlight_main() { add_to_static as i64, ); register_function(add_to_static_def); + let get_static_def = GuestFunctionDefinition::new( "GetStatic".to_string(), Vec::new(), @@ -1044,6 +1055,14 @@ pub extern "C" fn hyperlight_main() { ); register_function(get_static_def); + let add_to_static_and_fail_def = GuestFunctionDefinition::new( + "AddToStaticAndFail".to_string(), + Vec::new(), + ReturnType::Int, + add_to_static_and_fail as i64, + ); + register_function(add_to_static_and_fail_def); + let violate_seccomp_filters_def = GuestFunctionDefinition::new( "ViolateSeccompFilters".to_string(), Vec::new(),