Skip to content

Commit 2dc1401

Browse files
committed
Reset guest memory when guest function fails
1 parent 2154923 commit 2dc1401

File tree

3 files changed

+37
-2
lines changed

3 files changed

+37
-2
lines changed

src/hyperlight_host/src/sandbox/initialized_multi_use.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,9 @@ impl MultiUseSandbox {
163163
func_ret_type: ReturnType,
164164
args: Option<Vec<ParameterValue>>,
165165
) -> Result<ReturnValue> {
166-
let res = call_function_on_guest(self, func_name, func_ret_type, args)?;
166+
let res = call_function_on_guest(self, func_name, func_ret_type, args);
167167
self.restore_state()?;
168-
Ok(res)
168+
res
169169
}
170170

171171
/// Restore the Sandbox's state

src/hyperlight_host/tests/integration_test.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,22 @@ fn execute_on_heap() {
429429
}
430430
}
431431

432+
#[test]
433+
fn memory_resets_after_failed_guestcall() {
434+
let mut sbox1 = new_uninit_rust().unwrap().evolve(Noop::default()).unwrap();
435+
sbox1
436+
.call_guest_function_by_name("AddToStaticAndFail", ReturnType::String, None)
437+
.unwrap_err();
438+
let res = sbox1
439+
.call_guest_function_by_name("GetStatic", ReturnType::Int, None)
440+
.unwrap();
441+
assert!(
442+
matches!(res, ReturnValue::Int(0)),
443+
"Expected 0, got {:?}",
444+
res
445+
);
446+
}
447+
432448
// checks that a recursive function with stack allocation eventually fails with stackoverflow
433449
#[test]
434450
fn recursive_stack_allocate_overflow() {

src/tests/rust_guests/simpleguest/src/main.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,16 @@ fn get_static(function_call: &FunctionCall) -> Result<Vec<u8>> {
665665
}
666666
}
667667

668+
fn add_to_static_and_fail(_: &FunctionCall) -> Result<Vec<u8>> {
669+
unsafe {
670+
COUNTER += 10;
671+
};
672+
Err(HyperlightGuestError::new(
673+
ErrorCode::GuestError,
674+
format!("Crash on purpose"),
675+
))
676+
}
677+
668678
fn violate_seccomp_filters(function_call: &FunctionCall) -> Result<Vec<u8>> {
669679
if function_call.parameters.is_none() {
670680
call_host_function("MakeGetpidSyscall", None, ReturnType::ULong)?;
@@ -1036,6 +1046,7 @@ pub extern "C" fn hyperlight_main() {
10361046
add_to_static as i64,
10371047
);
10381048
register_function(add_to_static_def);
1049+
10391050
let get_static_def = GuestFunctionDefinition::new(
10401051
"GetStatic".to_string(),
10411052
Vec::new(),
@@ -1044,6 +1055,14 @@ pub extern "C" fn hyperlight_main() {
10441055
);
10451056
register_function(get_static_def);
10461057

1058+
let add_to_static_and_fail_def = GuestFunctionDefinition::new(
1059+
"AddToStaticAndFail".to_string(),
1060+
Vec::new(),
1061+
ReturnType::Int,
1062+
add_to_static_and_fail as i64,
1063+
);
1064+
register_function(add_to_static_and_fail_def);
1065+
10471066
let violate_seccomp_filters_def = GuestFunctionDefinition::new(
10481067
"ViolateSeccompFilters".to_string(),
10491068
Vec::new(),

0 commit comments

Comments
 (0)