Skip to content

Commit 0a72f81

Browse files
committed
make all guest function calls persistent
Signed-off-by: Jorge Prendes <[email protected]>
1 parent b91ea93 commit 0a72f81

File tree

4 files changed

+73
-73
lines changed

4 files changed

+73
-73
lines changed

src/hyperlight_host/src/func/call_ctx.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -183,27 +183,27 @@ mod tests {
183183
let sbox: MultiUseSandbox = new_uninit().unwrap().evolve(Noop::default()).unwrap();
184184
Self { sandbox: sbox }
185185
}
186-
pub fn call_add_to_static_multiple_times(mut self, i: i32) -> Result<TestSandbox> {
187-
let mut ctx = self.sandbox.new_call_context();
186+
pub fn call_add_to_static_multiple_times(&mut self, i: i32) -> Result<()> {
187+
let snapshot = self.sandbox.snapshot()?;
188188
let mut sum: i32 = 0;
189189
for n in 0..i {
190-
let result = ctx.call::<i32>("AddToStatic", n);
190+
let result = self.sandbox.call_guest_function_by_name::<i32>("AddToStatic", n);
191191
sum += n;
192192
println!("{:?}", result);
193193
let result = result.unwrap();
194194
assert_eq!(result, sum);
195195
}
196-
let result = ctx.finish();
197-
assert!(result.is_ok());
198-
self.sandbox = result.unwrap();
199-
Ok(self)
196+
self.sandbox.restore(&snapshot)?;
197+
Ok(())
200198
}
201199

202-
pub fn call_add_to_static(mut self, i: i32) -> Result<()> {
200+
pub fn call_add_to_static(&mut self, i: i32) -> Result<()> {
201+
let snapshot = self.sandbox.snapshot()?;
203202
for n in 0..i {
204203
let result = self
205204
.sandbox
206205
.call_guest_function_by_name::<i32>("AddToStatic", n);
206+
self.sandbox.restore(&snapshot)?;
207207
println!("{:?}", result);
208208
let result = result.unwrap();
209209
assert_eq!(result, n);
@@ -214,14 +214,14 @@ mod tests {
214214

215215
#[test]
216216
fn ensure_multiusesandbox_multi_calls_dont_reset_state() {
217-
let sandbox = TestSandbox::new();
217+
let mut sandbox = TestSandbox::new();
218218
let result = sandbox.call_add_to_static_multiple_times(5);
219219
assert!(result.is_ok());
220220
}
221221

222222
#[test]
223223
fn ensure_multiusesandbox_single_calls_do_reset_state() {
224-
let sandbox = TestSandbox::new();
224+
let mut sandbox = TestSandbox::new();
225225
let result = sandbox.call_add_to_static(5);
226226
assert!(result.is_ok());
227227
}

src/hyperlight_host/src/mem/mgr.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,4 +561,25 @@ impl SandboxMemoryManager<HostSharedMemory> {
561561
self.layout.sandbox_memory_config.get_output_data_size(),
562562
)
563563
}
564+
565+
pub(crate) fn clear_io_buffers(&mut self) {
566+
// Clear the output data buffer
567+
loop {
568+
let Ok(_) = self.shared_mem.try_pop_buffer_into::<Vec<u8>>(
569+
self.layout.output_data_buffer_offset,
570+
self.layout.sandbox_memory_config.get_output_data_size(),
571+
) else {
572+
break;
573+
};
574+
}
575+
// Clear the input data buffer
576+
loop {
577+
let Ok(_) = self.shared_mem.try_pop_buffer_into::<Vec<u8>>(
578+
self.layout.input_data_buffer_offset,
579+
self.layout.sandbox_memory_config.get_input_data_size(),
580+
) else {
581+
break;
582+
};
583+
}
584+
}
564585
}

src/hyperlight_host/src/sandbox/initialized_multi_use.rs

Lines changed: 42 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -180,28 +180,10 @@ impl MultiUseSandbox {
180180
Ok(())
181181
}
182182

183-
/// Call a guest function by name, with the given return type and arguments.
184-
#[instrument(err(Debug), skip(self, args), parent = Span::current())]
185-
pub fn call_guest_function_by_name<Output: SupportedReturnType>(
186-
&mut self,
187-
func_name: &str,
188-
args: impl ParameterTuple,
189-
) -> Result<Output> {
190-
maybe_time_and_emit_guest_call(func_name, || {
191-
let ret = self.call_guest_function_by_name_no_reset(
192-
func_name,
193-
Output::TYPE,
194-
args.into_value(),
195-
);
196-
self.restore_state()?;
197-
Output::from_value(ret?)
198-
})
199-
}
200-
201183
/// Call a guest function by name, with the given return type and arguments.
202184
/// The changes made to the sandbox are persisted
203185
#[instrument(err(Debug), skip(self, args), parent = Span::current())]
204-
pub fn persist_call_guest_function_by_name<Output: SupportedReturnType>(
186+
pub fn call_guest_function_by_name<Output: SupportedReturnType + std::fmt::Debug>(
205187
&mut self,
206188
func_name: &str,
207189
args: impl ParameterTuple,
@@ -213,7 +195,7 @@ impl MultiUseSandbox {
213195
args.into_value(),
214196
);
215197
let ret = Output::from_value(ret?);
216-
self.mem_mgr.unwrap_mgr_mut().push_state()?;
198+
self.mem_mgr.unwrap_mgr_mut().push_state().unwrap();
217199
ret
218200
})
219201
}
@@ -298,7 +280,7 @@ impl MultiUseSandbox {
298280
) -> Result<ReturnValue> {
299281
maybe_time_and_emit_guest_call(func_name, || {
300282
let ret = self.call_guest_function_by_name_no_reset(func_name, ret_type, args);
301-
self.restore_state()?;
283+
self.mem_mgr.unwrap_mgr_mut().push_state()?;
302284
ret
303285
})
304286
}
@@ -318,35 +300,43 @@ impl MultiUseSandbox {
318300
return_type: ReturnType,
319301
args: Vec<ParameterValue>,
320302
) -> Result<ReturnValue> {
321-
let fc = FunctionCall::new(
322-
function_name.to_string(),
323-
Some(args),
324-
FunctionCallType::Guest,
325-
return_type,
326-
);
327-
328-
let buffer: Vec<u8> = fc
329-
.try_into()
330-
.map_err(|_| HyperlightError::Error("Failed to serialize FunctionCall".to_string()))?;
331-
332-
self.get_mgr_wrapper_mut()
333-
.as_mut()
334-
.write_guest_function_call(&buffer)?;
335-
336-
self.vm.dispatch_call_from_host(
337-
self.dispatch_ptr.clone(),
338-
self.out_hdl.clone(),
339-
self.mem_hdl.clone(),
340-
#[cfg(gdb)]
341-
self.dbg_mem_access_fn.clone(),
342-
)?;
303+
let res = (|| {
304+
let fc = FunctionCall::new(
305+
function_name.to_string(),
306+
Some(args),
307+
FunctionCallType::Guest,
308+
return_type,
309+
);
310+
311+
let buffer: Vec<u8> = fc.try_into().map_err(|_| {
312+
HyperlightError::Error("Failed to serialize FunctionCall".to_string())
313+
})?;
343314

344-
self.check_stack_guard()?;
345-
check_for_guest_error(self.get_mgr_wrapper_mut())?;
315+
self.get_mgr_wrapper_mut()
316+
.as_mut()
317+
.write_guest_function_call(&buffer)?;
346318

347-
self.get_mgr_wrapper_mut()
348-
.as_mut()
349-
.get_guest_function_call_result()
319+
self.vm.dispatch_call_from_host(
320+
self.dispatch_ptr.clone(),
321+
self.out_hdl.clone(),
322+
self.mem_hdl.clone(),
323+
#[cfg(gdb)]
324+
self.dbg_mem_access_fn.clone(),
325+
)?;
326+
327+
self.check_stack_guard()?;
328+
check_for_guest_error(self.get_mgr_wrapper_mut())?;
329+
330+
self
331+
.get_mgr_wrapper_mut()
332+
.as_mut()
333+
.get_guest_function_call_result()
334+
})();
335+
336+
// TODO: Do we want to allow re-entrant guest function calls?
337+
self.get_mgr_wrapper_mut().as_mut().clear_io_buffers();
338+
339+
res
350340
}
351341

352342
/// Get a handle to the interrupt handler for this sandbox,
@@ -447,13 +437,14 @@ mod tests {
447437

448438
let snapshot = sbox.snapshot().unwrap();
449439

450-
let _ = sbox.persist_call_guest_function_by_name::<i32>("AddToStatic", 5i32).unwrap();
440+
let _ = sbox.call_guest_function_by_name::<i32>("AddToStatic", 5i32)
441+
.unwrap();
451442

452-
let res: i32 = sbox.persist_call_guest_function_by_name("GetStatic", ()).unwrap();
443+
let res: i32 = sbox.call_guest_function_by_name("GetStatic", ()).unwrap();
453444
assert_eq!(res, 5);
454445

455446
sbox.restore(&snapshot).unwrap();
456-
let res: i32 = sbox.persist_call_guest_function_by_name("GetStatic", ()).unwrap();
447+
let res: i32 = sbox.call_guest_function_by_name("GetStatic", ()).unwrap();
457448
assert_eq!(res, 0);
458449
}
459450

src/hyperlight_host/tests/integration_test.rs

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -688,18 +688,6 @@ fn execute_on_heap() {
688688
}
689689
}
690690

691-
#[test]
692-
fn memory_resets_after_failed_guestcall() {
693-
let mut sbox1 = new_uninit_rust().unwrap().evolve(Noop::default()).unwrap();
694-
sbox1
695-
.call_guest_function_by_name::<String>("AddToStaticAndFail", ())
696-
.unwrap_err();
697-
let res = sbox1
698-
.call_guest_function_by_name::<i32>("GetStatic", ())
699-
.unwrap();
700-
assert_eq!(res, 0, "Expected 0, got {:?}", res);
701-
}
702-
703691
// checks that a recursive function with stack allocation eventually fails with stackoverflow
704692
#[test]
705693
fn recursive_stack_allocate_overflow() {

0 commit comments

Comments
 (0)