diff --git a/src/hyperlight_host/src/hypervisor/hyperv_linux.rs b/src/hyperlight_host/src/hypervisor/hyperv_linux.rs index 104f75cab..55b6f8444 100644 --- a/src/hyperlight_host/src/hypervisor/hyperv_linux.rs +++ b/src/hyperlight_host/src/hypervisor/hyperv_linux.rs @@ -631,7 +631,7 @@ impl Hypervisor for HypervLinuxDriver { let cancel_requested = self .interrupt_handle .cancel_requested - .swap(false, Ordering::Relaxed); + .load(Ordering::Relaxed); // Note: if a `InterruptHandle::kill()` called while this thread is **here** // Then `cancel_requested` will be set to true again, which will cancel the **next vcpu run**. // Additionally signals will be sent to this thread until `running` is set to false. @@ -722,6 +722,9 @@ impl Hypervisor for HypervLinuxDriver { // If cancellation was not requested for this specific vm, the vcpu was interrupted because of stale signal // that was meant to be delivered to a previous/other vcpu on this same thread, so let's ignore it if cancel_requested { + self.interrupt_handle + .cancel_requested + .store(false, Ordering::Relaxed); HyperlightExit::Cancelled() } else { HyperlightExit::Retry() diff --git a/src/hyperlight_host/src/hypervisor/kvm.rs b/src/hyperlight_host/src/hypervisor/kvm.rs index c7aefebc8..a0f6c07ba 100644 --- a/src/hyperlight_host/src/hypervisor/kvm.rs +++ b/src/hyperlight_host/src/hypervisor/kvm.rs @@ -560,7 +560,7 @@ impl Hypervisor for KVMDriver { let cancel_requested = self .interrupt_handle .cancel_requested - .swap(false, Ordering::Relaxed); + .load(Ordering::Relaxed); // Note: if a `InterruptHandle::kill()` called while this thread is **here** // Then `cancel_requested` will be set to true again, which will cancel the **next vcpu run**. // Additionally signals will be sent to this thread until `running` is set to false. @@ -625,6 +625,9 @@ impl Hypervisor for KVMDriver { // If cancellation was not requested for this specific vm, the vcpu was interrupted because of stale signal // that was meant to be delivered to a previous/other vcpu on this same thread, so let's ignore it if cancel_requested { + self.interrupt_handle + .cancel_requested + .store(false, Ordering::Relaxed); HyperlightExit::Cancelled() } else { HyperlightExit::Retry() diff --git a/src/hyperlight_host/tests/integration_test.rs b/src/hyperlight_host/tests/integration_test.rs index edd6a6b51..e4c41da83 100644 --- a/src/hyperlight_host/tests/integration_test.rs +++ b/src/hyperlight_host/tests/integration_test.rs @@ -327,6 +327,43 @@ fn interrupt_custom_signal_no_and_retry_delay() { thread.join().expect("Thread should finish"); } +#[test] +fn interrupt_spamming_host_call() { + let mut uninit = UninitializedSandbox::new( + GuestBinary::FilePath(callback_guest_as_string().unwrap()), + None, + ) + .unwrap(); + + uninit + .register("HostFunc1", || { + // do nothing + }) + .unwrap(); + let mut sbox1: MultiUseSandbox = uninit.evolve(Noop::default()).unwrap(); + + let interrupt_handle = sbox1.interrupt_handle(); + + let barrier = Arc::new(Barrier::new(2)); + let barrier2 = barrier.clone(); + + let thread = thread::spawn(move || { + barrier2.wait(); + thread::sleep(Duration::from_secs(1)); + interrupt_handle.kill(); + }); + + barrier.wait(); + // This guest call calls "HostFunc1" in a loop + let res = sbox1 + .call_guest_function_by_name::("HostCallLoop", "HostFunc1".to_string()) + .unwrap_err(); + + assert!(matches!(res, HyperlightError::ExecutionCanceledByHost())); + + thread.join().expect("Thread should finish"); +} + #[test] fn print_four_args_c_guest() { let path = c_simple_guest_as_string().unwrap(); diff --git a/src/tests/rust_guests/callbackguest/src/main.rs b/src/tests/rust_guests/callbackguest/src/main.rs index 93518b782..0b79d578e 100644 --- a/src/tests/rust_guests/callbackguest/src/main.rs +++ b/src/tests/rust_guests/callbackguest/src/main.rs @@ -157,6 +157,19 @@ fn call_host_spin(_: &FunctionCall) -> Result> { Ok(get_flatbuffer_result(())) } +fn host_call_loop(function_call: &FunctionCall) -> Result> { + if let ParameterValue::String(message) = &function_call.parameters.as_ref().unwrap()[0] { + loop { + call_host_function::<()>(message, None, ReturnType::Void).unwrap(); + } + } else { + Err(HyperlightGuestError::new( + ErrorCode::GuestFunctionParameterTypeMismatch, + "Invalid parameters passed to host_call_loop".to_string(), + )) + } +} + #[no_mangle] pub extern "C" fn hyperlight_main() { let print_output_def = GuestFunctionDefinition::new( @@ -234,6 +247,14 @@ pub extern "C" fn hyperlight_main() { call_host_spin as usize, ); register_function(call_host_spin_def); + + let host_call_loop_def = GuestFunctionDefinition::new( + "HostCallLoop".to_string(), + Vec::from(&[ParameterType::String]), + ReturnType::Void, + host_call_loop as usize, + ); + register_function(host_call_loop_def); } #[no_mangle]